plot.py 2.09 KB
import argparse
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from core.data import read_features, read_lst, read_labels
from utils import SubCommandRunner


def scatter_plot(features: str, labels: str, outfile: str):
    """Generate a simple scatter plot. Mainly used for 
    data visualisation processed with tsne or algorithm like
    this later.

    Args:
        features (str): Features file in 2d or 3d
        labels (str): Labels file
        outfile (str) : output file
    """

    id_to_features = read_features(args.features)
    ids = [ key for key in id_to_features.keys() ]
    utt2label = read_labels(labels)
    
    features = [ id_to_features[id_] for id_ in ids ]
    features = np.vstack(features)

    labels_list = [ utt2label[id_][0] for id_ in ids ]

    features_T = features.transpose()
    print("Number of labels: ", len(np.unique(labels_list)))
    df = pd.DataFrame(dict(
        x=features_T[0],
        y=features_T[1],
        label=labels_list))

    groups = df.groupby('label')

    # Plot
    fig, ax = plt.subplots()

    for label, group in groups:
        p = ax.plot(group.x, group.y, marker='o', linestyle='', ms=1, label=label)
    ax.legend()
    plt.savefig(outfile)
    print("Your plot is saved well (no check of this affirmation)")
    


if __name__ == '__main__':

    # Main parser
    parser = argparse.ArgumentParser(description="")
    subparsers = parser.add_subparsers(title="action")

    # scatter with labels
    parser_scatter = subparsers.add_parser("scatter")
    parser_scatter.add_argument("features", type=str, help="define the main features file")
    parser_scatter.add_argument("--labels", default=None, type=str, help="specify the labels of each utterance/element")
    parser_scatter.add_argument("--outfile", default="./out.pdf", type=str, help="Specify the output file (better in pdf)")
    parser_scatter.set_defaults(which="scatter")

    # Parse
    args = parser.parse_args()

    # Run commands
    runner = SubCommandRunner({
        "scatter" : scatter_plot
    })
    
    runner.run(args.which, args.__dict__, remove="which")