Commit 9a2c6b4d026288745fe95708e3ff55f00a2351fd
1 parent
890a775449
Exists in
master
New file that help generating some stats and distribution stats (via plots)
Showing 1 changed file with 114 additions and 0 deletions Inline Diff
volia/stats.py
File was created | 1 | ||
2 | import argparse | ||
3 | |||
4 | import os | ||
5 | import core.data | ||
6 | import math | ||
7 | import numpy as np | ||
8 | import scipy.stats | ||
9 | import pickle | ||
10 | import matplotlib.pyplot as plt | ||
11 | import matplotlib.colors as mcolors | ||
12 | |||
13 | |||
14 | |||
15 | from cycler import cycler | ||
16 | |||
17 | def stats(): | ||
18 | print("Decisions") | ||
19 | |||
20 | |||
21 | print(list(mcolors.TABLEAU_COLORS)) | ||
22 | |||
23 | |||
24 | if __name__ == "__main__": | ||
25 | |||
26 | # Parser | ||
27 | parser = argparse.ArgumentParser(description="") | ||
28 | |||
29 | # Arguments | ||
30 | parser.add_argument("--predictions", type=str, help="prediction file", required=True) | ||
31 | parser.add_argument("--labels", type=str, help="label file", required=True) | ||
32 | parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | ||
33 | parser.add_argument("--outdir", type=str, help="output file", required=True) | ||
34 | |||
35 | args = parser.parse_args() | ||
36 | |||
37 | predictions = core.data.read_id_values(args.predictions, float) | ||
38 | labels = core.data.read_labels(args.labels) | ||
39 | |||
40 | le = None | ||
41 | with open(args.labelencoder, "rb") as f: | ||
42 | le = pickle.load(f) | ||
43 | stats = {} | ||
44 | |||
45 | print("PREDICTIONS ---------------------------") | ||
46 | for id_, predictions_ in predictions.items(): | ||
47 | label = labels[id_][0] | ||
48 | if label not in stats: | ||
49 | stats[label] = { | ||
50 | "nb_utt": 1, | ||
51 | "predictions": np.expand_dims(predictions_, axis=0) | ||
52 | } | ||
53 | else: | ||
54 | stats[label]["nb_utt"] = stats[label]["nb_utt"] + 1 | ||
55 | stats[label]["predictions"] = np.append(stats[label]["predictions"], np.expand_dims(predictions_, axis=0), axis=0) | ||
56 | |||
57 | |||
58 | print("CALCULATING ---------------------------") | ||
59 | |||
60 | |||
61 | colors = [ | ||
62 | "darkorange", | ||
63 | "red", | ||
64 | "blue" | ||
65 | ] | ||
66 | custom_cycler = (cycler(color=list(mcolors.TABLEAU_COLORS)) * | ||
67 | cycler(linestyle=['-', '--', '-.'])) | ||
68 | |||
69 | |||
70 | for label, stats_ in stats.items(): | ||
71 | |||
72 | plt.gca().set_prop_cycle(custom_cycler) | ||
73 | stats_mean = np.mean(stats_["predictions"], axis=0) | ||
74 | stats_std = np.std(stats_["predictions"], axis=0) | ||
75 | |||
76 | #print(label) | ||
77 | #print(stats_mean) | ||
78 | #print(stats_std) | ||
79 | kwargs = dict(alpha=0.5) | ||
80 | |||
81 | for i in range(stats_["predictions"].shape[1]): | ||
82 | label_str = le.inverse_transform([i])[0] | ||
83 | #plt.hist(stats_["predictions"][:, i], bins=10, label=label_str, **kwargs) | ||
84 | mu = stats_mean[i] | ||
85 | variance = stats_std[i] * stats_std[i] | ||
86 | sigma = stats_std[i] | ||
87 | # math.sqrt(variance) | ||
88 | print(f"{i}: mu {mu}, var {variance}, sigma {sigma}") | ||
89 | |||
90 | #x_values = np.arange(-1, 5, 0.1) | ||
91 | |||
92 | #y_values = scipy.stats.norm(mu, variance) | ||
93 | #y = scipy.stats.norm.pdf(x,mean,std) | ||
94 | |||
95 | #plt.plot(x_values, y_values.pdf(x_values,)) | ||
96 | |||
97 | #x, step = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000, retstep=True) | ||
98 | x = np.linspace(0, 1, 1000) | ||
99 | #x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000) | ||
100 | #x, step = np.linspace(0, 1, 1000, retstep=True) | ||
101 | |||
102 | P = scipy.stats.norm.cdf(x, mu, sigma) | ||
103 | #print(step) | ||
104 | plt.plot(x, P, label=label_str, **kwargs) | ||
105 | #plt.savefig("simple_gaussian.pdf") | ||
106 | |||
107 | plt.legend() | ||
108 | plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf")) | ||
109 | plt.clf() | ||
110 | |||
111 | |||
112 | # TODO: | ||
113 | # One graph for each label. Distribution of their predictions output are displayed. | ||
114 | |||
115 |