diff --git a/volia/core/data.py b/volia/core/data.py index 30cabf8..6861f2b 100644 --- a/volia/core/data.py +++ b/volia/core/data.py @@ -7,7 +7,9 @@ import numpy as np import sys # Defining some types -from typing import List, Dict +from typing import List, Dict, Tuple + +from numpy.lib.shape_base import expand_dims KeyToList = Dict[str, List[str]] KeyToLabels = Dict[str, List[str]] KeyToIntLabels = Dict[str, List[int]] @@ -67,6 +69,27 @@ def read_features(file_path: str) -> KeyToFeatures: return read_id_values(file_path, np.float64) +def read_features_with_matrix(file_path: str) -> Tuple[List[str], np.ndarray]: + """Read a features file and returns the keys (utterances ids) + with the corresponding matrix of values. + + Args: + file_path (str): path of the features file + + Returns: + [Tuple(List[str], np.ndarray)]: a tuple with a list of keys and the matrix + """ + data = read_id_values(file_path, np.float64) + keys = [] + matrix = None + for key, values in data.items(): + keys.append(key) + if matrix is None: + matrix = np.expand_dims(values, axis=0) + matrix = np.append(matrix, np.expand_dims(values, axis=0), axis=0) + + return (keys, matrix) + def read_labels(file_path: str) -> KeyToLabels: ''' Read features files with the following structure : diff --git a/volia/stats.py b/volia/stats.py index fe20213..c22a75e 100644 --- a/volia/stats.py +++ b/volia/stats.py @@ -93,6 +93,24 @@ def pred_distribution(predictions: str, labels: str, labelencoder: str, outdir: print("Decisions") +def pred_distribution_wt_sel(predictions: str, labels: str, labelencoder: str, outdir: str): + + keys_preds, matrix_preds = core.data.read_features_with_matrix(predictions) + n = 3 + print(matrix_preds.shape) + for j in range(matrix_preds.shape[1]): + indices = (-matrix_preds[:, j]).argsort()[:n] + print(f"INDICE: {j}") + print("indices") + print(indices) + print("Best values") + print(matrix_preds[indices, j]) + print("All dimensions of best values") + print(matrix_preds[indices]) + # Select the n best for each column + pass + + def utt2dur(utt2dur: str, labels: str): if labels == None: pass @@ -128,6 +146,13 @@ if __name__ == "__main__": parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) parser_pred_dist.set_defaults(which="pred_distribution") + # pred-distribution-with-selection + parser_pred_dist_wt_sel = subparsers.add_parser("pred-distribution-with-selection", help="plot distributions of prediction through labels with a selection of the n best records by column/class prediction.") + parser_pred_dist_wt_sel.add_argument("--predictions", type=str, help="prediction file", required=True) + parser_pred_dist_wt_sel.add_argument("--labels", type=str, help="label file", required=True) + parser_pred_dist_wt_sel.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) + parser_pred_dist_wt_sel.add_argument("--outdir", type=str, help="output file", required=True) + parser_pred_dist_wt_sel.set_defaults(which="pred_distribution_with_selection") # duration-stats parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) @@ -139,7 +164,8 @@ if __name__ == "__main__": # Run commands runner = SubCommandRunner({ - "pred-distribution": pred_distribution, + "pred_distribution": pred_distribution, + "pred_distribution_with_selection": pred_distribution_wt_sel, "utt2dur": utt2dur })