diff --git a/volia/stats.py b/volia/stats.py index c12c5cd..fe20213 100644 --- a/volia/stats.py +++ b/volia/stats.py @@ -9,30 +9,12 @@ import scipy.stats import pickle import matplotlib.pyplot as plt import matplotlib.colors as mcolors - +from utils import SubCommandRunner from cycler import cycler -def stats(): - print("Decisions") - - -print(list(mcolors.TABLEAU_COLORS)) - - -if __name__ == "__main__": - - # Parser - parser = argparse.ArgumentParser(description="") - - # Arguments - parser.add_argument("--predictions", type=str, help="prediction file", required=True) - parser.add_argument("--labels", type=str, help="label file", required=True) - parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) - parser.add_argument("--outdir", type=str, help="output file", required=True) - - args = parser.parse_args() +def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: str): predictions = core.data.read_id_values(args.predictions, float) labels = core.data.read_labels(args.labels) @@ -108,7 +90,57 @@ if __name__ == "__main__": plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf")) plt.clf() + print("Decisions") + + +def utt2dur(utt2dur: str, labels: str): + if labels == None: + pass + else: + pass + + durations = [] + with open(utt2dur, "r") as f: + for line in f: + splited = line.replace("\n", "").split(" ") + durations.append(float(splited[1])) + + durations = np.asarray(durations, dtype=float) + print(durations.shape) + mean = np.mean(durations) + std = np.std(durations) - # TODO: - # One graph for each label. Distribution of their predictions output are displayed. - + print(f"mean: {mean}") + print(f"std: {std}") + + +if __name__ == "__main__": + + # Parser + parser = argparse.ArgumentParser(description="Statistics") + subparsers = parser.add_subparsers(title="actions") + + # pred-distribution + parser_pred_dist = subparsers.add_parser("pred-distribution", help="plot distributions of prediction through labels") + parser_pred_dist.add_argument("--predictions", type=str, help="prediction file", required=True) + parser_pred_dist.add_argument("--labels", type=str, help="label file", required=True) + parser_pred_dist.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) + parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) + parser_pred_dist.set_defaults(which="pred_distribution") + + # duration-stats + parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") + parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) + parser_utt2dur.add_argument("--labels", type=str, default=None, help="labels file") + parser_utt2dur.set_defaults(which="utt2dur") + + # Parse + args = parser.parse_args() + + # Run commands + runner = SubCommandRunner({ + "pred-distribution": pred_distribution, + "utt2dur": utt2dur + }) + + runner.run(args.which, args.__dict__, remove="which") \ No newline at end of file