"""WandB logger and its helper handlers."""
from typing import Any, Callable, List, Optional, Union
from torch.optim import Optimizer
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, Events
from ignite.handlers import global_step_from_engine
__all__ = ["WandBLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"]
[docs]class WandBLogger(BaseLogger):
"""`Weights & Biases <https://wandb.ai/site>`_ handler to log metrics, model/optimizer parameters, gradients
during training and validation. It can also be used to log model checkpoints to the Weights & Biases cloud.
.. code-block:: bash
pip install wandb
This class is also a wrapper for the wandb module. This means that you can call any wandb function using
this wrapper. See examples on how to save model parameters and gradients.
Args:
args: Positional arguments accepted by `wandb.init`.
kwargs: Keyword arguments accepted by `wandb.init`.
Please see `wandb.init <https://docs.wandb.ai/library/init>`_ for documentation of possible parameters.
Examples:
.. code-block:: python
from ignite.contrib.handlers.wandb_logger import *
# Create a logger. All parameters are optional. See documentation
# on wandb.init for details.
wandb_logger = WandBLogger(
entity="shared",
project="pytorch-ignite-integration",
name="cnn-mnist",
config={"max_epochs": 10},
tags=["pytorch-ignite", "minst"]
)
# Attach the logger to the trainer to log training loss at each iteration
wandb_logger.attach_output_handler(
trainer,
event_name=Events.ITERATION_COMPLETED,
tag="training",
output_transform=lambda loss: {"loss": loss}
)
# Attach the logger to the evaluator on the training dataset and log NLL, Accuracy metrics after each epoch
# We setup `global_step_transform=lambda *_: trainer.state.iteration` to take iteration value
# of the `trainer`:
wandb_logger.attach_output_handler(
train_evaluator,
event_name=Events.EPOCH_COMPLETED,
tag="training",
metric_names=["nll", "accuracy"],
global_step_transform=lambda *_: trainer.state.iteration,
)
# Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after
# each epoch. We setup `global_step_transform=lambda *_: trainer.state.iteration` to take iteration value
# of the `trainer` instead of `evaluator`.
wandb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag="validation",
metric_names=["nll", "accuracy"],
global_step_transform=lambda *_: trainer.state.iteration,
)
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
wandb_logger.attach_opt_params_handler(
trainer,
event_name=Events.ITERATION_STARTED,
optimizer=optimizer,
param_name='lr' # optional
)
If you want to log model gradients, the model call graph, etc., use the logger as wrapper of wandb. Refer
to the documentation of wandb.watch for details:
.. code-block:: python
wandb_logger = WandBLogger(
entity="shared",
project="pytorch-ignite-integration",
name="cnn-mnist",
config={"max_epochs": 10},
tags=["pytorch-ignite", "minst"]
)
model = torch.nn.Sequential(...)
wandb_logger.watch(model)
For model checkpointing, Weights & Biases creates a local run dir, and automatically synchronizes all
files saved there at the end of the run. You can just use the `wandb_logger.run.dir` as path for the
`ModelCheckpoint`:
.. code-block:: python
from ignite.handlers import ModelCheckpoint
def score_function(engine):
return engine.state.metrics['accuracy']
model_checkpoint = ModelCheckpoint(
wandb_logger.run.dir, n_saved=2, filename_prefix='best',
require_empty=False, score_function=score_function,
score_name="validation_accuracy",
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {'model': model})
"""
def __init__(self, *args: Any, **kwargs: Any):
try:
import wandb
self._wandb = wandb
except ImportError:
raise RuntimeError(
"This contrib module requires wandb to be installed. "
"You man install wandb with the command:\n pip install wandb\n"
)
if kwargs.get("init", True):
wandb.init(*args, **kwargs)
def __getattr__(self, attr: Any) -> Any:
return getattr(self._wandb, attr) # type: ignore[misc]
def _create_output_handler(self, *args: Any, **kwargs: Any) -> "OutputHandler":
return OutputHandler(*args, **kwargs)
def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> "OptimizerParamsHandler":
return OptimizerParamsHandler(*args, **kwargs)
[docs]class OutputHandler(BaseOutputHandler):
"""Helper handler to log engine's output and/or metrics
Examples:
.. code-block:: python
from ignite.contrib.handlers.wandb_logger import *
# Create a logger. All parameters are optional. See documentation
# on wandb.init for details.
wandb_logger = WandBLogger(
entity="shared",
project="pytorch-ignite-integration",
name="cnn-mnist",
config={"max_epochs": 10},
tags=["pytorch-ignite", "minst"]
)
# Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after
# each epoch. We setup `global_step_transform=lambda *_: trainer.state.iteration,` to take iteration value
# of the `trainer`:
wandb_logger.attach(
evaluator,
log_handler=OutputHandler(
tag="validation",
metric_names=["nll", "accuracy"],
global_step_transform=lambda *_: trainer.state.iteration,
),
event_name=Events.EPOCH_COMPLETED
)
# or equivalently
wandb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag="validation",
metric_names=["nll", "accuracy"],
global_step_transform=lambda *_: trainer.state.iteration,
)
Another example, where model is evaluated every 500 iterations:
.. code-block:: python
from ignite.contrib.handlers.wandb_logger import *
@trainer.on(Events.ITERATION_COMPLETED(every=500))
def evaluate(engine):
evaluator.run(validation_set, max_epochs=1)
# Create a logger. All parameters are optional. See documentation
# on wandb.init for details.
wandb_logger = WandBLogger(
entity="shared",
project="pytorch-ignite-integration",
name="cnn-mnist",
config={"max_epochs": 10},
tags=["pytorch-ignite", "minst"]
)
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
# Attach the logger to the evaluator on the validation dataset and log NLL, Accuracy metrics after
# every 500 iterations. Since evaluator engine does not have access to the training iteration, we
# provide a global_step_transform to return the trainer.state.iteration for the global_step, each time
# evaluator metrics are plotted on Weights & Biases.
wandb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag="validation",
metrics=["nll", "accuracy"],
global_step_transform=global_step_transform
)
Args:
tag: common title for all produced plots. For example, "training"
metric_names: list of metric names to plot or a string "all" to plot all available
metrics.
output_transform: output transform function to prepare `engine.state.output` as a number.
For example, `output_transform = lambda output: output`
This function can also return a dictionary, e.g `{"loss": loss1, "another_loss": loss2}` to label the plot
with corresponding keys.
global_step_transform: global step transform function to output a desired global step.
Input of the function is `(engine, event_name)`. Output of function should be an integer.
Default is None, global_step based on attached engine. If provided,
uses function output as global_step. To setup global step from another engine, please use
:meth:`~ignite.contrib.handlers.wandb_logger.global_step_from_engine`.
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
the default value of wandb.log.
Note:
Example of `global_step_transform`:
.. code-block:: python
def global_step_transform(engine, event_name):
return engine.state.get_event_attrib_value(event_name)
"""
def __init__(
self,
tag: str,
metric_names: Optional[List[str]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
sync: Optional[bool] = None,
):
super().__init__(tag, metric_names, output_transform, global_step_transform)
self.sync = sync
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, WandBLogger):
raise RuntimeError(f"Handler '{self.__class__.__name__}' works only with WandBLogger.")
global_step = self.global_step_transform(engine, event_name) # type: ignore[misc]
if not isinstance(global_step, int):
raise TypeError(
f"global_step must be int, got {type(global_step)}."
" Please check the output of global_step_transform."
)
metrics = self._setup_output_metrics(engine)
if self.tag is not None:
metrics = {f"{self.tag}/{name}": value for name, value in metrics.items()}
logger.log(metrics, step=global_step, sync=self.sync)
[docs]class OptimizerParamsHandler(BaseOptimizerParamsHandler):
"""Helper handler to log optimizer parameters
Examples:
.. code-block:: python
from ignite.contrib.handlers.wandb_logger import *
# Create a logger. All parameters are optional. See documentation
# on wandb.init for details.
wandb_logger = WandBLogger(
entity="shared",
project="pytorch-ignite-integration",
name="cnn-mnist",
config={"max_epochs": 10},
tags=["pytorch-ignite", "minst"]
)
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
wandb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer),
event_name=Events.ITERATION_STARTED
)
# or equivalently
wandb_logger.attach_opt_params_handler(
trainer,
event_name=Events.ITERATION_STARTED,
optimizer=optimizer
)
Args:
optimizer: torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name: parameter name
tag: common title for all produced plots. For example, "generator"
sync: If set to False, process calls to log in a seperate thread. Default (None) uses whatever
the default value of wandb.log.
"""
def __init__(
self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None, sync: Optional[bool] = None,
):
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
self.sync = sync
def __call__(self, engine: Engine, logger: WandBLogger, event_name: Union[str, Events]) -> None:
if not isinstance(logger, WandBLogger):
raise RuntimeError("Handler OptimizerParamsHandler works only with WandBLogger")
global_step = engine.state.get_event_attrib_value(event_name)
tag_prefix = f"{self.tag}/" if self.tag else ""
params = {
f"{tag_prefix}{self.param_name}/group_{i}": float(param_group[self.param_name])
for i, param_group in enumerate(self.optimizer.param_groups)
}
logger.log(params, step=global_step, sync=self.sync)