File size: 2,204 Bytes
430712c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import time
import argparse

import torch

from cuml.cluster import KMeans

# from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering

from sklearn.metrics import normalized_mutual_info_score


def main(args):
    # Load embeddings
    embeddings_file = torch.load(args.embeddings_file)
    files = list(embeddings_file.keys())
    labels = [file.split('/')[-3] for file in files]
    embeddings = torch.cat(list(embeddings_file.values())).numpy()
    print(f"Embedding shape: {embeddings.shape}")

    # K-Means
    print("KMeans...")
    kmeans_start_time = time.time()
    kmeans = KMeans(
        n_clusters=args.n_clusters,
        random_state=0,
        max_samples_per_batch=1000000,
        verbose=True
    ).fit(embeddings)
    pseudo_labels = kmeans.labels_
    centroids = kmeans.cluster_centers_
    print(f"K-Means duration: {(time.time() - kmeans_start_time)/60:.2f} min")

    # AHC
    if args.n_clusters_ahc > 0:
        print("AHC...")
        ahc_start_time = time.time()
        ahc_labels = AgglomerativeClustering(
            n_clusters=args.n_clusters_ahc
        ).fit_predict(centroids)
        pseudo_labels = [ahc_labels[pl] for pl in pseudo_labels]
        print(f"AHC duration: {(time.time() - ahc_start_time)/60:.2f} min")

    # Print NMI
    nmi_score = normalized_mutual_info_score(labels, pseudo_labels)
    print(f"NMI: {nmi_score}")

    # Export pseudo labels
    with open(args.output_file, 'w') as f:
        for file, pseudo_label in zip(files, pseudo_labels):
            f.write(f"{pseudo_label} {file}\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'embeddings_file',
        help='Path to embeddings file (.pt).'
    )
    parser.add_argument(
        'output_file',
        help='Path to output file (.txt).'
    )
    parser.add_argument(
        '--n_clusters',
        help='Number of clusters for KMeans.',
        type=int,
        default=50000
    )
    parser.add_argument(
        '--n_clusters_ahc',
        help='Number of clusters for Agglomerative Clustering.',
        type=int,
        default=7500
    )
    args = parser.parse_args()

    main(args)