Summary: We added group split in D78300794 and remote_group_merge in D78450094. We first want to upstream this change to PGNCCLx as well so that NCCLx can use this new API and we can continue our c10d clean up in https://github.com/pytorch/pytorch/pull/158488.
Test Plan:
CI
```
buck test -c hpc_comms.use_ncclx=stable comms/ncclx/pg/tests:test_c10d_ncclx -- test_group_split_and_merge
```
Rollback Plan:
Differential Revision: D78521060
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158790
Approved by: https://github.com/d4l3k
This updates ProcessGroupGloo to support per operation timeouts. Previously the timeouts were ignored even if they were set.
* This checks if the timeout is `kUnsetTimeout` and conditionally uses the provided timeout or the default timeout from the context.
* This exposes `set_timeout` as a standard method on ProcessGroup/Backend so we can test the global timeout.
Test plan:
```
pytest test/distributed/test_c10d_gloo.py -v -k allreduce_timeout
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158128
Approved by: https://github.com/H-Huang, https://github.com/fduwjj
This enables Gloo CUDA when used with a backend that supports GPUDirect which currently is only the IBVERBS backend.
This requires some changes to Gloo which are in https://github.com/pytorch/gloo/pull/441
Since we're now depending on gloo_cuda we need to split ProcessGroupGloo into two pieces, one with the CPU bits (libtorch_cpu) and one with CUDA kernels in libtorch_cuda. This unfortunately requires some major refactoring as some CPU code is shared across both.
The gloo submodule is updated to depend on the new Gloo changes
Test plan:
```py
import os
import time
transport = "TCP"
#transport = "IBVERBS"
os.environ["GLOO_DEVICE_TRANSPORT"] = transport
rank = int(os.environ["RANK"])
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
ibv = "mlx5_0:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_9:1,mlx5_10:1,mlx5_11:1".split(",")[rank]
ibv_name, ibv_port = ibv.split(":")
os.environ["TORCH_GLOO_IBV_NAME"] = ibv_name
os.environ["TORCH_GLOO_IBV_PORT"] = ibv_port
os.environ["TORCH_GLOO_IBV_INDEX"] = "3"
import torch
import torch.distributed as dist
dist.init_process_group("gloo")
rank = dist.get_rank()
# initial sanity check
#device = "cpu"
#t = torch.zeros(10, device=device)
#dist.all_reduce(t)
#print("sanity complete")
device = "cpu"
iters = 10
warmup_iters = 2
for nelem in [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000]:
t = torch.zeros(nelem, device=device)
torch.cuda.current_stream().synchronize()
for i in range(warmup_iters):
dist.all_reduce(t)
torch.cuda.current_stream().synchronize()
start = time.perf_counter()
for i in range(iters):
dist.all_reduce(t)
torch.cuda.current_stream().synchronize()
dur = (time.perf_counter() - start)
qps = iters/dur
bandwidth_gb = t.nbytes * iters / dur / 1e9
gb = t.nbytes / 1e9
if rank == 0:
print(f"{transport=} {device=} {iters=} {nelem=} {qps=} {gb=} {bandwidth_gb=}\n", end="")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153406
Approved by: https://github.com/fduwjj
This adds lazy initialization support to ProcessGroupGloo via `TORCH_GLOO_LAZY_INIT` or via `create_device(..., lazy_init=True)`
This is still a draft PR as there's one race condition when doing coalesced operations that needs to be fixed upstream in Gloo first. Depends on https://github.com/facebookincubator/gloo/pull/427 landing first
This also updates the gloo submodule to include the required changes.
Test plan:
added lazy init test variants
```
pytest -v test/distributed/test_c10d_gloo.py -k Lazy
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150801
Approved by: https://github.com/fduwjj
Summary:
X-link: https://github.com/facebookincubator/gloo/pull/423
This modifies `connectFullMesh` to take in a shared_ptr<IStore> instead of a reference. This is an API breaking change but fairly easy to work around.
To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo.
This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic.
This change will land simultaneously in PyTorch and Gloo repos.
Test Plan:
```
buck2 test //gloo/... //caffe2/caffe2/contrib/gloo:
```
Differential Revision: D72084111
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150230
Approved by: https://github.com/fduwjj
This adds a `reduce_scatter` implementation for ProcessGroupGloo. This is a pretty naive implementation as it does 1 allreduce per rank but may be useful for testing in FSDP etc. There was an existing implementation of reduce_scatter_tensor/reduce_scatter_tensor_coalesed that has a very similar implementation but requires a fixed tensor size per rank.
If users find these functions to be too slow we can address them as issues arise.
Gloo now supports all major distributed operations. Quite a few of these were added by @rohan-varma and @yifuwang but they didn't update the support chart. We also have `CUDAWork` variants of most operations so those were also added to the chart.
Test plan:
```
pytest -v test/distributed/test_c10d_gloo.py -k reduce_scatter
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149869
Approved by: https://github.com/fduwjj
This adds AVG support to ProcessGroupGloo to better support FSDP on CPU. I expect there will be more issues but this is easy enough to support in a naive fashion.
This applies to both reduce and allreduce.
This is a simple SUM + division and may not be the most numerically stable but that's expected. FSDP for low precision data types implements pre/post divide and uses SUM instead.
Test plan:
```
pytest -v test/distributed/test_c10d_gloo.py
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149781
Approved by: https://github.com/fduwjj
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
y = x * x
req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
assert isinstance(req, torch.distributed.Work)
return y
@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
torch.ops.c10d_functional.wait_tensor(y)
return y * y
x = torch.ones(1280, 1280, device="cuda") + self.rank
with allow_inflight_collective_as_graph_input_ctx():
y = all_reduce_eager(x)
z = all_reduce_wait_compiled(y)
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.
This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.
----
**Update**: Did two items to prevent regression to existing use cases:
1. Added memory-stressed test case to test_c10d_nccl.py `test_unwaited` to cover existing user's "not calling work.wait() for non-functional collective" use case
2. Gated all new `register_work()` / `unregister_work()` calls with `c10d::allow_inflight_collective_as_graph_input()` check, which is a new context manager that requires explicit user enablement (i.e. not on by default, so should not affect existing users).
The risk of this new version of PR causing regression should be very low.
------
Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_wait_tensor`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`
------
Differential Revision: [D65023311](https://our.internmc.facebook.com/intern/diff/D65023311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
This adds logs if we can't acquire locks in NCCLUtils and ProcessGroupNCCL for 30s.
This is motivated by some deadlocks were seeing and it's unclear if it's in NCCL or on the PyTorch side of things.
This required replacing most `std::mutex` with `std::timed_mutex` and `std::condition_variable_any` as appropriate.
Test plan:
existing CI for regressions
will add unit tests on `C10D_LOCK_GUARD`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134131
Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj
This adds logs if we can't acquire locks in NCCLUtils and ProcessGroupNCCL for 30s.
This is motivated by some deadlocks were seeing and it's unclear if it's in NCCL or on the PyTorch side of things.
This required replacing most `std::mutex` with `std::timed_mutex` and `std::condition_variable_any` as appropriate.
Test plan:
existing CI for regressions
will add unit tests on `C10D_LOCK_GUARD`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134131
Approved by: https://github.com/c-p-i-o, https://github.com/fduwjj
This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following #124701. In addition, libfmt dependency is added in CMake code to enable using it in the headers. The libfmt has to be added as private dependency to torch_cuda and torch_hip because they include torch/csrc/distributed/c10d/Utils.hpp which uses libfmt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124987
Approved by: https://github.com/malfet
### Motivation
Despite our plan to reduce gloo usage, it is still being widely used as testing tool (in both the PyTorch CI and user tests) for code that only uses nccl in real world scenario. There's some coverage issues around all-gather and reduce-scatter variants, which are currently worked around in ugly ways (e.g. [this](b9e86bc93d/torch/distributed/_functional_collectives_impl.py (L216-L219)) and [this](b9e86bc93d/torch/distributed/_functional_collectives_impl.py (L262-L272))). For native funcol I ran into the same issues but I'd rather just fix the coverage.
### This PR
We already have a fallback impl for `_reduce_scatter_base`, which is composed from all-reduce + scatter. The scatter was not necessary. It introduces extra communication, sync point, and forced the impl to fail on `asyncOp=True`. This PR does the following:
- Simulate reduce-scatter with `allreduce(inp).chunk(world_size)[rank]`. This is still 2x communication than a real reduce-scatter (since all-reduce = reduce-scatter + all-gather), but it's strictly better than what we have now.
- By doing the above, the comm becomes async and we don't have to fail on `asyncOp=True`.
- The general logic is implemented in `reduce_scatter_tensor_coalesced`. `_reduce_scatter_base` just calls it with single input/output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118911
Approved by: https://github.com/shuqiangzhang
ghstack dependencies: #118910
### Motivation
Despite our plan to reduce gloo usage, it is still being widely used as testing tool (in both the PyTorch CI and user tests) for code that only uses nccl in real world scenario. There's some coverage issues around all-gather and reduce-scatter variants, which are currently worked around in ugly ways (e.g. [this](b9e86bc93d/torch/distributed/_functional_collectives_impl.py (L216-L219)) and [this](b9e86bc93d/torch/distributed/_functional_collectives_impl.py (L262-L272))). For native funcol I ran into the same issues but I'd rather just fix the coverage.
**I think it's reasonable to think of this as a fix rather than adding new features. This is orthogonal to the potential reduction of gloo usage**.
### This PR
This PR adds `ProcessGroupGloo::allgather_into_tensor_coalesced`. This is very straightforward - `ProcessGroupGloo` already supports `allgather_coalesced`, to which we can funnel `allgather_into_tensor_coalesced`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118910
Approved by: https://github.com/shuqiangzhang
There was missing support for bfloat scalars. When I use gloo backend
`torch.distributed.init_process_group(backend='gloo')`
and run
`torch.nn.parallel.DistributedDataParallel(model)`
and _model_ has Bfloat16 features I receive following error:
`RuntimeError: Invalid scalar type`
This change fix this issue.
c10::BFloat16 defines conversions from/to float, so calculations are made on float for bfloat.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113557
Approved by: https://github.com/XilunWu, https://github.com/jgong5
Fixes#111422
allreduce_sparse_cuda gets dispatched to allreduce_sparse which doesnt exist for gloo. However, gloo has an existing implementation so this is just fixing the dispatching to that.
The reason CI didn't catch this is because we are calling the backend directly. Added a test which calls the public API (dist.XYZ) and goes through the dispatcher
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111485
Approved by: https://github.com/fduwjj
This reverts commit 314a502eb04c6382e2cc9af0573533efba54109d.
Changes since original PR:
Reland 1
* rename torch.distributed.hooks to torch.distributed._hooks
Reland 2
* make _hooks importable even if !distributed.is_available()
* handle cuda driver exit intermittent failure caused by new cuda api usage in callback caller (see prev PR in stack)
(original PR https://github.com/pytorch/pytorch/pull/108815 desc copied below)
Expose a set of observability hooks into C10D such that our users can
detect collectives failure both faster and more easily.
The design is similar to NCCL desync debug that it minimized the
overhead by doing most of the work out of the main thread.
This PR introduces a new module torch.distributed.hooks that exposes the following set of methods:
register_collective_start_hook
register_collective_end_hook
register_process_group_hook
The process group hook exposes PG creation on the member ranks and call them inline from the
the PG creation code. This is fine since this happens during initialization and a limited number of times.
The collective start/end hooks are fired from a single background thread. It reads
events from a C++ queue and dispatches over.
Queue notification is oddly done using a pipe, this is needed so python can abort the thread on shutdown
and have it as background thread. This is not possible with more reasonable choices like a condvar.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111072
Approved by: https://github.com/malfet
ghstack dependencies: #111061