Commit 765b51bc7741c15001b3983ef6df4d68eedbcd62
1 parent
d27fe6fcc5
Exists in
master
Little modification to synchronize
Showing 2 changed files with 51 additions and 2 deletions Side-by-side Diff
volia/core/data.py
| ... | ... | @@ -7,7 +7,9 @@ |
| 7 | 7 | import sys |
| 8 | 8 | |
| 9 | 9 | # Defining some types |
| 10 | -from typing import List, Dict | |
| 10 | +from typing import List, Dict, Tuple | |
| 11 | + | |
| 12 | +from numpy.lib.shape_base import expand_dims | |
| 11 | 13 | KeyToList = Dict[str, List[str]] |
| 12 | 14 | KeyToLabels = Dict[str, List[str]] |
| 13 | 15 | KeyToIntLabels = Dict[str, List[int]] |
| ... | ... | @@ -66,6 +68,27 @@ |
| 66 | 68 | ''' |
| 67 | 69 | return read_id_values(file_path, np.float64) |
| 68 | 70 | |
| 71 | + | |
| 72 | +def read_features_with_matrix(file_path: str) -> Tuple[List[str], np.ndarray]: | |
| 73 | + """Read a features file and returns the keys (utterances ids) | |
| 74 | + with the corresponding matrix of values. | |
| 75 | + | |
| 76 | + Args: | |
| 77 | + file_path (str): path of the features file | |
| 78 | + | |
| 79 | + Returns: | |
| 80 | + [Tuple(List[str], np.ndarray)]: a tuple with a list of keys and the matrix | |
| 81 | + """ | |
| 82 | + data = read_id_values(file_path, np.float64) | |
| 83 | + keys = [] | |
| 84 | + matrix = None | |
| 85 | + for key, values in data.items(): | |
| 86 | + keys.append(key) | |
| 87 | + if matrix is None: | |
| 88 | + matrix = np.expand_dims(values, axis=0) | |
| 89 | + matrix = np.append(matrix, np.expand_dims(values, axis=0), axis=0) | |
| 90 | + | |
| 91 | + return (keys, matrix) | |
| 69 | 92 | |
| 70 | 93 | def read_labels(file_path: str) -> KeyToLabels: |
| 71 | 94 | ''' |
volia/stats.py
| ... | ... | @@ -93,6 +93,24 @@ |
| 93 | 93 | print("Decisions") |
| 94 | 94 | |
| 95 | 95 | |
| 96 | +def pred_distribution_wt_sel(predictions: str, labels: str, labelencoder: str, outdir: str): | |
| 97 | + | |
| 98 | + keys_preds, matrix_preds = core.data.read_features_with_matrix(predictions) | |
| 99 | + n = 3 | |
| 100 | + print(matrix_preds.shape) | |
| 101 | + for j in range(matrix_preds.shape[1]): | |
| 102 | + indices = (-matrix_preds[:, j]).argsort()[:n] | |
| 103 | + print(f"INDICE: {j}") | |
| 104 | + print("indices") | |
| 105 | + print(indices) | |
| 106 | + print("Best values") | |
| 107 | + print(matrix_preds[indices, j]) | |
| 108 | + print("All dimensions of best values") | |
| 109 | + print(matrix_preds[indices]) | |
| 110 | + # Select the n best for each column | |
| 111 | + pass | |
| 112 | + | |
| 113 | + | |
| 96 | 114 | def utt2dur(utt2dur: str, labels: str): |
| 97 | 115 | if labels == None: |
| 98 | 116 | pass |
| ... | ... | @@ -128,6 +146,13 @@ |
| 128 | 146 | parser_pred_dist.add_argument("--outdir", type=str, help="output file", required=True) |
| 129 | 147 | parser_pred_dist.set_defaults(which="pred_distribution") |
| 130 | 148 | |
| 149 | + # pred-distribution-with-selection | |
| 150 | + parser_pred_dist_wt_sel = subparsers.add_parser("pred-distribution-with-selection", help="plot distributions of prediction through labels with a selection of the n best records by column/class prediction.") | |
| 151 | + parser_pred_dist_wt_sel.add_argument("--predictions", type=str, help="prediction file", required=True) | |
| 152 | + parser_pred_dist_wt_sel.add_argument("--labels", type=str, help="label file", required=True) | |
| 153 | + parser_pred_dist_wt_sel.add_argument("--labelencoder", type=str, help="label encode pickle file", required=True) | |
| 154 | + parser_pred_dist_wt_sel.add_argument("--outdir", type=str, help="output file", required=True) | |
| 155 | + parser_pred_dist_wt_sel.set_defaults(which="pred_distribution_with_selection") | |
| 131 | 156 | # duration-stats |
| 132 | 157 | parser_utt2dur = subparsers.add_parser("utt2dur", help="distribution of utt2dur") |
| 133 | 158 | parser_utt2dur.add_argument("--utt2dur", type=str, help="utt2dur file", required=True) |
| ... | ... | @@ -139,7 +164,8 @@ |
| 139 | 164 | |
| 140 | 165 | # Run commands |
| 141 | 166 | runner = SubCommandRunner({ |
| 142 | - "pred-distribution": pred_distribution, | |
| 167 | + "pred_distribution": pred_distribution, | |
| 168 | + "pred_distribution_with_selection": pred_distribution_wt_sel, | |
| 143 | 169 | "utt2dur": utt2dur |
| 144 | 170 | }) |
| 145 | 171 |