ignite.distributed¶
Helper module to use distributed settings for multiple backends:
backends from native torch distributed configuration: “nccl”, “gloo”, “mpi”
XLA on TPUs via pytorch/xla
using Horovod framework as a backend
Distributed launcher and auto helpers¶
We provide a context manager to simplify the code of distributed configuration setup for all above supported backends.
In addition, methods like auto_model()
, auto_optim()
and
auto_dataloader()
helps to adapt in a transparent way provided model, optimizer and data
loaders to existing configuration:
# main.py
import ignite.distributed as idist
def training(local_rank, config, **kwargs):
print(idist.get_rank(), ": run with config:", config, "- backend=", idist.backend())
train_loader = idist.auto_dataloader(dataset, batch_size=32, num_workers=12, shuffle=True, **kwargs)
# batch size, num_workers and sampler are automatically adapted to existing configuration
# ...
model = resnet50()
model = idist.auto_model(model)
# model is DDP or DP or just itself according to existing configuration
# ...
optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = idist.auto_optim(optimizer)
# optimizer is itself, except XLA configuration and overrides `step()` method.
# User can safely call `optimizer.step()` (behind `xm.optimizer_step(optimizier)` is performed)
backend = "nccl" # torch native distributed configuration on multiple GPUs
# backend = "xla-tpu" # XLA TPUs distributed configuration
# backend = None # no distributed configuration
with idist.Parallel(backend=backend, **dist_configs) as parallel:
parallel.run(training, config, a=1, b=2)
Above code may be executed with torch.distributed.launch tool or by python and specifying distributed configuration
in the code. For more details, please, see Parallel
,
auto_model()
, auto_optim()
and
auto_dataloader()
.
Complete example of CIFAR10 training can be found here.
ignite.distributed.auto¶
Distributed sampler proxy to adapt user’s sampler for distributed data parallelism configuration. |
|
Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt provided model for non-distributed and distributed configurations (supporting all available backends from |
|
Helper method to adapt optimizer for non-distributed and distributed configurations (supporting all available backends from |
ignite.distributed.launcher¶
Distributed launcher context manager to simplify distributed configuration setup for multiple backends: |
ignite.distributed.utils¶
This module wraps common methods to fetch information about distributed configuration, initialize/finalize process group or spawn multiple processes.
Returns computation model’s backend. |
|
Helper method to perform broadcast operation. |
|
Returns current device according to current distributed configuration. |
|
Returns available backends. |
|
Returns distributed configuration name (given by ignite) |
|
Returns world size of current distributed configuration. |
|
Returns process rank within current distributed configuration. |
|
Returns local process rank within current distributed configuration. |
|
Returns number of processes (or tasks) per node within current distributed configuration. |
|
Returns node rank within current distributed configuration. |
|
Returns number of nodes within current distributed configuration. |
|
Spawns |
|
Initializes distributed configuration according to provided |
|
Finalizes distributed configuration. |
|
Helper method to display distributed configuration via |
|
Method to hint the local rank in case if torch native distributed context is created by user without using |
|
Helper method to perform all reduce operation. |
|
Helper method to perform all gather operation. |
|
Helper method to synchronize all processes. |
|
Returns host name for current process within current distributed configuration. |
|
Helper method to force this module to synchronize with current distributed context. |
|
Decorator to filter handlers wrt a rank number |
-
ignite.distributed.utils.
has_native_dist_support
¶ True if torch.distributed is available
-
ignite.distributed.utils.
has_xla_support
¶ True if torch_xla package is found
-
ignite.distributed.utils.
all_gather
(tensor)[source]¶ Helper method to perform all gather operation.
- Parameters
tensor (Union[torch.Tensor, float, str]) – tensor or number or str to collect across participating processes.
- Returns
torch.Tensor of shape
(world_size * tensor.shape[0], tensor.shape[1], ...)
if input is a tensor or torch.Tensor of shape(world_size, )
if input is a number or List of strings if input is a string- Return type
Union[torch.Tensor, float, List[float], List[str]]
-
ignite.distributed.utils.
all_reduce
(tensor, op='SUM')[source]¶ Helper method to perform all reduce operation.
- Parameters
tensor (Union[torch.Tensor, float]) – tensor or number to collect across participating processes.
op (str) – reduction operation, “SUM” by default. Possible values: “SUM”, “PRODUCT”, “MIN”, “MAX”, “AND”, “OR”. Horovod backend supports only “SUM”, “AVERAGE”, “ADASUM”, “MIN”, “MAX”, “PRODUCT”.
- Returns
torch.Tensor or number
- Return type
Union[torch.Tensor, float]
-
ignite.distributed.utils.
available_backends
()[source]¶ Returns available backends.
- Return type
Tuple[str, ..]
-
ignite.distributed.utils.
backend
()[source]¶ Returns computation model’s backend.
None for no distributed configuration
“nccl” or “gloo” or “mpi” for native torch distributed configuration
“xla-tpu” for XLA distributed configuration
“horovod” for Horovod distributed framework
- Returns
str or None
- Return type
Optional[str]
Changed in version 0.4.2: Added Horovod distributed framework.
-
ignite.distributed.utils.
broadcast
(tensor, src=0)[source]¶ Helper method to perform broadcast operation.
- Parameters
tensor (Union[torch.Tensor, float, str]) – tensor or number or str to broadcast to participating processes. Make sure to respect dtype of torch tensor input for all processes, otherwise execution will crash.
src (int) – source rank. Default, 0.
- Returns
torch.Tensor or string or number
- Return type
Union[torch.Tensor, float, str]
Examples
if idist.get_rank() == 0: t1 = torch.rand(4, 5, 6, device=idist.device()) s1 = "abc" x = 12.3456 else: t1 = torch.empty(4, 5, 6, device=idist.device()) s1 = "" x = 0.0 # Broadcast tensor t1 from rank 0 to all processes t1 = idist.broadcast(t1, src=0) assert isinstance(t1, torch.Tensor) # Broadcast string s1 from rank 0 to all processes s1 = idist.broadcast(s1, src=0) # >>> s1 = "abc" # Broadcast float number x from rank 0 to all processes x = idist.broadcast(x, src=0) # >>> x = 12.3456
New in version 0.4.2.
-
ignite.distributed.utils.
device
()[source]¶ Returns current device according to current distributed configuration.
torch.device(“cpu”) if no distributed configuration or torch native gloo distributed configuration
torch.device(“cuda:local_rank”) if torch native nccl or horovod distributed configuration
torch.device(“xla:index”) if XLA distributed configuration
- Returns
torch.device
- Return type
torch.device
Changed in version 0.4.2: Added Horovod distributed framework.
-
ignite.distributed.utils.
finalize
()[source]¶ Finalizes distributed configuration. For example, in case of native pytorch distributed configuration, it calls
dist.destroy_process_group()
.- Return type
-
ignite.distributed.utils.
get_local_rank
()[source]¶ Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
get_nnodes
()[source]¶ Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
get_node_rank
()[source]¶ Returns node rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
get_nproc_per_node
()[source]¶ Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
get_rank
()[source]¶ Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
get_world_size
()[source]¶ Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
- Return type
-
ignite.distributed.utils.
hostname
()[source]¶ Returns host name for current process within current distributed configuration.
- Return type
-
ignite.distributed.utils.
initialize
(backend, **kwargs)[source]¶ Initializes distributed configuration according to provided
backend
Examples
Launch single node multi-GPU training with
torch.distributed.launch
utility.# >>> python -m torch.distributed.launch --nproc_per_node=4 main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.initialize("nccl") local_rank = idist.get_local_rank() train_fn(local_rank, a, b, c) idist.finalize()
- Parameters
- Return type
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.
-
ignite.distributed.utils.
model_name
()[source]¶ Returns distributed configuration name (given by ignite)
serial for no distributed configuration
native-dist for native torch distributed configuration
xla-dist for XLA distributed configuration
horovod-dist for Horovod distributed framework
Changed in version 0.4.2: horovod-dist will be returned for Horovod distributed framework.
- Return type
-
ignite.distributed.utils.
one_rank_only
(rank=0, with_barrier=False)[source]¶ Decorator to filter handlers wrt a rank number
- Parameters
- Return type
Callable
engine = ... @engine.on(...) @one_rank_only() # means @one_rank_only(rank=0) def some_handler(_): ... @engine.on(...) @one_rank_only(rank=1) def some_handler(_): ...
-
ignite.distributed.utils.
set_local_rank
(index)[source]¶ Method to hint the local rank in case if torch native distributed context is created by user without using
initialize()
orspawn()
.Usage:
User set up torch native distributed process group
import ignite.distributed as idist def run(local_rank, *args, **kwargs): idist.set_local_rank(local_rank) # ... dist.init_process_group(**dist_info) # ...
-
ignite.distributed.utils.
show_config
()[source]¶ Helper method to display distributed configuration via
logging
.- Return type
-
ignite.distributed.utils.
spawn
(backend, fn, args, kwargs_dict=None, nproc_per_node=1, **kwargs)[source]¶ Spawns
nproc_per_node
processes that runfn
withargs
/kwargs_dict
and initialize distributed configuration defined bybackend
.Examples
Launch single node multi-GPU training using torch native distributed framework
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == 4 device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
Launch multi-node multi-GPU training using torch native distributed framework
# >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222 # >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222 # >>> ... # >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222 # main.py import torch import ignite.distributed as idist def train_fn(local_rank, nnodes, nproc_per_node): import torch.distributed as dist assert dist.is_available() and dist.is_initialized() assert dist.get_world_size() == nnodes * nproc_per_node device = idist.device() assert device == torch.device(f"cuda:{local_rank}") idist.spawn( "nccl", train_fn, args=(nnodes, nproc_per_node), nproc_per_node=nproc_per_node, nnodes=nnodes, node_rank=node_rank, master_addr=master_addr, master_port=master_port )
Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA
# >>> python main.py # main.py import ignite.distributed as idist def train_fn(local_rank, a, b, c, d=12): import torch_xla.core.xla_model as xm assert xm.get_world_size() == 8 device = idist.device() assert "xla" in device.type idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
- Parameters
backend (str) – backend to use: nccl, gloo, xla-tpu, horovod
fn (Callable) – function to called as the entrypoint of the spawned process. This function must be defined at the top level of a module so it can be pickled and spawned. This is a requirement imposed by multiprocessing. The function is called as
fn(i, *args, **kwargs_dict)
, where i is the process index and args is the passed through tuple of arguments.args (Tuple) – arguments passed to fn.
kwargs_dict (Optional[Mapping]) – kwargs passed to fn.
nproc_per_node (int) – number of processes to spawn on a single node. Default, 1.
kwargs (Any) –
acceptable kwargs according to provided backend:
- ”nccl” or “gloo” : nnodes (default, 1), node_rank (default, 0), master_addr(default, “127.0.0.1”), master_port (default, 2222), timeout to dist.init_process_group functionand kwargs for mp.start_processes function.
- ”xla-tpu” : nnodes (default, 1), node_rank (default, 0) and kwargs to xmp.spawn function.
- ”horovod”: hosts (default, None) and other kwargs to hvd_run function. Arguments nnodes=1and node_rank=0 are tolerated and ignored, otherwise an exception is raised.
- Return type
Changed in version 0.4.2:
backend
now accepts horovod distributed framework.