4 | def th_confusion_matrix(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes=None): |
5 | """ |
6 | |
7 | Args: |
8 | y_true: 1-D tensor of shape [n_samples] |
9 | y_pred: 1-D tensor of shape [n_samples] |
10 | num_classes: scalar |
11 | Returns: |
12 | |
13 | """ |
14 | size = [num_classes + 1, num_classes + 1] if num_classes is not None else None |
15 | y_true = y_true.float() |
16 | y_pred = y_pred.float() |
17 | if size is None: |
18 | cm = torch.sparse_coo_tensor(indices=torch.stack([y_true, y_pred], dim=0), values=torch.ones_like(y_pred)) |
19 | else: |
20 | cm = torch.sparse_coo_tensor(indices=torch.stack([y_true, y_pred], dim=0), values=torch.ones_like(y_pred), |
21 | size=size) |
22 | return cm.to_dense()[1:, 1:] |