Commit d27fe6fcc52cc5c3a4ea61289d4d4b0c38c53b83
1 parent
acbafc4147
Exists in
master
add utt2dur statistics command and changed original stats command by a subcommand pred-distribution
Showing 1 changed file with 55 additions and 23 deletions Side-by-side Diff
volia/stats.py
| ... | ... | @@ -9,31 +9,13 @@ |
| 9 | 9 | import pickle |
| 10 | 10 | import matplotlib.pyplot as plt |
| 11 | 11 | import matplotlib.colors as mcolors |
| 12 | +from utils import SubCommandRunner | |
| 12 | 13 | |
| 13 | 14 | |
| 14 | - | |
| 15 | 15 | from cycler import cycler |
| 16 | 16 | |
| 17 | -def stats(): | |
| 18 | - print("Decisions") | |
| 17 | +def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: str): | |
| 19 | 18 | |
| 20 | - | |
| 21 | -print(list(mcolors.TABLEAU_COLORS)) | |
| 22 | - | |
| 23 | - | |
| 24 | -if __name__ == "__main__": | |
| 25 | - | |
| 26 | - # Parser | |
| 27 | - parser = argparse.ArgumentParser(description="") | |
| 28 | - | |
| 29 | - # Arguments | |
| 30 | - parser.add_argument("--predictions", type=str, help="prediction file", required=True) | |
| 31 | - parser.add_argument("--labels", type=str, help="label file", required=True) | |
| 32 | - parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
| 33 | - parser.add_argument("--outdir", type=str, help="output file", required=True) | |
| 34 | - | |
| 35 | - args = parser.parse_args() | |
| 36 | - | |
| 37 | 19 | predictions = core.data.read_id_values(args.predictions, float) |
| 38 | 20 | labels = core.data.read_labels(args.labels) |
| 39 | 21 | |
| 40 | 22 | |
| ... | ... | @@ -108,8 +90,58 @@ |
| 108 | 90 | plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf")) |
| 109 | 91 | plt.clf() |
| 110 | 92 | |
| 93 | + print("Decisions") | |
| 111 | 94 | |
| 112 | - # TODO: | |
| 113 | - # One graph for each label. Distribution of their predictions output are displayed. | |
| 114 | - | |
| 95 | + | |
| 96 | +def utt2dur(utt2dur: str, labels: str): | |
| 97 | + if labels == None: | |
| 98 | + pass | |
| 99 | + else: | |
| 100 | + pass | |
| 101 | + | |
| 102 | + durations = [] | |
| 103 | + with open(utt2dur, "r") as f: | |
| 104 | + for line in f: | |
| 105 | + splited = line.replace("\n", "").split(" ") | |
| 106 | + durations.append(float(splited[1])) | |
| 107 | + | |
| 108 | + durations = np.asarray(durations, dtype=float) | |
| 109 | + print(durations.shape) | |
| 110 | + mean = np.mean(durations) | |
| 111 | + std = np.std(durations) | |
| 112 | + | |
| 113 | + print(f"mean: {mean}") | |
| 114 | + print(f"std: {std}") | |
| 115 | + | |
| 116 | + | |
| 117 | +if __name__ == "__main__": | |
| 118 | + | |
| 119 | + # Parser | |
| 120 | + parser = argparse.ArgumentParser(description="Statistics") | |
| 121 | + subparsers = parser.add_subparsers(title="actions") | |
| 122 | + | |
| 123 | + # pred-distribution | |
| 124 | + parser_pred_dist = subparsers.add_parser("pred-distribution", help="plot distributions of prediction through labels") | |
| 125 | + parser_pred_dist.add_argument("--predictions", type=str, help="prediction file", required=True) | |
| 126 | + parser_pred_dist.add_argument("--labels", type=str, help="label file", required=True) | |
| 127 | + parser_pred_dist.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
| 128 | + parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) | |
| 129 | + parser_pred_dist.set_defaults(which="pred_distribution") | |
| 130 | + | |
| 131 | + # duration-stats | |
| 132 | + parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") | |
| 133 | + parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) | |
| 134 | + parser_utt2dur.add_argument("--labels", type=str, default=None, help="labels file") | |
| 135 | + parser_utt2dur.set_defaults(which="utt2dur") | |
| 136 | + | |
| 137 | + # Parse | |
| 138 | + args = parser.parse_args() | |
| 139 | + | |
| 140 | + # Run commands | |
| 141 | + runner = SubCommandRunner({ | |
| 142 | + "pred-distribution": pred_distribution, | |
| 143 | + "utt2dur": utt2dur | |
| 144 | + }) | |
| 145 | + | |
| 146 | + runner.run(args.which, args.__dict__, remove="which") |