Commit 9a2c6b4d026288745fe95708e3ff55f00a2351fd

Authored by Quillot Mathias
1 parent 890a775449
Exists in master

New file that help generating some stats and distribution stats (via plots)

Showing 1 changed file with 114 additions and 0 deletions Inline Diff

File was created 1
2 import argparse
3
4 import os
5 import core.data
6 import math
7 import numpy as np
8 import scipy.stats
9 import pickle
10 import matplotlib.pyplot as plt
11 import matplotlib.colors as mcolors
12
13
14
15 from cycler import cycler
16
17 def stats():
18 print("Decisions")
19
20
21 print(list(mcolors.TABLEAU_COLORS))
22
23
24 if __name__ == "__main__":
25
26 # Parser
27 parser = argparse.ArgumentParser(description="")
28
29 # Arguments
30 parser.add_argument("--predictions", type=str, help="prediction file", required=True)
31 parser.add_argument("--labels", type=str, help="label file", required=True)
32 parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
33 parser.add_argument("--outdir", type=str, help="output file", required=True)
34
35 args = parser.parse_args()
36
37 predictions = core.data.read_id_values(args.predictions, float)
38 labels = core.data.read_labels(args.labels)
39
40 le = None
41 with open(args.labelencoder, "rb") as f:
42 le = pickle.load(f)
43 stats = {}
44
45 print("PREDICTIONS ---------------------------")
46 for id_, predictions_ in predictions.items():
47 label = labels[id_][0]
48 if label not in stats:
49 stats[label] = {
50 "nb_utt": 1,
51 "predictions": np.expand_dims(predictions_, axis=0)
52 }
53 else:
54 stats[label]["nb_utt"] = stats[label]["nb_utt"] + 1
55 stats[label]["predictions"] = np.append(stats[label]["predictions"], np.expand_dims(predictions_, axis=0), axis=0)
56
57
58 print("CALCULATING ---------------------------")
59
60
61 colors = [
62 "darkorange",
63 "red",
64 "blue"
65 ]
66 custom_cycler = (cycler(color=list(mcolors.TABLEAU_COLORS)) *
67 cycler(linestyle=['-', '--', '-.']))
68
69
70 for label, stats_ in stats.items():
71
72 plt.gca().set_prop_cycle(custom_cycler)
73 stats_mean = np.mean(stats_["predictions"], axis=0)
74 stats_std = np.std(stats_["predictions"], axis=0)
75
76 #print(label)
77 #print(stats_mean)
78 #print(stats_std)
79 kwargs = dict(alpha=0.5)
80
81 for i in range(stats_["predictions"].shape[1]):
82 label_str = le.inverse_transform([i])[0]
83 #plt.hist(stats_["predictions"][:, i], bins=10, label=label_str, **kwargs)
84 mu = stats_mean[i]
85 variance = stats_std[i] * stats_std[i]
86 sigma = stats_std[i]
87 # math.sqrt(variance)
88 print(f"{i}: mu {mu}, var {variance}, sigma {sigma}")
89
90 #x_values = np.arange(-1, 5, 0.1)
91
92 #y_values = scipy.stats.norm(mu, variance)
93 #y = scipy.stats.norm.pdf(x,mean,std)
94
95 #plt.plot(x_values, y_values.pdf(x_values,))
96
97 #x, step = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000, retstep=True)
98 x = np.linspace(0, 1, 1000)
99 #x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000)
100 #x, step = np.linspace(0, 1, 1000, retstep=True)
101
102 P = scipy.stats.norm.cdf(x, mu, sigma)
103 #print(step)
104 plt.plot(x, P, label=label_str, **kwargs)
105 #plt.savefig("simple_gaussian.pdf")
106
107 plt.legend()
108 plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf"))
109 plt.clf()
110
111
112 # TODO:
113 # One graph for each label. Distribution of their predictions output are displayed.
114
115