Commit 765b51bc7741c15001b3983ef6df4d68eedbcd62
1 parent
d27fe6fcc5
Exists in
master
Little modification to synchronize
Showing 2 changed files with 51 additions and 2 deletions Side-by-side Diff
volia/core/data.py
... | ... | @@ -7,7 +7,9 @@ |
7 | 7 | import sys |
8 | 8 | |
9 | 9 | # Defining some types |
10 | -from typing import List, Dict | |
10 | +from typing import List, Dict, Tuple | |
11 | + | |
12 | +from numpy.lib.shape_base import expand_dims | |
11 | 13 | KeyToList = Dict[str, List[str]] |
12 | 14 | KeyToLabels = Dict[str, List[str]] |
13 | 15 | KeyToIntLabels = Dict[str, List[int]] |
... | ... | @@ -66,6 +68,27 @@ |
66 | 68 | ''' |
67 | 69 | return read_id_values(file_path, np.float64) |
68 | 70 | |
71 | + | |
72 | +def read_features_with_matrix(file_path: str) -> Tuple[List[str], np.ndarray]: | |
73 | + """Read a features file and returns the keys (utterances ids) | |
74 | + with the corresponding matrix of values. | |
75 | + | |
76 | + Args: | |
77 | + file_path (str): path of the features file | |
78 | + | |
79 | + Returns: | |
80 | + [Tuple(List[str], np.ndarray)]: a tuple with a list of keys and the matrix | |
81 | + """ | |
82 | + data = read_id_values(file_path, np.float64) | |
83 | + keys = [] | |
84 | + matrix = None | |
85 | + for key, values in data.items(): | |
86 | + keys.append(key) | |
87 | + if matrix is None: | |
88 | + matrix = np.expand_dims(values, axis=0) | |
89 | + matrix = np.append(matrix, np.expand_dims(values, axis=0), axis=0) | |
90 | + | |
91 | + return (keys, matrix) | |
69 | 92 | |
70 | 93 | def read_labels(file_path: str) -> KeyToLabels: |
71 | 94 | ''' |
volia/stats.py
... | ... | @@ -93,6 +93,24 @@ |
93 | 93 | print("Decisions") |
94 | 94 | |
95 | 95 | |
96 | +def pred_distribution_wt_sel(predictions: str, labels: str, labelencoder: str, outdir: str): | |
97 | + | |
98 | + keys_preds, matrix_preds = core.data.read_features_with_matrix(predictions) | |
99 | + n = 3 | |
100 | + print(matrix_preds.shape) | |
101 | + for j in range(matrix_preds.shape[1]): | |
102 | + indices = (-matrix_preds[:, j]).argsort()[:n] | |
103 | + print(f"INDICE: {j}") | |
104 | + print("indices") | |
105 | + print(indices) | |
106 | + print("Best values") | |
107 | + print(matrix_preds[indices, j]) | |
108 | + print("All dimensions of best values") | |
109 | + print(matrix_preds[indices]) | |
110 | + # Select the n best for each column | |
111 | + pass | |
112 | + | |
113 | + | |
96 | 114 | def utt2dur(utt2dur: str, labels: str): |
97 | 115 | if labels == None: |
98 | 116 | pass |
... | ... | @@ -128,6 +146,13 @@ |
128 | 146 | parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) |
129 | 147 | parser_pred_dist.set_defaults(which="pred_distribution") |
130 | 148 | |
149 | + # pred-distribution-with-selection | |
150 | + parser_pred_dist_wt_sel = subparsers.add_parser("pred-distribution-with-selection", help="plot distributions of prediction through labels with a selection of the n best records by column/class prediction.") | |
151 | + parser_pred_dist_wt_sel.add_argument("--predictions", type=str, help="prediction file", required=True) | |
152 | + parser_pred_dist_wt_sel.add_argument("--labels", type=str, help="label file", required=True) | |
153 | + parser_pred_dist_wt_sel.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
154 | + parser_pred_dist_wt_sel.add_argument("--outdir", type=str, help="output file", required=True) | |
155 | + parser_pred_dist_wt_sel.set_defaults(which="pred_distribution_with_selection") | |
131 | 156 | # duration-stats |
132 | 157 | parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") |
133 | 158 | parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) |
... | ... | @@ -139,7 +164,8 @@ |
139 | 164 | |
140 | 165 | # Run commands |
141 | 166 | runner = SubCommandRunner({ |
142 | - "pred-distribution": pred_distribution, | |
167 | + "pred_distribution": pred_distribution, | |
168 | + "pred_distribution_with_selection": pred_distribution_wt_sel, | |
143 | 169 | "utt2dur": utt2dur |
144 | 170 | }) |
145 | 171 |