Checkpoint¶
-
class
ignite.handlers.checkpoint.
Checkpoint
(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, filename_pattern=None, include_self=False, greater_or_equal=False)[source]¶ Checkpoint handler can be used to periodically save and load objects which have attribute
state_dict/load_state_dict
. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used withDiskSaver
) also handles automatically moving data on TPU to CPU before writing the checkpoint.- Parameters
to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented
state_dict
andload_state_dict
methods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional keymodule.
in the state dictionary).save_handler (Union[Callable, ignite.handlers.checkpoint.BaseSaveHandler]) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If
save_handler
is callable class, it can inherit ofBaseSaveHandler
and optionally implementremove
method to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk,save_handler
can be defined withDiskSaver
.filename_prefix (str) – Prefix for the file name to which objects will be saved. See Note for details.
score_function (Optional[Callable]) – If not None, it should be a function taking a single argument,
Engine
object, and returning a score (float). Objects with highest scores will be retained.score_name (Optional[str]) – If
score_function
not None, it is possible to store its value usingscore_name
. See Notes for more details.n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept.
global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is
(engine, event_name)
. Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please useglobal_step_from_engine()
.filename_pattern (Optional[str]) – If
filename_pattern
is provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in
to_save
with keycheckpointer
.greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False.
Note
This class stores a single file as a dictionary of provided objects to save. The filename is defined by
filename_pattern
and by default has the following structure:{filename_prefix}_{name}_{suffix}.{ext}
wherefilename_prefix
is the argument passed to the constructor,name is the key in
to_save
if a single object is to store, otherwise name is “checkpoint”.suffix is composed as following
{global_step}_{score_name}={score}
.
score_function
score_name
global_step_transform
suffix
None
None
None
{engine.state.iteration}
X
None
None
{score}
X
None
X
{global_step}_{score}
X
X
X
{global_step}_{score_name}={score}
None
None
X
{global_step}
X
X
None
{score_name}={score}
Above global_step defined by the output of global_step_transform and score defined by the output of score_function.
By default, none of
score_function
,score_name
,global_step_transform
is defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.For example,
score_name="neg_val_loss"
andscore_function
that returns -loss (as objects with highest scores will be retained), then saved filename will be{filename_prefix}_{name}_neg_val_loss=-0.1234.pt
.Note
If
filename_pattern
is given, it will be used to render the filenames.filename_pattern
is a string that can contain{filename_prefix}
,{name}
,{score}
,{score_name}
and{global_step}
as templates.For example, let
filename_pattern="{global_step}-{name}-{score}.pt"
then the saved filename will be30000-checkpoint-94.pt
Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will replace the older one. This means that filename like
checkpoint.pt
will be saved every call and will always be overwritten by newer checkpoints.Note
To get the last stored filename, handler exposes attribute
last_checkpoint
:handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt
Note
This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with
DiskSaver
, checkpoint is stored by rank 0 process.Warning
When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint.
# Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler)
Examples
Attach the handler to make checkpoints during training:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} if (checkpoint_iters): # A: Output is "checkpoint_<iteration>.pt" handler = Checkpoint( to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2 ) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) else: # B:Output is "checkpoint_<epoch>.pt" gst = lambda *_: trainer.state.epoch handler = Checkpoint( to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, global_step_transform=gst ) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler) trainer.run(data_loader, max_epochs=6) > A: ["checkpoint_7000.pt", "checkpoint_8000.pt", ] > B: ["checkpoint_5.pt", "checkpoint_6.pt", ]
Attach the handler to an evaluator to save best model during the training according to computed validation metric:
from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator # Run evaluation on epoch completed event # ... score_function = Checkpoint.get_default_score_fn("accuracy") to_save = {'model': model} handler = Checkpoint( to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer) ) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ]
-
Changed in version 0.4.3:
Checkpoint can save model with same filename.
Added
greater_or_equal
argument.
Methods
Helper method to get default score function based on the metric name.
Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Method replace internal state of the class with provided state dict data.
Method to reset saved checkpoint names.
Helper method to get the default filename pattern for a checkpoint.
Method returns state dict with saved items: list of
(priority, filename)
pairs.-
class
Item
(priority, filename)¶ Create new instance of Item(priority, filename)
-
property
filename
¶ Alias for field number 1
-
property
priority
¶ Alias for field number 0
-
property
-
static
get_default_score_fn
(metric_name, score_sign=1.0)[source]¶ Helper method to get default score function based on the metric name.
- Parameters
metric_name (str) – metric name to get the value from
engine.state.metrics
. Engine is the one to whichCheckpoint
handler is added.score_sign (float) – sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, a negative score sign should be used (objects with larger score are retained). Default, 1.0.
- Return type
Callable
Exemples:
from ignite.handlers import Checkpoint best_acc_score = Checkpoint.get_default_score_fn("accuracy") best_model_handler = Checkpoint( to_save, save_handler, score_name="val_accuracy", score_function=best_acc_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
Usage with error-like metric:
from ignite.handlers import Checkpoint neg_loss_score = Checkpoint.get_default_score_fn("loss", -1.0) best_model_handler = Checkpoint( to_save, save_handler, score_name="val_neg_loss", score_function=neg_loss_score ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
New in version 0.4.3.
-
static
load_objects
(to_load, checkpoint, **kwargs)[source]¶ Helper method to apply
load_state_dict
on the objects fromto_load
using states fromcheckpoint
.Exemples:
import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Note
If
to_load
contains objects of type torch DistributedDataParallel or DataParallel, methodload_state_dict
will applied to their internal wrapped model (obj.module
).- Parameters
to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …}
checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict.
kwargs (Any) – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning)
- Return type
-
load_state_dict
(state_dict)[source]¶ Method replace internal state of the class with provided state dict data.
- Parameters
state_dict (Mapping) – a dict with “saved” key and list of
(priority, filename)
pairs as values.- Return type
-
reset
()[source]¶ Method to reset saved checkpoint names.
Use this method if the engine will independently run multiple times:
from ignite.handlers import Checkpoint trainer = ... checkpointer = Checkpoint(...) trainer.add_event_handler(Events.COMPLETED, checkpointer) trainer.add_event_handler(Events.STARTED, checkpointer.reset) # fold 0 trainer.run(data0, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint) # fold 1 trainer.run(data1, max_epochs=max_epochs) print("Last checkpoint:", checkpointer.last_checkpoint)
New in version 0.4.3.
- Return type
-
static
setup_filename_pattern
(with_prefix=True, with_score=True, with_score_name=True, with_global_step=True)[source]¶ Helper method to get the default filename pattern for a checkpoint.
- Parameters
with_prefix (bool) – If True, the
filename_prefix
is added to the filename pattern:{filename_prefix}_{name}...
. Default, True.with_score (bool) – If True,
score
is added to the filename pattern:..._{score}.{ext}
. Default, True. At least one ofwith_score
andwith_global_step
should be True.with_score_name (bool) – If True,
score_name
is added to the filename pattern:..._{score_name}={score}.{ext}
. If activated, argumentwith_score
should be also True, otherwise an error is raised. Default, True.with_global_step (bool) – If True,
{global_step}
is added to the filename pattern:...{name}_{global_step}...
. At least one ofwith_score
andwith_global_step
should be True.
- Return type
Example
from ignite.handlers import Checkpoint filename_pattern = Checkpoint.setup_filename_pattern() print(filename_pattern) > "{filename_prefix}_{name}_{global_step}_{score_name}={score}.{ext}"
New in version 0.4.3.