Commit 8774b988e43462ca1a216c5a2514b688eada570c

Authored by Quillot Mathias
1 parent 81260862fd
Exists in master

New scatter plot module

Showing 1 changed file with 71 additions and 0 deletions Side-by-side Diff

  1 +import argparse
  2 +import matplotlib.pyplot as plt
  3 +import numpy as np
  4 +import pandas as pd
  5 +
  6 +from core.data import read_features, read_lst, read_labels
  7 +from utils import SubCommandRunner
  8 +
  9 +
  10 +def scatter_plot(features: str, labels: str, outfile: str):
  11 + """Generate a simple scatter plot. Mainly used for
  12 + data visualisation processed with tsne or algorithm like
  13 + this later.
  14 +
  15 + Args:
  16 + features (str): Features file in 2d or 3d
  17 + labels (str): Labels file
  18 + outfile (str) : output file
  19 + """
  20 +
  21 + id_to_features = read_features(args.features)
  22 + ids = [ key for key in id_to_features.keys() ]
  23 + utt2label = read_labels(labels)
  24 +
  25 + features = [ id_to_features[id_] for id_ in ids ]
  26 + features = np.vstack(features)
  27 +
  28 + labels_list = [ utt2label[id_][0] for id_ in ids ]
  29 +
  30 + features_T = features.transpose()
  31 + print("Number of labels: ", len(np.unique(labels_list)))
  32 + df = pd.DataFrame(dict(
  33 + x=features_T[0],
  34 + y=features_T[1],
  35 + label=labels_list))
  36 +
  37 + groups = df.groupby('label')
  38 +
  39 + # Plot
  40 + fig, ax = plt.subplots()
  41 +
  42 + for label, group in groups:
  43 + p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
  44 + ax.legend()
  45 + plt.savefig(outfile)
  46 + print("Your plot is saved well (no check of this affirmation)")
  47 +
  48 +
  49 +
  50 +if __name__ == '__main__':
  51 +
  52 + # Main parser
  53 + parser = argparse.ArgumentParser(description="")
  54 + subparsers = parser.add_subparsers(title="action")
  55 +
  56 + # with label
  57 + parser_scatter = subparsers.add_parser("scatter")
  58 + parser_scatter.add_argument("features", type=str, help="define the main features file")
  59 + parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
  60 + parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
  61 + parser_scatter.set_defaults(which="scatter")
  62 +
  63 + # Parse
  64 + args = parser.parse_args()
  65 +
  66 + # Run commands
  67 + runner = SubCommandRunner({
  68 + "scatter" : scatter_plot
  69 + })
  70 +
  71 + runner.run(args.which, args.__dict__, remove="which")