import socket
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
import torch
from ignite.distributed.comp_models import (
    _SerialModel,
    has_hvd_support,
    has_native_dist_support,
    has_xla_support,
    registered_computation_models,
)
from ignite.utils import setup_logger
__all__ = [
    "backend",
    "broadcast",
    "device",
    "available_backends",
    "model_name",
    "get_world_size",
    "get_rank",
    "get_local_rank",
    "get_nproc_per_node",
    "get_node_rank",
    "get_nnodes",
    "spawn",
    "initialize",
    "finalize",
    "show_config",
    "set_local_rank",
    "all_reduce",
    "all_gather",
    "barrier",
    "hostname",
    "has_xla_support",
    "has_native_dist_support",
    "has_hvd_support",
    "sync",
    "registered_computation_models",
    "one_rank_only",
]
_model = _SerialModel()
_need_to_sync = True
[docs]def sync(temporary: bool = False) -> None:
    """Helper method to force this module to synchronize with current distributed context.
    This method should be used when distributed context is manually created or destroyed.
    Args:
        temporary: If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
            This may have a negative performance impact.
    """
    global _model
    for comp_model_cls in registered_computation_models:
        if comp_model_cls == _SerialModel:
            continue
        model = comp_model_cls.create_from_context()
        if model is not None:
            _set_model(model, temporary=temporary)
            return
    _model = _SerialModel() 
[docs]def device() -> torch.device:
    """Returns current device according to current distributed configuration.
    - `torch.device("cpu")` if no distributed configuration or torch native gloo distributed configuration
    - `torch.device("cuda:local_rank")` if torch native nccl or horovod distributed configuration
    - `torch.device("xla:index")` if XLA distributed configuration
    Returns:
        torch.device
    .. versionchanged:: 0.4.2
        Added Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.device() 
[docs]def backend() -> Optional[str]:
    """Returns computation model's backend.
    - `None` for no distributed configuration
    - "nccl" or "gloo" or "mpi" for native torch distributed configuration
    - "xla-tpu" for XLA distributed configuration
    - "horovod" for Horovod distributed framework
    Returns:
        str or None
    .. versionchanged:: 0.4.2
        Added Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.backend() 
[docs]def available_backends() -> Tuple[str, ...]:
    """Returns available backends."""
    out = ()  # type: Tuple[str, ...]
    for m in registered_computation_models:
        out += m.available_backends
    return out 
[docs]def model_name() -> str:
    """Returns distributed configuration name (given by ignite)
    - `serial` for no distributed configuration
    - `native-dist` for native torch distributed configuration
    - `xla-dist` for XLA distributed configuration
    - `horovod-dist` for Horovod distributed framework
    .. versionchanged:: 0.4.2
        `horovod-dist` will be returned for Horovod distributed framework.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.name 
[docs]def get_world_size() -> int:
    """Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_world_size() 
[docs]def get_rank() -> int:
    """Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_rank() 
[docs]def get_local_rank() -> int:
    """Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_local_rank() 
[docs]def get_nproc_per_node() -> int:
    """Returns number of processes (or tasks) per node within current distributed configuration.
    Returns 1 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_nproc_per_node() 
[docs]def get_nnodes() -> int:
    """Returns number of nodes within current distributed configuration.
    Returns 1 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_nnodes() 
[docs]def get_node_rank() -> int:
    """Returns node rank within current distributed configuration.
    Returns 0 if no distributed configuration.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.get_node_rank() 
[docs]def hostname() -> str:
    """Returns host name for current process within current distributed configuration.
    """
    return socket.gethostname() 
[docs]def spawn(
    backend: str,
    fn: Callable,
    args: Tuple,
    kwargs_dict: Optional[Mapping] = None,
    nproc_per_node: int = 1,
    **kwargs: Any,
) -> None:
    """Spawns ``nproc_per_node`` processes that run ``fn`` with ``args``/``kwargs_dict`` and initialize
    distributed configuration defined by ``backend``.
    Examples:
        1) Launch single node multi-GPU training using torch native distributed framework
        .. code-block:: python
            # >>> python main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c, d=12):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == 4
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            idist.spawn("nccl", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=4)
        2) Launch multi-node multi-GPU training using torch native distributed framework
        .. code-block:: python
            # >>> (node 0): python main.py --node_rank=0 --nnodes=8 --master_addr=master --master_port=2222
            # >>> (node 1): python main.py --node_rank=1 --nnodes=8 --master_addr=master --master_port=2222
            # >>> ...
            # >>> (node 7): python main.py --node_rank=7 --nnodes=8 --master_addr=master --master_port=2222
            # main.py
            import torch
            import ignite.distributed as idist
            def train_fn(local_rank, nnodes, nproc_per_node):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == nnodes * nproc_per_node
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            idist.spawn(
                "nccl",
                train_fn,
                args=(nnodes, nproc_per_node),
                nproc_per_node=nproc_per_node,
                nnodes=nnodes,
                node_rank=node_rank,
                master_addr=master_addr,
                master_port=master_port
            )
        3) Launch single node multi-TPU training (for example on Google Colab) using PyTorch/XLA
        .. code-block:: python
            # >>> python main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c, d=12):
                import torch_xla.core.xla_model as xm
                assert xm.get_world_size() == 8
                device = idist.device()
                assert "xla" in device.type
            idist.spawn("xla-tpu", train_fn, args=(a, b, c), kwargs_dict={"d": 23}, nproc_per_node=8)
    Args:
        backend: backend to use: `nccl`, `gloo`, `xla-tpu`, `horovod`
        fn: function to called as the entrypoint of the spawned process.
            This function must be defined at the top level of a module so it can be pickled and spawned.
            This is a requirement imposed by multiprocessing. The function is called as ``fn(i, *args, **kwargs_dict)``,
            where `i` is the process index and args is the passed through tuple of arguments.
        args: arguments passed to `fn`.
        kwargs_dict: kwargs passed to `fn`.
        nproc_per_node: number of processes to spawn on a single node. Default, 1.
        kwargs: acceptable kwargs according to provided backend:
            - | "nccl" or "gloo" : `nnodes` (default, 1), `node_rank` (default, 0), `master_addr`
              | (default, "127.0.0.1"), `master_port` (default, 2222), `timeout` to `dist.init_process_group`_ function
              | and kwargs for `mp.start_processes`_ function.
            - | "xla-tpu" : `nnodes` (default, 1), `node_rank` (default, 0) and kwargs to `xmp.spawn`_ function.
            - | "horovod": `hosts` (default, None) and other kwargs to `hvd_run`_ function. Arguments `nnodes=1`
              | and `node_rank=0` are tolerated and ignored, otherwise an exception is raised.
    .. _dist.init_process_group: https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
    .. _mp.start_processes: https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
    .. _xmp.spawn: http://pytorch.org/xla/release/1.6/index.html#torch_xla.distributed.xla_multiprocessing.spawn
    .. _hvd_run: https://horovod.readthedocs.io/en/latest/api.html#module-horovod.run
    .. versionchanged:: 0.4.2
        ``backend`` now accepts `horovod` distributed framework.
    """
    _assert_backend(backend)
    if kwargs_dict is None:
        kwargs_dict = {}
    for comp_model_cls in registered_computation_models:
        if backend not in comp_model_cls.available_backends:
            continue
        comp_model_cls.spawn(
            fn, args=args, kwargs_dict=kwargs_dict, nproc_per_node=nproc_per_node, backend=backend, **kwargs
        ) 
[docs]def all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM") -> Union[torch.Tensor, float]:
    """Helper method to perform all reduce operation.
    Args:
        tensor: tensor or number to collect across participating processes.
        op: reduction operation, "SUM" by default. Possible values: "SUM", "PRODUCT", "MIN", "MAX", "AND", "OR".
            Horovod backend supports only "SUM", "AVERAGE", "ADASUM", "MIN", "MAX", "PRODUCT".
    Returns:
        torch.Tensor or number
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.all_reduce(tensor, op) 
[docs]def all_gather(tensor: Union[torch.Tensor, float, str]) -> Union[torch.Tensor, float, List[float], List[str]]:
    """Helper method to perform all gather operation.
    Args:
        tensor: tensor or number or str to collect across participating processes.
    Returns:
        torch.Tensor of shape ``(world_size * tensor.shape[0], tensor.shape[1], ...)`` if input is a tensor or
        torch.Tensor of shape ``(world_size, )`` if input is a number or
        List of strings if input is a string
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.all_gather(tensor) 
[docs]def broadcast(tensor: Union[torch.Tensor, float, str], src: int = 0) -> Union[torch.Tensor, float, str]:
    """Helper method to perform broadcast operation.
    Args:
        tensor: tensor or number or str to broadcast to participating processes.
            Make sure to respect dtype of torch tensor input for all processes, otherwise execution will crash.
        src: source rank. Default, 0.
    Returns:
        torch.Tensor or string or number
    Examples:
        .. code-block:: python
            if idist.get_rank() == 0:
                t1 = torch.rand(4, 5, 6, device=idist.device())
                s1 = "abc"
                x = 12.3456
            else:
                t1 = torch.empty(4, 5, 6, device=idist.device())
                s1 = ""
                x = 0.0
            # Broadcast tensor t1 from rank 0 to all processes
            t1 = idist.broadcast(t1, src=0)
            assert isinstance(t1, torch.Tensor)
            # Broadcast string s1 from rank 0 to all processes
            s1 = idist.broadcast(s1, src=0)
            # >>> s1 = "abc"
            # Broadcast float number x from rank 0 to all processes
            x = idist.broadcast(x, src=0)
            # >>> x = 12.3456
    .. versionadded:: 0.4.2
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    return _model.broadcast(tensor, src=src) 
[docs]def barrier() -> None:
    """Helper method to synchronize all processes.
    """
    if _need_to_sync and isinstance(_model, _SerialModel):
        sync(temporary=True)
    _model.barrier() 
[docs]def set_local_rank(index: int) -> None:
    """Method to hint the local rank in case if torch native distributed context is created by user
    without using :meth:`~ignite.distributed.utils.initialize` or :meth:`~ignite.distributed.utils.spawn`.
    Usage:
        User set up torch native distributed process group
        .. code-block:: python
            import ignite.distributed as idist
            def run(local_rank, *args, **kwargs):
                idist.set_local_rank(local_rank)
                # ...
                dist.init_process_group(**dist_info)
                # ...
    Args:
        index: local rank or current process index
    """
    from ignite.distributed.comp_models.base import ComputationModel
    ComputationModel._ext_local_rank = index 
def _set_model(model: Any, temporary: bool = False) -> None:
    global _model, _need_to_sync
    _model = model
    _need_to_sync = True
    if not isinstance(_model, _SerialModel) and not temporary:
        _need_to_sync = False
def _assert_backend(backend: str) -> None:
    backends = available_backends()
    if backend not in backends:
        raise ValueError(f"Backend should be one of '{backends}'")
[docs]def initialize(backend: str, **kwargs: Any) -> None:
    """Initializes distributed configuration according to provided ``backend``
    Examples:
        Launch single node multi-GPU training with ``torch.distributed.launch`` utility.
        .. code-block:: python
            # >>> python -m torch.distributed.launch --nproc_per_node=4 main.py
            # main.py
            import ignite.distributed as idist
            def train_fn(local_rank, a, b, c):
                import torch.distributed as dist
                assert dist.is_available() and dist.is_initialized()
                assert dist.get_world_size() == 4
                device = idist.device()
                assert device == torch.device(f"cuda:{local_rank}")
            idist.initialize("nccl")
            local_rank = idist.get_local_rank()
            train_fn(local_rank, a, b, c)
            idist.finalize()
    Args:
        backend: backend: `nccl`, `gloo`, `xla-tpu`, `horovod`.
        kwargs: acceptable kwargs according to provided backend:
            - "nccl" or "gloo" : timeout(=timedelta(minutes=30)).
            - "horovod" : comm(=None), more info: `hvd_init`_.
    .. _hvd_init: https://horovod.readthedocs.io/en/latest/api.html#horovod.torch.init
    .. versionchanged:: 0.4.2
        ``backend`` now accepts `horovod` distributed framework.
    """
    if not (has_xla_support or has_native_dist_support or has_hvd_support):
        # nothing to do => serial model
        # maybe warn about this
        return
    _assert_backend(backend)
    for comp_model_cls in registered_computation_models:
        if backend not in comp_model_cls.available_backends:
            continue
        _set_model(comp_model_cls(backend, **kwargs)) 
[docs]def finalize() -> None:
    """Finalizes distributed configuration. For example, in case of native pytorch distributed configuration,
    it calls ``dist.destroy_process_group()``.
    """
    _model.finalize()
    _set_model(_SerialModel()) 
[docs]def show_config() -> None:
    """Helper method to display distributed configuration via ``logging``.
    """
    # setup parallel logger
    logger = setup_logger(__name__)
    logger.info(f"distributed configuration: {model_name()}")
    logger.info(f"backend: {backend()}")
    logger.info(f"device: {device().type}")
    logger.info(f"hostname: {hostname()}")
    logger.info(f"world size: {get_world_size()}")
    logger.info(f"rank: {get_rank()}")
    logger.info(f"local rank: {get_local_rank()}")
    logger.info(f"num processes per_node: {get_nproc_per_node()}")
    logger.info(f"num nodes: {get_nnodes()}")
    logger.info(f"node rank: {get_node_rank()}") 
[docs]def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Callable:
    """Decorator to filter handlers wrt a rank number
    Args:
        rank: rank number of the handler (default: 0).
        with_barrier: synchronisation with a barrier (default: False).
    .. code-block:: python
        engine = ...
        @engine.on(...)
        @one_rank_only() # means @one_rank_only(rank=0)
        def some_handler(_):
            ...
        @engine.on(...)
        @one_rank_only(rank=1)
        def some_handler(_):
            ...
    """
    def _one_rank_only(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]:
            ret = None
            if get_rank() == rank:
                ret = func(*args, **kwargs)
            if with_barrier:
                barrier()
            return ret
        return wrapper
    return _one_rank_only