Shortcuts

Checkpoint

class ignite.handlers.checkpoint.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, filename_pattern=None, include_self=False, greater_or_equal=False)[source]

Checkpoint handler can be used to periodically save and load objects which have attribute state_dict/load_state_dict. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with DiskSaver) also handles automatically moving data on TPU to CPU before writing the checkpoint.

Parameters
  • to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented state_dict and load_state_dict methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional key module. in the state dictionary).

  • save_handler (Union[Callable, ignite.handlers.checkpoint.BaseSaveHandler]) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If save_handler is callable class, it can inherit of BaseSaveHandler and optionally implement remove method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk, save_handler can be defined with DiskSaver.

  • filename_prefix (str) – Prefix for the file name to which objects will be saved. See Note for details.

  • score_function (Optional[Callable]) – If not None, it should be a function taking a single argument, Engine object, and returning a score (float). Objects with highest scores will be retained.

  • score_name (Optional[str]) – If score_function not None, it is possible to store its value using score_name. See Notes for more details.

  • n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.

  • global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use global_step_from_engine().

  • filename_pattern (Optional[str]) – If filename_pattern is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.

  • include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in to_save with key checkpointer.

  • greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.

Note

This class stores a single file as a dictionary of provided objects to save. The filename is defined by filename_pattern and by default has the following structure: {filename_prefix}_{name}_{suffix}.{ext} where

  • filename_prefix is the argument passed to the constructor,

  • name is the key in to_save if a single object is to store, otherwise name is “checkpoint”.

  • suffix is composed as following {global_step}_{score_name}={score}.

score_function

score_name

global_step_transform

suffix

None

None

None

{engine.state.iteration}

X

None

None

{score}

X

None

X

{global_step}_{score}

X

X

X

{global_step}_{score_name}={score}

None

None

X

{global_step}

X

X

None

{score_name}={score}

Above global_step defined by the output of global_step_transform and score defined by the output of score_function.

By default, none of score_function, score_name, global_step_transform is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.

For example, score_name="neg_val_loss" and score_function that returns -loss (as objects with highest scores will be retained), then saved filename will be {filename_prefix}_{name}_neg_val_loss=-0.1234.pt.

Note

If filename_pattern is given, it will be used to render the filenames. filename_pattern is a string that can contain {filename_prefix}, {name}, {score}, {score_name} and {global_step} as templates.

For example, let filename_pattern="{global_step}-{name}-{score}.pt" then the saved filename will be 30000-checkpoint-94.pt

Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like checkpoint.pt will be saved every call and will always be overwritten by newer checkpoints.

Note

To get the last stored filename, handler exposes attribute last_checkpoint:

handler = Checkpoint(...)
...
print(handler.last_checkpoint)
> checkpoint_12345.pt

Note

This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with DiskSaver, checkpoint is stored by rank 0 process.

Warning

When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint.

# Wrong:
# if idist.get_rank() == 0:
#     handler = Checkpoint(...)
#     trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

# Correct:
handler = Checkpoint(...)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)

Examples

Attach the handler to make checkpoints during training:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver

trainer = ...
model = ...
optimizer = ...
lr_scheduler = ...

to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}

if (checkpoint_iters):
    # A: Output is "checkpoint_<iteration>.pt"
    handler = Checkpoint(
        to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2
    )
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
else:
    # B:Output is "checkpoint_<epoch>.pt"
    gst = lambda *_: trainer.state.epoch
    handler = Checkpoint(
        to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, global_step_transform=gst
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

trainer.run(data_loader, max_epochs=6)
> A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ]
> B: ["checkpoint_5.pt", "checkpoint_6.pt", ]

Attach the handler to an evaluator to save best model during the training according to computed validation metric:

from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine

trainer = ...
evaluator = ...
# Setup Accuracy metric computation on evaluator
# Run evaluation on epoch completed event
# ...

score_function = Checkpoint.get_default_score_fn("accuracy")

to_save = {'model': model}
handler = Checkpoint(
    to_save, DiskSaver('/tmp/models', create_dir=True),
    n_saved=2, filename_prefix='best',
    score_function=score_function, score_name="val_acc",
    global_step_transform=global_step_from_engine(trainer)
)

evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(data_loader, max_epochs=10)
> ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
    Changed in version 0.4.3:
  • Checkpoint can save model with same filename.

  • Added greater_or_equal argument.

Methods

get_default_score_fn

Helper method to get default score function based on the metric name.

load_objects

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

load_state_dict

Method replace internal state of the class with provided state dict data.

reset

Method to reset saved checkpoint names.

setup_filename_pattern

Helper method to get the default filename pattern for a checkpoint.

state_dict

Method returns state dict with saved items: list of (priority, filename) pairs.

class Item(priority, filename)

Create new instance of Item(priority, filename)

Parameters
  • priority (int) –

  • filename (str) –

property filename

Alias for field number 1

property priority

Alias for field number 0

static get_default_score_fn(metric_name, score_sign=1.0)[source]

Helper method to get default score function based on the metric name.

Parameters
  • metric_name (str) – metric name to get the value from engine.state.metrics. Engine is the one to which Checkpoint handler is added.

  • score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Return type

Callable

Exemples:

from ignite.handlers import Checkpoint

best_acc_score = Checkpoint.get_default_score_fn("accuracy")

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

Usage with error-like metric:

from ignite.handlers import Checkpoint

neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0)

best_model_handler = Checkpoint(
    to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

New in version 0.4.3.

static load_objects(to_load, checkpoint, **kwargs)[source]

Helper method to apply load_state_dict on the objects from to_load using states from checkpoint.

Exemples:

import torch
from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

Note

If to_load contains objects of type torch DistributedDataParallel or DataParallel, method load_state_dict will applied to their internal wrapped model (obj.module).

Parameters
  • to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}

  • checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.

  • kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)

Return type

None

load_state_dict(state_dict)[source]

Method replace internal state of the class with provided state dict data.

Parameters

state_dict (Mapping) – a dict with “saved” key and list of (priority, filename) pairs as values.

Return type

None

reset()[source]

Method to reset saved checkpoint names.

Use this method if the engine will independently run multiple times:

from ignite.handlers import Checkpoint

trainer = ...
checkpointer = Checkpoint(...)

trainer.add_event_handler(Events.COMPLETED, checkpointer)
trainer.add_event_handler(Events.STARTED, checkpointer.reset)

# fold 0
trainer.run(data0, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

# fold 1
trainer.run(data1, max_epochs=max_epochs)
print("Last checkpoint:", checkpointer.last_checkpoint)

New in version 0.4.3.

Return type

None

static setup_filename_pattern(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)[source]

Helper method to get the default filename pattern for a checkpoint.

Parameters
  • with_prefix (bool) – If True, the filename_prefix is added to the filename pattern: {filename_prefix}_{name}.... Default, True.

  • with_score (bool) – If True, score is added to the filename pattern: ..._{score}.{ext}. Default, True. At least one of with_score and with_global_step should be True.

  • with_score_name (bool) – If True, score_name is added to the filename pattern: ..._{score_name}={score}.{ext}. If activated, argument with_score should be also True, otherwise an error is raised. Default, True.

  • with_global_step (bool) – If True, {global_step} is added to the filename pattern: ...{name}_{global_step}.... At least one of with_score and with_global_step should be True.

Return type

str

Example

from ignite.handlers import Checkpoint

filename_pattern = Checkpoint.setup_filename_pattern()

print(filename_pattern)
> "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"

New in version 0.4.3.

state_dict()[source]

Method returns state dict with saved items: list of (priority, filename) pairs. Can be used to save internal state of the class.

Return type

OrderedDict[str, List[Tuple[int, str]]]