Commit d27fe6fcc52cc5c3a4ea61289d4d4b0c38c53b83
1 parent
acbafc4147
Exists in
master
add utt2dur statistics command and changed original stats command by a subcommand pred-distribution
Showing 1 changed file with 55 additions and 23 deletions Side-by-side Diff
volia/stats.py
... | ... | @@ -9,31 +9,13 @@ |
9 | 9 | import pickle |
10 | 10 | import matplotlib.pyplot as plt |
11 | 11 | import matplotlib.colors as mcolors |
12 | +from utils import SubCommandRunner | |
12 | 13 | |
13 | 14 | |
14 | - | |
15 | 15 | from cycler import cycler |
16 | 16 | |
17 | -def stats(): | |
18 | - print("Decisions") | |
17 | +def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: str): | |
19 | 18 | |
20 | - | |
21 | -print(list(mcolors.TABLEAU_COLORS)) | |
22 | - | |
23 | - | |
24 | -if __name__ == "__main__": | |
25 | - | |
26 | - # Parser | |
27 | - parser = argparse.ArgumentParser(description="") | |
28 | - | |
29 | - # Arguments | |
30 | - parser.add_argument("--predictions", type=str, help="prediction file", required=True) | |
31 | - parser.add_argument("--labels", type=str, help="label file", required=True) | |
32 | - parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
33 | - parser.add_argument("--outdir", type=str, help="output file", required=True) | |
34 | - | |
35 | - args = parser.parse_args() | |
36 | - | |
37 | 19 | predictions = core.data.read_id_values(args.predictions, float) |
38 | 20 | labels = core.data.read_labels(args.labels) |
39 | 21 | |
40 | 22 | |
... | ... | @@ -108,8 +90,58 @@ |
108 | 90 | plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf")) |
109 | 91 | plt.clf() |
110 | 92 | |
93 | + print("Decisions") | |
111 | 94 | |
112 | - # TODO: | |
113 | - # One graph for each label. Distribution of their predictions output are displayed. | |
114 | - | |
95 | + | |
96 | +def utt2dur(utt2dur: str, labels: str): | |
97 | + if labels == None: | |
98 | + pass | |
99 | + else: | |
100 | + pass | |
101 | + | |
102 | + durations = [] | |
103 | + with open(utt2dur, "r") as f: | |
104 | + for line in f: | |
105 | + splited = line.replace("\n", "").split(" ") | |
106 | + durations.append(float(splited[1])) | |
107 | + | |
108 | + durations = np.asarray(durations, dtype=float) | |
109 | + print(durations.shape) | |
110 | + mean = np.mean(durations) | |
111 | + std = np.std(durations) | |
112 | + | |
113 | + print(f"mean: {mean}") | |
114 | + print(f"std: {std}") | |
115 | + | |
116 | + | |
117 | +if __name__ == "__main__": | |
118 | + | |
119 | + # Parser | |
120 | + parser = argparse.ArgumentParser(description="Statistics") | |
121 | + subparsers = parser.add_subparsers(title="actions") | |
122 | + | |
123 | + # pred-distribution | |
124 | + parser_pred_dist = subparsers.add_parser("pred-distribution", help="plot distributions of prediction through labels") | |
125 | + parser_pred_dist.add_argument("--predictions", type=str, help="prediction file", required=True) | |
126 | + parser_pred_dist.add_argument("--labels", type=str, help="label file", required=True) | |
127 | + parser_pred_dist.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
128 | + parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) | |
129 | + parser_pred_dist.set_defaults(which="pred_distribution") | |
130 | + | |
131 | + # duration-stats | |
132 | + parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") | |
133 | + parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) | |
134 | + parser_utt2dur.add_argument("--labels", type=str, default=None, help="labels file") | |
135 | + parser_utt2dur.set_defaults(which="utt2dur") | |
136 | + | |
137 | + # Parse | |
138 | + args = parser.parse_args() | |
139 | + | |
140 | + # Run commands | |
141 | + runner = SubCommandRunner({ | |
142 | + "pred-distribution": pred_distribution, | |
143 | + "utt2dur": utt2dur | |
144 | + }) | |
145 | + | |
146 | + runner.run(args.which, args.__dict__, remove="which") |