Compare View

switch
from
...
to
 
Commits (3)

Changes

Showing 2 changed files Side-by-side Diff

... ... @@ -7,7 +7,9 @@ import numpy as np
7 7 import sys
8 8  
9 9 # Defining some types
10   -from typing import List, Dict
  10 +from typing import List, Dict, Tuple
  11 +
  12 +from numpy.lib.shape_base import expand_dims
11 13 KeyToList = Dict[str, List[str]]
12 14 KeyToLabels = Dict[str, List[str]]
13 15 KeyToIntLabels = Dict[str, List[int]]
... ... @@ -67,6 +69,27 @@ def read_features(file_path: str) -> KeyToFeatures:
67 69 return read_id_values(file_path, np.float64)
68 70  
69 71  
  72 +def read_features_with_matrix(file_path: str) -> Tuple[List[str], np.ndarray]:
  73 + """Read a features file and returns the keys (utterances ids)
  74 + with the corresponding matrix of values.
  75 +
  76 + Args:
  77 + file_path (str): path of the features file
  78 +
  79 + Returns:
  80 + [Tuple(List[str], np.ndarray)]: a tuple with a list of keys and the matrix
  81 + """
  82 + data = read_id_values(file_path, np.float64)
  83 + keys = []
  84 + matrix = None
  85 + for key, values in data.items():
  86 + keys.append(key)
  87 + if matrix is None:
  88 + matrix = np.expand_dims(values, axis=0)
  89 + matrix = np.append(matrix, np.expand_dims(values, axis=0), axis=0)
  90 +
  91 + return (keys, matrix)
  92 +
70 93 def read_labels(file_path: str) -> KeyToLabels:
71 94 '''
72 95 Read features files with the following structure :
... ... @@ -9,30 +9,12 @@ import scipy.stats
9 9 import pickle
10 10 import matplotlib.pyplot as plt
11 11 import matplotlib.colors as mcolors
12   -
  12 +from utils import SubCommandRunner
13 13  
14 14  
15 15 from cycler import cycler
16 16  
17   -def stats():
18   - print("Decisions")
19   -
20   -
21   -print(list(mcolors.TABLEAU_COLORS))
22   -
23   -
24   -if __name__ == "__main__":
25   -
26   - # Parser
27   - parser = argparse.ArgumentParser(description="")
28   -
29   - # Arguments
30   - parser.add_argument("--predictions", type=str, help="prediction file", required=True)
31   - parser.add_argument("--labels", type=str, help="label file", required=True)
32   - parser.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
33   - parser.add_argument("--outdir", type=str, help="output file", required=True)
34   -
35   - args = parser.parse_args()
  17 +def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: str):
36 18  
37 19 predictions = core.data.read_id_values(args.predictions, float)
38 20 labels = core.data.read_labels(args.labels)
... ... @@ -108,7 +90,83 @@ if __name__ == "__main__":
108 90 plt.savefig(os.path.join(args.outdir, f"{label}_prediction_cdf.pdf"))
109 91 plt.clf()
110 92  
  93 + print("Decisions")
111 94  
112   - # TODO:
113   - # One graph for each label. Distribution of their predictions output are displayed.
114   -
  95 +
  96 +def pred_distribution_wt_sel(predictions: str, labels: str, labelencoder: str, outdir: str):
  97 +
  98 + keys_preds, matrix_preds = core.data.read_features_with_matrix(predictions)
  99 + n = 3
  100 + print(matrix_preds.shape)
  101 + for j in range(matrix_preds.shape[1]):
  102 + indices = (-matrix_preds[:, j]).argsort()[:n]
  103 + print(f"INDICE: {j}")
  104 + print("indices")
  105 + print(indices)
  106 + print("Best values")
  107 + print(matrix_preds[indices, j])
  108 + print("All dimensions of best values")
  109 + print(matrix_preds[indices])
  110 + # Select the n best for each column
  111 + pass
  112 +
  113 +
  114 +def utt2dur(utt2dur: str, labels: str):
  115 + if labels == None:
  116 + pass
  117 + else:
  118 + pass
  119 +
  120 + durations = []
  121 + with open(utt2dur, "r") as f:
  122 + for line in f:
  123 + splited = line.replace("\n", "").split(" ")
  124 + durations.append(float(splited[1]))
  125 +
  126 + durations = np.asarray(durations, dtype=float)
  127 + print(durations.shape)
  128 + mean = np.mean(durations)
  129 + std = np.std(durations)
  130 +
  131 + print(f"mean: {mean}")
  132 + print(f"std: {std}")
  133 +
  134 +
  135 +if __name__ == "__main__":
  136 +
  137 + # Parser
  138 + parser = argparse.ArgumentParser(description="Statistics")
  139 + subparsers = parser.add_subparsers(title="actions")
  140 +
  141 + # pred-distribution
  142 + parser_pred_dist = subparsers.add_parser("pred-distribution", help="plot distributions of prediction through labels")
  143 + parser_pred_dist.add_argument("--predictions", type=str, help="prediction file", required=True)
  144 + parser_pred_dist.add_argument("--labels", type=str, help="label file", required=True)
  145 + parser_pred_dist.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
  146 + parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True)
  147 + parser_pred_dist.set_defaults(which="pred_distribution")
  148 +
  149 + # pred-distribution-with-selection
  150 + parser_pred_dist_wt_sel = subparsers.add_parser("pred-distribution-with-selection", help="plot distributions of prediction through labels with a selection of the n best records by column/class prediction.")
  151 + parser_pred_dist_wt_sel.add_argument("--predictions", type=str, help="prediction file", required=True)
  152 + parser_pred_dist_wt_sel.add_argument("--labels", type=str, help="label file", required=True)
  153 + parser_pred_dist_wt_sel.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True)
  154 + parser_pred_dist_wt_sel.add_argument("--outdir", type=str, help="output file", required=True)
  155 + parser_pred_dist_wt_sel.set_defaults(which="pred_distribution_with_selection")
  156 + # duration-stats
  157 + parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur")
  158 + parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True)
  159 + parser_utt2dur.add_argument("--labels", type=str, default=None, help="labels file")
  160 + parser_utt2dur.set_defaults(which="utt2dur")
  161 +
  162 + # Parse
  163 + args = parser.parse_args()
  164 +
  165 + # Run commands
  166 + runner = SubCommandRunner({
  167 + "pred_distribution": pred_distribution,
  168 + "pred_distribution_with_selection": pred_distribution_wt_sel,
  169 + "utt2dur": utt2dur
  170 + })
  171 +
  172 + runner.run(args.which, args.__dict__, remove="which")
115 173 \ No newline at end of file