Commit d27fe6fcc52cc5c3a4ea61289d4d4b0c38c53b83

Authored by Quillot Mathias
1 parent acbafc4147
Exists in master

add utt2dur statistics command and changed original stats command by a subcommand pred-distribution

Showing 1 changed file with 55 additions and 23 deletions Side-by-side Diff

... ... @@ -9,31 +9,13 @@
9 9 import pickle
10 10 import matplotlib.pyplot as plt
11 11 import matplotlib.colors as mcolors
  12 +from utils import SubCommandRunner
12 13  
13 14  
14   -
15 15 from cycler import cycler
16 16  
17   -def stats():
18   - print("Decisions")
  17 +def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: str):
19 18  
20   -
21   -print(list(mcolors.TABLEAU_COLORS))
22   -
23   -
24   -if __name__ == "__main__":
25   -
26   - # Parser
27   - parser = argparse.ArgumentParser(description="")
28   -
29   - # Arguments
30   - parser.add_argument("--predictions", type=str, help="prediction file", required=True)
31   - parser.add_argument("--labels", type=str, help="label file", required=True)
32   - parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
33   - parser.add_argument("--outdir", type=str, help="output file", required=True)
34   -
35   - args = parser.parse_args()
36   -
37 19 predictions = core.data.read_id_values(args.predictions, float)
38 20 labels = core.data.read_labels(args.labels)
39 21  
40 22  
... ... @@ -108,8 +90,58 @@
108 90 plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf"))
109 91 plt.clf()
110 92  
  93 + print("Decisions")
111 94  
112   - # TODO:
113   - # One graph for each label. Distribution of their predictions output are displayed.
114   -
  95 +
  96 +def utt2dur(utt2dur: str, labels: str):
  97 + if labels == None:
  98 + pass
  99 + else:
  100 + pass
  101 +
  102 + durations = []
  103 + with open(utt2dur, "r") as f:
  104 + for line in f:
  105 + splited = line.replace("\n", "").split(" ")
  106 + durations.append(float(splited[1]))
  107 +
  108 + durations = np.asarray(durations, dtype=float)
  109 + print(durations.shape)
  110 + mean = np.mean(durations)
  111 + std = np.std(durations)
  112 +
  113 + print(f"mean: {mean}")
  114 + print(f"std: {std}")
  115 +
  116 +
  117 +if __name__ == "__main__":
  118 +
  119 + # Parser
  120 + parser = argparse.ArgumentParser(description="Statistics")
  121 + subparsers = parser.add_subparsers(title="actions")
  122 +
  123 + # pred-distribution
  124 + parser_pred_dist = subparsers.add_parser("pred-distribution", help="plot distributions of prediction through labels")
  125 + parser_pred_dist.add_argument("--predictions", type=str, help="prediction file", required=True)
  126 + parser_pred_dist.add_argument("--labels", type=str, help="label file", required=True)
  127 + parser_pred_dist.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
  128 + parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True)
  129 + parser_pred_dist.set_defaults(which="pred_distribution")
  130 +
  131 + # duration-stats
  132 + parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur")
  133 + parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True)
  134 + parser_utt2dur.add_argument("--labels", type=str, default=None, help="labels file")
  135 + parser_utt2dur.set_defaults(which="utt2dur")
  136 +
  137 + # Parse
  138 + args = parser.parse_args()
  139 +
  140 + # Run commands
  141 + runner = SubCommandRunner({
  142 + "pred-distribution": pred_distribution,
  143 + "utt2dur": utt2dur
  144 + })
  145 +
  146 + runner.run(args.which, args.__dict__, remove="which")