Commit 9a2c6b4d026288745fe95708e3ff55f00a2351fd

Authored by Quillot Mathias
1 parent 890a775449
Exists in master

New file that help generating some stats and distribution stats (via plots)

Showing 1 changed file with 114 additions and 0 deletions Side-by-side Diff

  1 +
  2 +import argparse
  3 +
  4 +import os
  5 +import core.data
  6 +import math
  7 +import numpy as np
  8 +import scipy.stats
  9 +import pickle
  10 +import matplotlib.pyplot as plt
  11 +import matplotlib.colors as mcolors
  12 +
  13 +
  14 +
  15 +from cycler import cycler
  16 +
  17 +def stats():
  18 + print("Decisions")
  19 +
  20 +
  21 +print(list(mcolors.TABLEAU_COLORS))
  22 +
  23 +
  24 +if __name__ == "__main__":
  25 +
  26 + # Parser
  27 + parser = argparse.ArgumentParser(description="")
  28 +
  29 + # Arguments
  30 + parser.add_argument("--predictions", type=str, help="prediction file", required=True)
  31 + parser.add_argument("--labels", type=str, help="label file", required=True)
  32 + parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
  33 + parser.add_argument("--outdir", type=str, help="output file", required=True)
  34 +
  35 + args = parser.parse_args()
  36 +
  37 + predictions = core.data.read_id_values(args.predictions, float)
  38 + labels = core.data.read_labels(args.labels)
  39 +
  40 + le = None
  41 + with open(args.labelencoder, "rb") as f:
  42 + le = pickle.load(f)
  43 + stats = {}
  44 +
  45 + print("PREDICTIONS ---------------------------")
  46 + for id_, predictions_ in predictions.items():
  47 + label = labels[id_][0]
  48 + if label not in stats:
  49 + stats[label] = {
  50 + "nb_utt": 1,
  51 + "predictions": np.expand_dims(predictions_, axis=0)
  52 + }
  53 + else:
  54 + stats[label]["nb_utt"] = stats[label]["nb_utt"] + 1
  55 + stats[label]["predictions"] = np.append(stats[label]["predictions"], np.expand_dims(predictions_, axis=0), axis=0)
  56 +
  57 +
  58 + print("CALCULATING ---------------------------")
  59 +
  60 +
  61 + colors = [
  62 + "darkorange",
  63 + "red",
  64 + "blue"
  65 + ]
  66 + custom_cycler = (cycler(color=list(mcolors.TABLEAU_COLORS)) *
  67 + cycler(linestyle=['-', '--', '-.']))
  68 +
  69 +
  70 + for label, stats_ in stats.items():
  71 +
  72 + plt.gca().set_prop_cycle(custom_cycler)
  73 + stats_mean = np.mean(stats_["predictions"], axis=0)
  74 + stats_std = np.std(stats_["predictions"], axis=0)
  75 +
  76 + #print(label)
  77 + #print(stats_mean)
  78 + #print(stats_std)
  79 + kwargs = dict(alpha=0.5)
  80 +
  81 + for i in range(stats_["predictions"].shape[1]):
  82 + label_str = le.inverse_transform([i])[0]
  83 + #plt.hist(stats_["predictions"][:, i], bins=10, label=label_str, **kwargs)
  84 + mu = stats_mean[i]
  85 + variance = stats_std[i] * stats_std[i]
  86 + sigma = stats_std[i]
  87 + # math.sqrt(variance)
  88 + print(f"{i}: mu {mu}, var {variance}, sigma {sigma}")
  89 +
  90 + #x_values = np.arange(-1, 5, 0.1)
  91 +
  92 + #y_values = scipy.stats.norm(mu, variance)
  93 + #y = scipy.stats.norm.pdf(x,mean,std)
  94 +
  95 + #plt.plot(x_values, y_values.pdf(x_values,))
  96 +
  97 + #x, step = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000, retstep=True)
  98 + x = np.linspace(0, 1, 1000)
  99 + #x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000)
  100 + #x, step = np.linspace(0, 1, 1000, retstep=True)
  101 +
  102 + P = scipy.stats.norm.cdf(x, mu, sigma)
  103 + #print(step)
  104 + plt.plot(x, P, label=label_str, **kwargs)
  105 + #plt.savefig("simple_gaussian.pdf")
  106 +
  107 + plt.legend()
  108 + plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf"))
  109 + plt.clf()
  110 +
  111 +
  112 + # TODO:
  113 + # One graph for each label. Distribution of their predictions output are displayed.
  114 +