Blame view
volia/plot.py
3.14 KB
8774b988e New scatter plot ... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import argparse import matplotlib.pyplot as plt import numpy as np import pandas as pd from core.data import read_features, read_lst, read_labels from utils import SubCommandRunner def scatter_plot(features: str, labels: str, outfile: str): """Generate a simple scatter plot. Mainly used for data visualisation processed with tsne or algorithm like this later. Args: features (str): Features file in 2d or 3d labels (str): Labels file outfile (str) : output file """ id_to_features = read_features(args.features) ids = [ key for key in id_to_features.keys() ] utt2label = read_labels(labels) features = [ id_to_features[id_] for id_ in ids ] features = np.vstack(features) labels_list = [ utt2label[id_][0] for id_ in ids ] features_T = features.transpose() print("Number of labels: ", len(np.unique(labels_list))) df = pd.DataFrame(dict( x=features_T[0], y=features_T[1], label=labels_list)) groups = df.groupby('label') # Plot fig, ax = plt.subplots() for label, group in groups: p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label) ax.legend() plt.savefig(outfile) print("Your plot is saved well (no check of this affirmation)") |
40650f20d Preparing interac... |
48 49 50 51 52 53 54 55 56 57 58 |
def interactive_scatter_plot(features: str, labels: str, outdir: str): """Generate an interactive scatter plot in 3D. Mainly used for data visualisation processed with tsne or algorithm like this later. This visualization is generated in Web files. Args: features (str): Features file in 2d or 3d labels (str): Labels file outdir (str) : output directory where Web files are saved """ pass |
8774b988e New scatter plot ... |
59 60 61 62 63 64 |
if __name__ == '__main__': # Main parser parser = argparse.ArgumentParser(description="") subparsers = parser.add_subparsers(title="action") |
70fadae57 Simple comment ch... |
65 |
# scatter with labels |
8774b988e New scatter plot ... |
66 67 68 69 70 |
parser_scatter = subparsers.add_parser("scatter") parser_scatter.add_argument("features", type=str, help="define the main features file") parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element") parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)") parser_scatter.set_defaults(which="scatter") |
40650f20d Preparing interac... |
71 72 73 74 75 76 77 |
# interactive scatter parser_interactive_scatter = subparsers.add_parser("interactive_scatter") parser_interactive_scatter.add_argument("features", type=str, help="features files with only 3D will be converted into csv") parser_scatter.add_argument("--labels", default=None, type=str, help="Specify the labels of each utterance/element") parser_scatter.add_argument("--outdir", default=".out", type=str, help="output directoy where static web and data files are saved.") parser_scatter.set_defaults(which="interactive_scatter") |
8774b988e New scatter plot ... |
78 79 80 81 82 |
# Parse args = parser.parse_args() # Run commands runner = SubCommandRunner({ |
40650f20d Preparing interac... |
83 84 |
"scatter" : scatter_plot, "interactive_scatter" : interactive_scatter_plot |
8774b988e New scatter plot ... |
85 86 87 |
}) runner.run(args.which, args.__dict__, remove="which") |