ignite.metrics¶
Metrics provide a way to compute various quantities of interest in an online fashion without having to store the entire output history of a model.
In practice a user needs to attach the metric instance to an engine. The metric
value is then computed using the output of the engine’s process_function
:
from ignite.engine import Engine
from ignite.metrics import Accuracy
def process_function(engine, batch):
# ...
return y_pred, y
engine = Engine(process_function)
metric = Accuracy()
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")
If the engine’s output is not in the format (y_pred, y)
or {'y_pred': y_pred, 'y': y, ...}
, the user can
use the output_transform
argument to transform it:
from ignite.engine import Engine
from ignite.metrics import Accuracy
def process_function(engine, batch):
# ...
return {'y_pred': y_pred, 'y_true': y, ...}
engine = Engine(process_function)
def output_transform(output):
# `output` variable is returned by above `process_function`
y_pred = output['y_pred']
y = output['y_true']
return y_pred, y # output format is according to `Accuracy` docs
metric = Accuracy(output_transform=output_transform)
metric.attach(engine, "accuracy")
# ...
state = engine.run(data)
print(f"Accuracy: {state.metrics['accuracy']}")
Warning
Please, be careful when using lambda
functions to setup multiple output_transform
for multiple metrics
# Wrong
# metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names]
# As lambda can not store `name` and all `output_transform` will use the last `name`
# A correct way. For example, using functools.partial
from functools import partial
def ot_func(output, name):
return output[name]
metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names]
For more details, see here
Note
Most of implemented metrics are adapted to distributed computations and reduce their internal states across supported devices before computing metric value. This can be helpful to run the evaluation on multiple nodes/GPU instances/TPUs with a distributed data sampler. Following code snippet shows in detail how to use metrics:
device = f"cuda:{local_rank}"
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank, ],
output_device=local_rank)
test_sampler = DistributedSampler(test_dataset)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
sampler=test_sampler,
num_workers=num_workers,
pin_memory=True
)
evaluator = create_supervised_evaluator(model, metrics={'accuracy': Accuracy()}, device=device)
Note
Metrics cannot be serialized using pickle module because the implementation is based on lambda functions. Therefore, use the third party library dill to overcome the limitation of pickle.
Reset, Update, Compute API¶
User can also call directly the following methods on the metric:
reset()
: resets internal variables and accumulatorsupdate()
: updates internal variables and accumulators with provided batch output(y_pred, y)
compute()
: computes custom metric and return the result
This API gives a more fine-grained/custom usage on how to compute a metric. For example:
from ignite.metrics import Precision
# Define the metric
precision = Precision()
# Start accumulation:
for x, y in data:
y_pred = model(x)
precision.update((y_pred, y))
# Compute the result
print("Precision: ", precision.compute())
# Reset metric
precision.reset()
# Start new accumulation:
for x, y in data:
y_pred = model(x)
precision.update((y_pred, y))
# Compute new result
print("Precision: ", precision.compute())
Metric arithmetics¶
Metrics could be combined together to form new metrics. This could be done through arithmetics, such
as metric1 + metric2
, use PyTorch operators, such as (metric1 + metric2).pow(2).mean()
,
or use a lambda function, such as MetricsLambda(lambda a, b: torch.mean(a + b), metric1, metric2)
.
For example:
from ignite.metrics import Precision, Recall
precision = Precision(average=False)
recall = Recall(average=False)
F1 = (precision * recall * 2 / (precision + recall)).mean()
Note
This example computes the mean of F1 across classes. To combine
precision and recall to get F1 or other F metrics, we have to be careful
that average=False
, i.e. to use the unaveraged precision and recall,
otherwise we will not be computing F-beta metrics.
Metrics also support indexing operation (if metric’s result is a vector/matrix/tensor). For example, this can be useful to compute mean metric (e.g. precision, recall or IoU) ignoring the background:
from ignite.metrics import ConfusionMatrix
cm = ConfusionMatrix(num_classes=10)
iou_metric = IoU(cm)
iou_no_bg_metric = iou_metric[:9] # We assume that the background index is 9
mean_iou_no_bg_metric = iou_no_bg_metric.mean()
# mean_iou_no_bg_metric.compute() -> tensor(0.12345)
How to create a custom metric¶
To create a custom metric one needs to create a new class inheriting from Metric
and override
three methods :
reset()
: resets internal variables and accumulatorsupdate()
: updates internal variables and accumulators with provided batch output(y_pred, y)
compute()
: computes custom metric and return the result
For example, we would like to implement for illustration purposes a multi-class accuracy metric with some specific condition (e.g. ignore user-defined classes):
from ignite.metrics import Metric
from ignite.exceptions import NotComputableError
# These decorators helps with distributed settings
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
class CustomAccuracy(Metric):
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
self.ignored_class = ignored_class
self._num_correct = None
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)
@reinit__is_reduced
def reset(self):
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(CustomAccuracy, self).reset()
@reinit__is_reduced
def update(self, output):
y_pred, y = output[0].detach(), output[1].detach()
indices = torch.argmax(y_pred, dim=1)
mask = (y != self.ignored_class)
mask &= (indices != self.ignored_class)
y = y[mask]
indices = indices[mask]
correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]
@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._num_correct.item() / self._num_examples
We imported necessary classes as Metric
, NotComputableError
and
decorators to adapt the metric for distributed setting. In reset
method, we reset internal variables _num_correct
and _num_examples
which are used to compute the custom metric. In updated
method we define how to update
the internal variables. And finally in compute
method, we compute metric value.
Notice that _num_correct
is a tensor, since in update
we accumulate tensor values. _num_examples
is a python
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
adding them to the internal variables.
We can check this implementation in a simple case:
import torch
torch.manual_seed(8)
m = CustomAccuracy(ignored_class=3)
batch_size = 4
num_classes = 5
y_pred = torch.rand(batch_size, num_classes)
y = torch.randint(0, num_classes, size=(batch_size, ))
m.update((y_pred, y))
res = m.compute()
print(y, torch.argmax(y_pred, dim=1))
# Out: tensor([2, 2, 2, 3]) tensor([2, 1, 0, 0])
print(m._num_correct, m._num_examples, res)
# Out: 1 3 0.3333333333333333
Metrics and its usages¶
By default, Metrics are epoch-wise, it means
Usages can be user defined by creating a class inheriting for MetricUsage
. See the list below of usages.
Complete list of usages¶
Metrics and distributed computations¶
In the above example, CustomAccuracy
has reset
, update
, compute
methods decorated
with reinit__is_reduced()
, sync_all_reduce()
. The purpose of these features is to adapt metrics in distributed
computations on supported backend and devices (see ignite.distributed for more details). More precisely, in the above
example we added @sync_all_reduce("_num_examples", "_num_correct")
over compute
method. This means that when compute
method is called, metric’s interal variables self._num_examples
and self._num_correct
are summed up over all participating
devices. Therefore, once collected, these internal variables can be used to compute the final metric value.
Complete list of metrics¶
Helper class to compute arithmetic average of a single variable. |
|
Helper class to compute geometric average of a single variable. |
|
Single variable accumulator helper to compute (arithmetic, geometric, harmonic) average of a single variable. |
|
Calculates the accuracy for binary, multiclass and multilabel data. |
|
Calculates confusion matrix for multi-class data. |
|
Calculates Dice Coefficient for a given |
|
Calculates the Jaccard Index using |
|
Calculates Intersection over Union using |
|
Calculates mean Intersection over Union using |
|
Class for metrics that should be computed on the entire output history of a model. |
|
Calculates F-beta score. |
|
Provides metrics for the number of examples processed per second. |
|
Calculates the average loss according to the passed loss_fn. |
|
Calculates the mean absolute error. |
|
Calculates the mean |
|
Calculates the mean squared error. |
|
Base class for all Metrics. |
|
Apply a function to other metrics to obtain a new metric. |
|
Calculates a confusion matrix for multi-labelled, multi-class data. |
|
Calculates precision for binary and multiclass data. |
|
Computes average Peak signal-to-noise ratio (PSNR). |
|
Calculates recall for binary and multiclass data. |
|
Calculates the root mean squared error. |
|
Compute running average of a metric or the output of process function. |
|
Computes Structual Similarity Index Measure |
|
Calculates the top-k categorical accuracy. |
Helpers for customizing metrics¶
MetricUsage¶
-
class
ignite.metrics.metric.
MetricUsage
(started, completed, iteration_completed)[source]¶ Base class for all usages of metrics.
A usage of metric defines the events when a metric starts to compute, updates and completes. Valid events are from
Events
.- Parameters
started (ignite.engine.events.Events) – event when the metric starts to compute. This event will be associated to
started()
.completed (ignite.engine.events.Events) – event when the metric completes. This event will be associated to
completed()
.iteration_completed (ignite.engine.events.CallableEventWithFilter) – event when the metric updates. This event will be associated to
iteration_completed()
.
- Return type
EpochWise¶
-
class
ignite.metrics.metric.
EpochWise
[source]¶ Epoch-wise usage of Metrics. It’s the default and most common usage of metrics.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on everyITERATION_COMPLETED
.completed()
on everyEPOCH_COMPLETED
.
- Return type
BatchWise¶
-
class
ignite.metrics.metric.
BatchWise
[source]¶ Batch-wise usage of Metrics.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on everyITERATION_COMPLETED
.completed()
on everyITERATION_COMPLETED
.
- Return type
BatchFiltered¶
-
class
ignite.metrics.metric.
BatchFiltered
(*args, **kwargs)[source]¶ Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
Metric’s methods are triggered on the following engine events:
iteration_completed()
on filteredITERATION_COMPLETED
.completed()
on everyEPOCH_COMPLETED
.
- Parameters
args (Any) – Positional arguments to setup
ITERATION_COMPLETED
kwargs (Any) – Keyword arguments to setup
ITERATION_COMPLETED
handled byiteration_completed()
.
- Return type
reinit__is_reduced¶
-
ignite.metrics.metric.
reinit__is_reduced
(func)[source]¶ Helper decorator for distributed configuration.
See ignite.metrics on how to use it.
- Parameters
func (Callable) – A callable to reinit.
- Return type
Callable
sync_all_reduce¶
-
ignite.metrics.metric.
sync_all_reduce
(*attrs)[source]¶ Helper decorator for distributed configuration to collect instance attribute value across all participating processes.
See ignite.metrics on how to use it.
- Parameters
attrs (Any) – attribute names of decorated class
- Return type
Callable