MultiLabelConfusionMatrix¶
- 
class ignite.metrics.MultiLabelConfusionMatrix(num_classes, output_transform=<function MultiLabelConfusionMatrix.<lambda>>, device=device(type='cpu'), normalized=False)[source]¶
- Calculates a confusion matrix for multi-labelled, multi-class data. - updatemust receive output of the form- (y_pred, y).
- y_pred must contain 0s and 1s and has the following shape (batch_size, num_classes, …). For example, y_pred[i, j] = 1 denotes that the j’th class is one of the labels of the i’th sample as predicted. 
- y should have the following shape (batch_size, num_classes, …) with 0s and 1s. For example, y[i, j] = 1 denotes that the j’th class is one of the labels of the i’th sample according to the ground truth. 
- both y and y_pred must be torch Tensors having any of the following types: {torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}. They must have the same dimensions. 
- The confusion matrix ‘M’ is of dimension (num_classes, 2, 2). - M[i, 0, 0] corresponds to count/rate of true negatives of class i 
- M[i, 0, 1] corresponds to count/rate of false positives of class i 
- M[i, 1, 0] corresponds to count/rate of false negatives of class i 
- M[i, 1, 1] corresponds to count/rate of true positives of class i 
 
- The classes present in M are indexed as 0, … , num_classes-1 as can be inferred from above. 
 - Parameters
- num_classes (int) – Number of classes, should be > 1. 
- output_transform (Callable) – a callable that is used to transform the - Engine’s- process_function’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.
- device (Union[str, torch.device]) – specifies which device updates are accumulated on. Setting the metric’s device to be the same as your - updatearguments ensures the- updatemethod is non-blocking. By default, CPU.
- normalized (bool) – whether to normalize confusion matrix by its sum or not. 
 
 - New in version 0.5.0. - Methods - Computes the metric based on it’s accumulated state. - Resets the metric to it’s initial state. - Updates the metric’s state using the passed batch output. - 
compute()[source]¶
- Computes the metric based on it’s accumulated state. - By default, this is called at the end of each epoch. - Returns
- the actual quantity of interest. However, if aMappingis returned, it will be (shallow) flattened into engine.state.metrics whencompleted()is called.
- Return type
- Any 
- Raises
- NotComputableError – raised when the metric cannot be computed. 
 
 - 
reset()[source]¶
- Resets the metric to it’s initial state. - By default, this is called at the start of each epoch. - Return type
 
 - 
update(output)[source]¶
- Updates the metric’s state using the passed batch output. - By default, this is called once for each batch. - Parameters
- output (Sequence[torch.Tensor]) – the is the output from the engine’s process function. 
- Return type