Commit 29318d99165a89b6bece42ad02ed4f878753008a

Authored by Mathias
1 parent 4aa3a0ea73
Exists in master

Repair error about the definition of the axis for the multiplication

Showing 1 changed file with 15 additions and 3 deletions Side-by-side Diff

... ... @@ -169,10 +169,15 @@
169 169 return np.divide(a, divider, out=np.zeros_like(a), where=divider!=0)
170 170  
171 171 def compute_purity_score(count_matrix, axis=0):
  172 + if axis==0:
  173 + other_axis = 1
  174 + else:
  175 + other_axis = 0
172 176 count_per_row = count_matrix.sum(axis=axis)
173 177 dividers = np.square(count_per_row)
  178 +
174 179 count_matrix_squared = np.square(count_matrix)
175   - matrix_divided = np.apply_along_axis(divide_line, 0, np.asarray(count_matrix_squared, dtype=np.float), dividers)
  180 + matrix_divided = np.apply_along_axis(divide_line, other_axis, np.asarray(count_matrix_squared, dtype=np.float), dividers)
176 181 vector_purity = np.sum(matrix_divided, axis=axis)
177 182  
178 183 scalar_purity = np.average(vector_purity, weights=count_per_row)
... ... @@ -186,7 +191,6 @@
186 191 K = np.sqrt(purity_cluster_score * purity_class_score)
187 192  
188 193 for i in range(count_matrix.shape[0]):
189   -
190 194 for j in range(count_matrix.shape[1]):
191 195 count_matrix[i][j]
192 196 count_matrix[i]
193 197  
194 198  
195 199  
... ... @@ -198,15 +202,23 @@
198 202  
199 203  
200 204 if __name__ == "__main__":
  205 + print("Purity test #1")
201 206 # Hypothesis
202 207 y_hat = np.asarray([0, 1, 2, 0, 1, 0, 3, 2, 2, 3, 3, 0])
203 208 # Truth
204 209 y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
205 210  
206 211 (result_matrix, result_vector, result) = entropy_score(y, y_hat)
  212 + print(purity_score(y, y_hat))
207 213  
  214 + exit(1)
  215 + print("Purity test #2")
  216 + # Hypothesis
  217 + y_hat = np.asarray([0, 1, 2, 0, 1, 0, 3, 2, 2, 3, 3, 0, 4, 4, 4])
  218 + # Truth
  219 + y = np.asarray([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 0, 3, 3, 3])
208 220  
209   - print(purity_score(y, y_hat))
  221 + (result_matrix, result_vector, result) = entropy_score(y, y_hat)
210 222 exit(1)
211 223 print("Result matrix: ")
212 224 print(result_matrix)