diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index 6b58d7069df4..011154bba326 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -276,7 +276,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass @abstractmethod - def barrier(self) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: pass @@ -357,5 +357,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return tensor - def barrier(self) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: pass diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index dcfc227d1437..fcbdbfb64e38 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,4 +1,5 @@ import warnings +from distutils.version import LooseVersion from typing import Any, Callable, Mapping, Optional, Tuple, cast import torch @@ -6,6 +7,7 @@ from ignite.distributed.comp_models.base import ComputationModel try: + import horovod import horovod.torch as hvd try: @@ -23,6 +25,7 @@ if has_hvd_support: HOROVOD = "horovod" + HOROVOD_VERSION = horovod.__version__ class _HorovodDistModel(ComputationModel): """Private class for `Horovod `_ distributed computation model.""" @@ -192,7 +195,15 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return hvd.broadcast(tensor, root_rank=src) - def barrier(self) -> None: - # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 - # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") - hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier") + def barrier(self, *args: Any, **kwargs: Any) -> None: + if LooseVersion(HOROVOD_VERSION) < LooseVersion("0.23.0"): + if len(args) or len(kwargs): + warnings.warn( + f"Arguments {list(args) + list(kwargs)} are not passed to horovod barrier method. " + f"Please use horovod version>='0.23.0'" + ) + # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 + # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") + hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier") + else: + hvd.barrier(*args, **kwargs) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index 5be4083fa1b5..f7c656b8259e 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -437,8 +437,8 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: dist.broadcast(tensor, src=src) return tensor - def barrier(self) -> None: - dist.barrier() + def barrier(self, *args: Any, **kwargs: Any) -> None: + dist.barrier(*args, **kwargs) def _expand_hostlist(nodelist: str) -> List[str]: """Expand a compressed hostlist string and returns all hosts listed. diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index ebf55240f4f3..d889d251197e 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -159,5 +159,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: xm.all_reduce("sum", [tensor]) return tensor - def barrier(self) -> None: - xm.rendezvous("barrier") + def barrier(self, *args: Any, tag: str = "barrier", **kwargs: Any) -> None: + xm.rendezvous(tag, *args, **kwargs) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index 2b3a1e59fafc..c567e8a2763c 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -418,12 +418,28 @@ def broadcast( return _model.broadcast(tensor, src=src, safe_mode=safe_mode) -def barrier() -> None: - """Helper method to synchronize all processes.""" +def barrier(*args: Any, **kwargs: Any) -> None: + """Helper method to synchronize all processes. + + Args: + args: acceptable args according to provided backend + kwargs: acceptable kwargs according to provided backend + + - | "nccl" or "gloo" : ``group`` (default, GroupMember.WORLD), ``async_op`` (default, False), + | ``device_ids`` (default, None). + + - | "horovod" : for version >= "0.23.0", ``process_set`` (default, global_process_set). + + - | "xla-tpu" : ``tag``, ``payload`` (default, b""), ``replicas`` (default, []). + + .. versionchanged:: 0.5.1 + Method now accepts ``args`` and ``kwargs`` for all supported backends. + + """ if _need_to_sync and isinstance(_model, _SerialModel): sync(temporary=True) - _model.barrier() + _model.barrier(*args, **kwargs) def set_local_rank(index: int) -> None: diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 10c7eef7b69a..d959e4ddd90d 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -215,14 +215,14 @@ def _test(data_src, data_others, safe_mode): idist.broadcast(None, src=0) -def _test_distrib_barrier(device): +def _test_distrib_barrier(device, *args, **kwargs): t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float) true_res = sum([i for i in range(idist.get_world_size())]) if idist.get_rank() == 0: t += 10.0 - idist.barrier() + idist.barrier(*args, **kwargs) tt = idist.all_reduce(t) assert tt.item() == true_res + 10.0 diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index 55ce5ebb7647..015cfcdc6b89 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -255,21 +255,40 @@ def test_idist_broadcast_gloo(distributed_context_single_node_gloo): _test_distrib_broadcast(device) +from torch.distributed import GroupMember + + @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_idist_barrier_nccl(distributed_context_single_node_nccl): +@pytest.mark.parametrize( + "args,kwargs", + [ + ([], {}), + ([GroupMember.WORLD, False], {}), + ([GroupMember.WORLD, True], {}), + ], +) +def test_idist_barrier_nccl(distributed_context_single_node_nccl, args, kwargs): device = idist.device() - _test_distrib_barrier(device) + _test_distrib_barrier(device, *args, **kwargs) @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") -def test_idist_barrier_gloo(distributed_context_single_node_gloo): +@pytest.mark.parametrize( + "args,kwargs", + [ + ([], {}), + ([GroupMember.WORLD, False], {}), + ([GroupMember.WORLD, True], {}), + ], +) +def test_idist_barrier_gloo(distributed_context_single_node_gloo, args, kwargs): device = idist.device() - _test_distrib_barrier(device) + _test_distrib_barrier(device, *args, **kwargs) def _test_idist_methods_overhead(ok_factor): diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index c7dc58c1dc8e..f8a459e878c6 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -178,23 +178,41 @@ def test_idist_broadcast_xla_in_child_proc(xmp_executor): @pytest.mark.tpu @pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") @pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package") -def test_idist_barrier_xla(): +@pytest.mark.parametrize( + "args,kwargs", + [ + ([], {}), + ([b"test_payload", []], {}), + ([b"test_payload", []], {"tag": "test_barrier"}), + ([], {"payload": b"test_payload", "replicas": [], "tag": "test_barrier"}), + ], +) +def test_idist_barrier_xla(args, kwargs): device = idist.device() - _test_distrib_barrier(device) + _test_distrib_barrier(device, *args, **kwargs) -def _test_idist_barrier_xla_in_child_proc(index): +def _test_idist_barrier_xla_in_child_proc(index, args, kwargs): device = idist.device() - _test_distrib_barrier(device) + _test_distrib_barrier(device, *args, **kwargs) @pytest.mark.tpu @pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") @pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package") -def test_idist_barrier_xla_in_child_proc(xmp_executor): +@pytest.mark.parametrize( + "args,kwargs", + [ + ([], {}), + ([b"test_payload", []], {}), + ([b"test_payload", []], {"tag": "test_barrier"}), + ([], {"payload": b"test_payload", "replicas": [], "tag": "test_barrier"}), + ], +) +def test_idist_barrier_xla_in_child_proc(xmp_executor, args, kwargs): n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(), nprocs=n) + xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(args, kwargs), nprocs=n) @pytest.mark.tpu