diff --git a/volia/plot.py b/volia/plot.py new file mode 100644 index 0000000..9aa3645 --- /dev/null +++ b/volia/plot.py @@ -0,0 +1,71 @@ +import argparse +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from core.data import read_features, read_lst, read_labels +from utils import SubCommandRunner + + +def scatter_plot(features: str, labels: str, outfile: str): + """Generate a simple scatter plot. Mainly used for + data visualisation processed with tsne or algorithm like + this later. + + Args: + features (str): Features file in 2d or 3d + labels (str): Labels file + outfile (str) : output file + """ + + id_to_features = read_features(args.features) + ids = [ key for key in id_to_features.keys() ] + utt2label = read_labels(labels) + + features = [ id_to_features[id_] for id_ in ids ] + features = np.vstack(features) + + labels_list = [ utt2label[id_][0] for id_ in ids ] + + features_T = features.transpose() + print("Number of labels: ", len(np.unique(labels_list))) + df = pd.DataFrame(dict( + x=features_T[0], + y=features_T[1], + label=labels_list)) + + groups = df.groupby('label') + + # Plot + fig, ax = plt.subplots() + + for label, group in groups: + p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) + ax.legend() + plt.savefig(outfile) + print("Your plot is saved well (no check of this affirmation)") + + + +if __name__ == '__main__': + + # Main parser + parser = argparse.ArgumentParser(description="") + subparsers = parser.add_subparsers(title="action") + + # with label + parser_scatter = subparsers.add_parser("scatter") + parser_scatter.add_argument("features", type=str, help="define the main features file") + parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") + parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") + parser_scatter.set_defaults(which="scatter") + + # Parse + args = parser.parse_args() + + # Run commands + runner = SubCommandRunner({ + "scatter" : scatter_plot + }) + + runner.run(args.which, args.__dict__, remove="which") \ No newline at end of file