test-kmeans.py
799 Bytes
'''
'''
import argparse
import numpy as np
from sklearn.cluster import KMeans
from data import read_file, index_by_id
import pickle
parser = argparse.ArgumentParser(description="...")
parser.add_argument("kmeans", type=str, help="kmean saved file")
parser.add_argument("features", type=str, help="features file")
parser.add_argument("lst", type=str, help="lst file")
args = parser.parse_args()
KMEAN_FILE = args.kmeans
FEATURES_FILE = args.features
LST_FILE = args.lst
# Load features and lst
features = read_file(FEATURES_FILE)
features_ind = index_by_id(features)
lst = read_file(LST_FILE)
# Load Kmeans
kmeans = pickle.load(open(KMEAN_FILE, "rb"))
# Get all x
X = np.asarray([features_ind[x[0][0]][x[0][3]][1] for x in lst])
predicts = kmeans.predict(X)
print(np.unique(predicts).shape)