Authored by Mathias
1 parent 4aa3a0ea73
Exists in

### Repair error about the definition of the axis for the multiplication

Showing 1 changed file with 15 additions and 3 deletions

volia/measures.py

 ... ... @@ -169,10 +169,15 @@ 169 169 return np.divide(a, divider, out=np.zeros_like(a), where=divider!=0) 170 170 171 171 def compute_purity_score(count_matrix, axis=0): 172 + if axis==0: 173 + other_axis = 1 174 + else: 175 + other_axis = 0 172 176 count_per_row = count_matrix.sum(axis=axis) 173 177 dividers = np.square(count_per_row) 178 + 174 179 count_matrix_squared = np.square(count_matrix) 175 - matrix_divided = np.apply_along_axis(divide_line, 0, np.asarray(count_matrix_squared, dtype=np.float), dividers) 180 + matrix_divided = np.apply_along_axis(divide_line, other_axis, np.asarray(count_matrix_squared, dtype=np.float), dividers) 176 181 vector_purity = np.sum(matrix_divided, axis=axis) 177 182 178 183 scalar_purity = np.average(vector_purity, weights=count_per_row) ... ... @@ -186,7 +191,6 @@ 186 191 K = np.sqrt(purity_cluster_score * purity_class_score) 187 192 188 193 for i in range(count_matrix.shape[0]): 189 - 190 194 for j in range(count_matrix.shape[1]): 191 195 count_matrix[i][j] 192 196 count_matrix[i] 193 197 194 198 195 199 ... ... @@ -198,15 +202,23 @@ 198 202 199 203 200 204 if __name__ == "__main__": 205 + print("Purity test #1") 201 206 # Hypothesis 202 207 y_hat = np.asarray([0, 1, 2, 0, 1, 0, 3, 2, 2, 3, 3, 0]) 203 208 # Truth 204 209 y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) 205 210 206 211 (result_matrix, result_vector, result) = entropy_score(y, y_hat) 212 + print(purity_score(y, y_hat)) 207 213 214 + exit(1) 215 + print("Purity test #2") 216 + # Hypothesis 217 + y_hat = np.asarray([0, 1, 2, 0, 1, 0, 3, 2, 2, 3, 3, 0, 4, 4, 4]) 218 + # Truth 219 + y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 0, 3, 3, 3]) 208 220 209 - print(purity_score(y, y_hat)) 221 + (result_matrix, result_vector, result) = entropy_score(y, y_hat) 210 222 exit(1) 211 223 print("Result matrix: ") 212 224 print(result_matrix)