Commit Graph

54 Commits

Author SHA1 Message Date
cyy
2bcfbf2505 [Distributed] [17/N] Fix clang-tidy warnings in torch/csrc/distributed/ (#138465)
Follows  #137404

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138465
Approved by: https://github.com/ezyang
2024-10-24 04:58:49 +00:00
362ca54f03 [c10d][Partial-Graph Overlap] Support calling .wait_tensor() within compiled region on output tensor of eager async_op=True collective (#137763)
This PR aims to support the following use case:
```python
def all_reduce_eager(x):
    y = x * x
    req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
    assert isinstance(req, torch.distributed.Work)
    return y

@torch.compile(fullgraph=True)
def all_reduce_wait_compiled(y):
    torch.ops.c10d_functional.wait_tensor(y)
    return y * y
```
where the collective is issued in eager (with `async_op=True`) but waited in compiled region.

This is important for internal use cases such as TorchRec, where we issue collectives in eager for SparseArch all_to_all but want to wait for them in compiled region at beginning of OverArch, so that the all_to_all can be overlapped with the DenseArch compute that runs in parallel.

------

Test commands:
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_eager_async_allreduce_inductor_wait`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives`
- `pytest -rA test/test_fx.py::TestDCE::test_keep_collectives_no_overload`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_unwaited`
- `pytest -rA test/distributed/test_c10d_functional_native.py::TestWithNCCL::test_work_registry`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_unwaited`
- `pytest -rA test/distributed/test_c10d_nccl.py::CommTest::test_work_registry`
- `pytest -rA test/distributed/_tensor/test_tensor_ops.py::DistTensorOpsTest::test_equal`
- `pytest -rA test/distributed/_tensor/test_random_ops.py::DistTensorRandomOpTest::test_manual_seed`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_asymmetric_compilation`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_scalar`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_speculation_divergence`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_automatic_dynamic_tensor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_dim_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_graph_break_empty_graph_still_collective`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_scalar_missing_source`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_compiler_collectives_type_mismatch`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_ddp_baseline_aot_eager_multiprocess`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_setattr`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_fsdp_unspecialized_forced_getattr_no_inline`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_aot_eager_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_ddp_inductor_static_graph`
- `pytest -rA test/distributed/test_dynamo_distributed.py::TestMultiProc::test_hf_bert_fsdp_activation_checkpointing`
- `pytest -rA test/distributed/_tensor/test_experimental_ops.py::DistOtherOpsTest::test_bernoulli`
- `pytest -rA test/distributed/_tensor/test_dtensor_compile.py::TestDTensorCompileE2E::test_tp_compile_fullgraph_is_seq_parallel_True`
- `pytest -rA test/distributed/test_inductor_collectives.py::TestCollectivesMultiProc::test_allreduce_inductor_cudagraph_trees`
- `python benchmarks/dynamo/torchbench.py --ci --accuracy --timing --explain --inductor --device cuda --inference --bfloat16 --total-partitions 2 --partition-id 1 --output inference_torchbench.csv --only moco`

------

Differential Revision: [D64511994](https://our.internmc.facebook.com/intern/diff/D64511994)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137763
Approved by: https://github.com/yifuwang
2024-10-21 06:02:57 +00:00
a0c7029a75 [c10d][Reland] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931) (#135653)
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

This is try to reland D62008954 by fixing internal errors.

Differential Revision: [D62483294](https://our.internmc.facebook.com/intern/diff/D62483294/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135653
Approved by: https://github.com/wz337, https://github.com/H-Huang
2024-09-16 19:56:42 +00:00
351ba3e67c Revert "[c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)"
This reverts commit 65864d01341d006955579b145f78547314ceb14b.

Reverted https://github.com/pytorch/pytorch/pull/132931 on behalf of https://github.com/ZainRizvi due to This PR is breaking builds internally due to the removal of ProcessGroup::Options ([comment](https://github.com/pytorch/pytorch/pull/132931#issuecomment-2321862402))
2024-08-30 16:27:40 +00:00
65864d0134 [c10d] Remove Option for ProcessGroup and Expose backend Options to reflect the correct code structure (#132931)
We introduced the dispatchable backend for a ProcessGroup and collective in https://github.com/pytorch/pytorch/issues/86225. This PR is a follow-up cleanup to clean up the option of a ProcessGroup and ask users to either set timeout or backend later on or directly create backend after creating a PG.

Also PGNCCL is using option class from ProcessGroup but we actually should use Option from backend class. So this PR is to make the type or name to be aligned with what we are doing in cpp side. I don't change the signature for the public API, so they still use args named "pg_options"

We need to make changes to the test to make it aligned with the change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132931
Approved by: https://github.com/H-Huang
2024-08-29 22:40:12 +00:00
ed327876f5 [codemod] c10:optional -> std::optional (#126135)
Generated by running the following from PyTorch root:
```
find . -regex ".*\.\(cpp\|h\|cu\|hpp\|cc\|cxx\)$" | grep -v "build/" | xargs -n 50 -P 4 perl -pi -e 's/c10::optional/std::optional/'
```

`c10::optional` is just an alias for `std::optional`. This removes usages of that alias in preparation for eliminating it entirely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126135
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/albanD, https://github.com/aaronenyeshi
2024-05-14 19:35:51 +00:00
cyy
4457cd9a30 [Distributed] [7/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124987)
This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following #124701.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124987
Approved by: https://github.com/malfet
2024-05-11 00:03:52 +00:00
724c7491d0 Revert " [Distributed] [7/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124987)"
This reverts commit b3fd94d15ef49c99ffa32a8226d1f00b0cc26f68.

Reverted https://github.com/pytorch/pytorch/pull/124987 on behalf of https://github.com/ezyang due to broke downstream extensions ([comment](https://github.com/pytorch/pytorch/pull/124987#issuecomment-2083956511))
2024-04-30 00:37:53 +00:00
cyy
b3fd94d15e [Distributed] [7/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124987)
This PR continues to clean clang-tidy warnings in torch/csrc/distributed/c10d, following #124701. In addition, libfmt dependency is added in CMake code to enable using it in the headers. The libfmt has to be added as private dependency to torch_cuda and torch_hip because they include torch/csrc/distributed/c10d/Utils.hpp which uses libfmt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124987
Approved by: https://github.com/malfet
2024-04-27 07:22:27 +00:00
cyy
ea61c9cb29 [Distributed] [5/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124043)
This PR continues to fix some clang-tidy warnings in distributed/c10d code, following https://github.com/pytorch/pytorch/pull/124032.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124043
Approved by: https://github.com/ezyang
2024-04-23 00:43:50 +00:00
cyy
c2596fd3e0 [Distributed] [4/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#124032)
This PR continues to fix some clang-tidy warnings in distributed/c10d code, following https://github.com/pytorch/pytorch/pull/123312.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124032
Approved by: https://github.com/Skylion007
2024-04-16 00:42:18 +00:00
4e9094533e [c10d/nccl-pg] allow user to pass process group description (#123472)
Summary:
We need a way to allow user set a customized description for a process group, e.g. FSDP, PP.

Here are several use cases of user specified group_desc:
- Logging: we can easily match a log line and understand what's this collective/pg is used to.
- Pytorch traces (e.g. Kineto, Execution Trace) can benefit from the PG desc since trace analysis, benchmarks will be able to easily differentiate PG purpose like FSDP, PP.
- Lower layer collectives(e.g. NCCL) debug: we will be able to expose PG desc to NCCL communicator so NCCL layer operations can be easily correlated to a PG.

Solution: Add a group_desc field to c10d

Differential Revision: D55781850

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123472
Approved by: https://github.com/kwen2501
2024-04-12 08:44:21 +00:00
cyy
6d8bb0e984 [Distributed] [1/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#122884)
This PR fixes some clang-tidy warnings in distributed code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122884
Approved by: https://github.com/kwen2501
2024-03-31 09:06:35 +00:00
78b945484b [c10d] Extend NCCL communicator splitting to more use cases (#114916)
Previously we could only use `ncclCommSplit` when we knew all backends were connected on all shards (due to the need to perform a NOCOLOR split), which in practice meant we could only use it for subgroups that were copies of the entire world.

This change allows for specifying a bound device id to `init_process_group` which tells the pg and its backends that the specified device, and the specified device only, will be associated with this rank.

This guarantee lets us do an early connect (which we could not previously do due to how ProcessGroupNCCL infers devices based on tensors and not the rank number).  And by doing the early connect, we have the guarantee ranks are connected and can perform nocolor splits when needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114916
Approved by: https://github.com/kwen2501
2023-12-07 15:13:01 +00:00
bb7ac12cbf [ProcessGroupNCCL] Avoid recording stream for broadcast and scatter (#112896)
Summary: Follows PR #111431, save memory for DTensor init

Test Plan: Sandcastle

Reviewed By: wanchaol

Differential Revision: D50985365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112896
Approved by: https://github.com/wanchaol
2023-11-07 15:44:04 +00:00
ec18ef62f4 Native c10d_functional ops (#110570)
This PR introduces a native version of c10d_functional ops. The main goal is to add collective support in AOTInductor and allow collective ops to work in multi-threaded native runtimes.

The native version also incorporated API improvements we wished to implement in Python c10d_functional:

- Removed `ranks` and `group_size` from collective op signatures which were proven to be redundant.
- Use tensor storage as opposed to `void*` to resolve in-flight work.

The native process group registration/resolution mechansim is only used for native c10d_functional in the PR. It will become the single source of truth in upcoming PRs.

The upcoming PRs will implement Inductor/AOTInductor support for c10d_functional, after which native c10d_functional will replace Python c10d_functional.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110570
Approved by: https://github.com/wanchaol
2023-10-25 22:56:06 +00:00
18cc8a92ac [ProcessGroupNCCL] Avoid recording stream for synchronous ops (#111431)
For synchronous ops (i.e. `asyncOp = False`), we don't want to record streams because we know that the NCCL stream will join back to the "current" stream right after this op. So we might just as well keep the stream ownership of the input/output tensors unchanged. The benefit would be that the allocation/free of the tensors would look deterministic to the "current" stream so that the caching allocator can reuse memory pool for this stream in a clever way.

To prevent the input/output tensors from being recycled by python, we rely on the stashing mechanism in ProcessGroupNCCL (which can be also turned on by setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1`).

This mechanism change is for libraries like FSDP which uses `all_gather_into_tensor` and `reduce_scatter_tensor` in a synchronous way and which cannot set `TORCH_NCCL_AVOID_RECORD_STREAMS=1` for their users. And therefore, this change is limited to these two collectives for now.

Cc: @awgu @janeyx99 @albanD
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111431
Approved by: https://github.com/H-Huang
2023-10-19 00:41:09 +00:00
317e39a8ad [C10d] Cleanup collective sequence number. (#109136)
Sequence numbers must be associated with a Work object
if we want to use it as a way to report collective progress.

The API surface change is introducing Work::getSequenceNumber, which
should eventually be exposed to python.

The bulk of this change is changing gloo to make the sequence number
be always in use and weave it to the dozens subclasses of Work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109136
Approved by: https://github.com/fduwjj
2023-09-26 17:17:04 +00:00
9a1b6d44bb [C10d] Add PG::enableCollectivesTiming to make it dynamically enabled. (#108814)
Collectives timing gates the tracking when a collective starts on device.

Currently it's enabled by set the NCCL_ENABLE_TIMING env var.

The goal of this PR is to make it possible to dynamically enable that flag so users of the PG hooks don't have to set that flag in order to have their hooks work.

The design is that once set, all new collectives will have such behavior so we track it on each Work object.

We make enableTiming_ atomic in PGNCCL to avoid races on non-TSO hardware.

To ensure consistency, we copy its value during Work construction and replace all previous usage of enableTiming_ from the PG with usages from the Work, which now has an immutable value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108814
Approved by: https://github.com/wconstab, https://github.com/fduwjj
ghstack dependencies: #108813
2023-09-20 19:47:41 +00:00
2bca5f2af7 [C10D] Track pg name in c++. (#108813)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108813
Approved by: https://github.com/wconstab
2023-09-15 01:10:29 +00:00
45128ab67c [Reland] Add OnCompletion Hook to ProcessGroup (#106988) (#107233)
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.

This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107233
Approved by: https://github.com/kumpera
2023-08-15 17:35:14 +00:00
fd214aa8be Revert "Add OnCompletion Hook to ProcessGroup (#106988)"
This reverts commit ba1da47e8fa95ca0dd8b2d63430f7eb54fdbbccb.

Reverted https://github.com/pytorch/pytorch/pull/106988 on behalf of https://github.com/huydhn due to Sorry for reverting you change, but it is failing Windows build with some linker error.  The Windows failures on PR looks legit ([comment](https://github.com/pytorch/pytorch/pull/106988#issuecomment-1678580899))
2023-08-15 08:24:33 +00:00
ba1da47e8f Add OnCompletion Hook to ProcessGroup (#106988)
This allows infra/trainers to get detailed stats about communication
efficiencies without know anything about what model or distributed
training paradigms have been used. This is helpful as infra/trainer
package usually prefers to be as model/algorithm agnostic as possible.
Therefore, we cannot assume that infra/trainer can have access to all
collectives used by the model authors.

This commit adds an `OnCompletion` hook to `ProcessGroupNCCL` which
will be fired on every work completion event.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106988
Approved by: https://github.com/kumpera, https://github.com/H-Huang
ghstack dependencies: #107140, #107141, #107160
2023-08-15 04:32:23 +00:00
dd6319198d Apply clang-format to distributed/c10d folder (#107140)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107140
Approved by: https://github.com/H-Huang
2023-08-14 23:16:38 +00:00
3a01c056f5 [PyTorch][ET] Collect Process Groups Mapping Info (#104373)
Summary: Add the logics and interface to log ProcessGroup comms configuration (unique ID, type, and ranks info).

Test Plan:
Testing in HPC:
```
TORCH_LOGS=all ../buck-out/v2/gen/fbcode/c8344b52091f4f7f/hpc/models/ads/__ads_10x_launcher__/ads_10x_launcher.par  +launcher=local launcher.num_trainers=4 +data_loader=random data_loader.num_batches=2000
```
Example output in ET:
```
    {
      "name": "## process_group:init ##", "id": 3, "rf_id": 1, "parent": 2, "fw_parent": 0, "seq_id": -1, "scope": 7, "tid": 1, "fw_tid": 0, "op_schema": "",
      "inputs": ["[{'pg_id': 140538064364672, 'backend_id': 140538060772480, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}, {'pg_id': 140538064363904, 'backend_id': 140538042628864, 'backend_config': 'cuda:nccl', 'ranks': {0: 0, 1: 1, 2: 2, 3: 3}}]"], "input_shapes": [[]], "input_types": ["String"],
      "outputs": [], "output_shapes": [], "output_types": []
    },
```

Differential Revision: D46321690

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104373
Approved by: https://github.com/kwen2501
2023-07-25 03:34:53 +00:00
9165d46b89 DDP + C10D sparse all_reduce changes (#103916) (#104256)
Summary:

reland of https://github.com/pytorch/pytorch/pull/103916

## Changes

prototyping sparse allreduce using the sparse dispatch key. When passing in sparse tensors into `dist.allreduce()` we can execute our dispatched function.

prior to this change, passing a sparse tensor into `allreduce()` will error out with `Tensor must be dense...`

## Example script

```python
# python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 this_script.py

import torch
import torch.distributed as dist

def main():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    a = torch.tensor([[0, 2.], [3, 0]]).to(rank)
    a = a.to_sparse()
    print(f"rank {rank} - a: {a}")
    dist.all_reduce(a)

if __name__ == "__main__":
    main()
```

output:
```
rank 1 - a: tensor(indices=tensor([[0, 1],
                       [1, 0]]),
       values=tensor([2., 3.]),
       device='cuda:1', size=(2, 2), nnz=2, layout=torch.sparse_coo)
allreduce_sparse_cuda_
tensor.is_sparse() = 1
in ProcessGroupNCCL::allreduceSparse
rank 0 - a: tensor(indices=tensor([[0, 1],
                       [1, 0]]),
       values=tensor([2., 3.]),
       device='cuda:0', size=(2, 2), nnz=2, layout=torch.sparse_coo)
allreduce_sparse_cuda_
tensor.is_sparse() = 1
in ProcessGroupNCCL::allreduceSparse
```

Test Plan:
Testing commands (OSS):

```
# python
pytest test/distributed/test_c10d_nccl.py -vsk test_sparse_allreduce_ops

# c++
build/bin/ProcessGroupNCCLTest --gtest_filter=ProcessGroupNCCLTest.testSparseAllreduce
```

Testing commands (internal, ondemand GPU):
ddp tests:
```
buck build mode/opt -c hpc_comms.use_ncclexp=default //caffe2/test/distributed:c10d --show-full-output

# Get the .par file from the previous command and use it below
TORCH_SHOW_CPP_STACKTRACE=1 /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/c8344b52091f4f7f/caffe2/test/distributed/__c10d__/c10d.par -r test_ddp_set_sparse_metadata
```

c10d tests:
```
# build tests and run with log output (python)
buck build mode/opt -c hpc_comms.use_ncclexp=default //caffe2/test/distributed:c10d --show-full-output
NCCL_DEBUG=WARN /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/c8344b52091f4f7f/caffe2/test/distributed/__c10d__/c10d.par -r test_sparse_allreduce_ops

# python
NCCL_DEBUG=WARN buck test mode/opt -c hpc_comms.use_ncclexp=default //caffe2/test/distributed:c10d -- --exact 'caffe2/test/distributed:c10d - test_sparse_allreduce_ops (test_c10d_nccl.ProcessGroupNCCLTest)'

# c++
NCCL_DEBUG=WARN buck run mode/opt -c hpc_comms.use_ncclexp=default //caffe2/test/cpp/c10d:ProcessGroupNCCLTest -- --gtest_filter=ProcessGroupNCCLTest.testSparseAllreduce
```

Differential Revision: D47056695

Pulled By: H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104256
Approved by: https://github.com/rohan-varma
2023-06-28 00:37:52 +00:00
436d035dc7 Revert "DDP + C10D sparse all_reduce changes (#103916)"
This reverts commit fed5fba6e4ee3f221bac481798c5a31f785ba75e.

Reverted https://github.com/pytorch/pytorch/pull/103916 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/103916#issuecomment-1608412325))
2023-06-26 22:37:58 +00:00
fed5fba6e4 DDP + C10D sparse all_reduce changes (#103916)
Summary:
## Changes

prototyping sparse allreduce using the sparse dispatch key. When passing in sparse tensors into `dist.allreduce()` we can execute our dispatched function.

prior to this change, passing a sparse tensor into `allreduce()` will error out with `Tensor must be dense...`

## Example script

```python
# python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 this_script.py

import torch
import torch.distributed as dist

def main():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    a = torch.tensor([[0, 2.], [3, 0]]).to(rank)
    a = a.to_sparse()
    print(f"rank {rank} - a: {a}")
    dist.all_reduce(a)

if __name__ == "__main__":
    main()
```

output:
```
rank 1 - a: tensor(indices=tensor([[0, 1],
                       [1, 0]]),
       values=tensor([2., 3.]),
       device='cuda:1', size=(2, 2), nnz=2, layout=torch.sparse_coo)
allreduce_sparse_cuda_
tensor.is_sparse() = 1
in ProcessGroupNCCL::allreduceSparse
rank 0 - a: tensor(indices=tensor([[0, 1],
                       [1, 0]]),
       values=tensor([2., 3.]),
       device='cuda:0', size=(2, 2), nnz=2, layout=torch.sparse_coo)
allreduce_sparse_cuda_
tensor.is_sparse() = 1
in ProcessGroupNCCL::allreduceSparse
```

Test Plan:
Testing commands (OSS):

```
# python
pytest test/distributed/test_c10d_nccl.py -vsk test_sparse_allreduce_ops

# c++
build/bin/ProcessGroupNCCLTest --gtest_filter=ProcessGroupNCCLTest.testSparseAllreduce
```

Testing commands (internal, ondemand GPU):
ddp tests:
```
buck build mode/opt -c hpc_comms.use_nccl=exp //caffe2/test/distributed:c10d --show-full-output

# Get the .par file from the previous command and use it below
TORCH_SHOW_CPP_STACKTRACE=1 /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/c8344b52091f4f7f/caffe2/test/distributed/__c10d__/c10d.par -r test_ddp_set_sparse_metadata
```

c10d tests:
```
# build tests and run with log output (python)
buck build mode/opt -c hpc_comms.use_nccl=exp //caffe2/test/distributed:c10d --show-full-output
NCCL_DEBUG=WARN /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/c8344b52091f4f7f/caffe2/test/distributed/__c10d__/c10d.par -r test_sparse_allreduce_ops

# python
NCCL_DEBUG=WARN buck test mode/opt -c hpc_comms.use_nccl=exp //caffe2/test/distributed:c10d -- --exact 'caffe2/test/distributed:c10d - test_sparse_allreduce_ops (test_c10d_nccl.ProcessGroupNCCLTest)'

# c++
NCCL_DEBUG=WARN buck run mode/opt -c hpc_comms.use_nccl=exp //caffe2/test/cpp/c10d:ProcessGroupNCCLTest -- --gtest_filter=ProcessGroupNCCLTest.testSparseAllreduce
```

Differential Revision: D46724856

Pulled By: H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103916
Approved by: https://github.com/rohan-varma
2023-06-26 20:42:17 +00:00
22e8a61d9b Implement coalesced reduce_scatter_tensor (#103561)
Map of #101157.

This PR adds support for coalesced `reduce_scatter_tensor` calls in the following syntax:

Sync communication style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
```

Async communication style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the reduce-scatters' results
```
Each `reduce_scatter_tensor` call can be independent in terms of their data and buffer locations. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103561
Approved by: https://github.com/fegin
2023-06-15 20:11:12 +00:00
97180aca5e Enables barrier to support the specified device (#99589)
Enables barrier to support the specified device, e.g cuda/custom device. There is some discussion here: https://github.com/pytorch/pytorch/issues/97938#issue-1646833919

Today, there are two limitations of barrier:
One is that barrier does not support custom  #device:
fbdb86c174/torch/csrc/distributed/c10d/ProcessGroup.hpp (L512-L522)

The second is that there is a special valid for nccl when device_id is not None, which is an assumption for cuda and nccl bindings, and also hinders custom device.
789070986c/torch/distributed/distributed_c10d.py (L3504-L3508)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99589
Approved by: https://github.com/kwen2501
2023-05-17 05:26:04 +00:00
daed3bf8f9 Implement coalesced all_gather_into_tensor (#101157)
This PR adds support for the following use cases:
- Sync style:
```
with dist._coalescing_manager():
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])
```
- Async style:
```
with dist._coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_gather_into_tensor(output_tensors[i], input_tensors[i])

# do a bunch of other things
cm.wait()
# do things that depend on the all-gather's
```
Each `all_gather_into_tensor` would be independent in terms of data and their buffer location. But could be executed in parallel by supported backends (like NCCL).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101157
Approved by: https://github.com/kumpera, https://github.com/wanchaol
2023-05-11 20:58:47 +00:00
0848ed21b8 [c10d] Figure out device to use for object collectives (#100954)
Fixes https://github.com/pytorch/pytorch/issues/97938

this pr is clone from https://github.com/pytorch/pytorch/pull/100238, which is important to me. But
@kwen2501 has not resolved the confliction. So, this pr is submitted to resolve the confliction.
the only confliction is `distributed_c10d.py:2653`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100954
Approved by: https://github.com/kwen2501
2023-05-11 01:49:09 +00:00
3a09aa5977 [c10d] Faster coalescing (#98793)
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
2023-04-24 21:27:26 +00:00
9861ec9785 Revert "[c10d] Faster coalescing (#98793)"
This reverts commit db456ab83da6a505dcebc128903d5ee4fc2d5712.

Reverted https://github.com/pytorch/pytorch/pull/98793 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-04-21 09:15:04 +00:00
db456ab83d [c10d] Faster coalescing (#98793)
### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
2023-04-19 20:17:58 +00:00
2973994259 fix typo in comments under torch/csrc/distributed (#96062)
This PR fixes typos in comments and messages of `.cpp` and `.hpp` files under `torch/csrc/distributed` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96062
Approved by: https://github.com/ngimel
2023-03-07 02:56:41 +00:00
b2ea1d06aa Collective dispatching from Process Group (#91257)
Fixes https://github.com/pytorch/pytorch/issues/90932
Fixes https://github.com/pytorch/pytorch/issues/90659

Remove redundant collection operation definitions by calling the ops directly from `ProcessGroup`

Context:
https://github.com/pytorch/pytorch/issues/86225

Differential Revision: [D42854676](https://our.internmc.facebook.com/intern/diff/D42854676)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91257
Approved by: https://github.com/kwen2501
2023-02-09 18:31:28 +00:00
cyy
f172feae0d More tidy fixes (#93069)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93069
Approved by: https://github.com/Skylion007
2023-01-27 06:40:50 +00:00
77c2a8a11f Clang-Tidy: Improve ctors by removing unnecessary copies and initializations (#91538)
Apply clang-tidy fixups to prefer member initializer and modernize-pass-by-value. This is a mostly a noop, but it should make a few ctors slighlty more readable and more efficient. Also drops in some missing moves that prevents a lot of unnecessary copying.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91538
Approved by: https://github.com/ezyang
2022-12-31 07:19:30 +00:00
7a0f29b776 Allow Process Group to support multiple backends (#88330) (#90997)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88330

### Implementation
Move backend-specific (NCCL, Gloo, etc) collective implementations to corresponding `Backend` class. Update ProcessGroup to support multiple backends and use dispatcher to calls backends based on tensor device type.

### Changes

#### c++ changes (ProcessGroup files, `Ops.cpp`, `init.cpp`)
- Update pybind definitions for new process group base class and new backend class
- Update pybinded backend class with collective definitions to keep BC with Python PG instances (e.g. `dist.ProcessGroupGloo`, `dist.ProcessGroupNCCL`) which are used in tests
- Switch `ProcessGroupGloo`, `ProcessGroupNCCL`, `ProcessGroupMPI`, `ProcessGroupUCC` to derive from the `Backend` class.
- Update CPU/CUDA `Ops.cpp` and `OpsImpl.cpp` to perform this dispatching by querying the backend using the device type
- Update internal dispatched implementation of `barrier` to use a tensor which allows operation to be dispatched.
- Update `allgather` collective to use `TensorList`. For some reason it was using the default implementation of `allgather` rather than dispatching it correctly. I still don't understand why and had originally filed an issue in 85122.

#### python changes (`distributed_c10d.py`, test files)
- Add BackendConfig class to specify the default configurations of backends and `get_backend_config()` API
- `get_backend()` deprecation warning
- `init_process_group` how returns a generic `ProcessGroup` object, it contains a list of backends (the ones stated above) which it will dispatch operations to.
- `new_group` updated to return the same as above
- Update `test_c10d_gloo.py`, Update `DistributedDataParallelTest` to use `init_process_group`, Update `ReducerTest`, update `test_broadcast_coalesced_gloo` to move from PG instance and gloo options
- Update `test_c10d_nccl.py`, Update `DistributedDataParallelTest` to use `init_process_group`
- Specific tests updated: `test_Backend_enum_class`

### Changes missing
- lazy initialization of backends
- support parsing of BackendConfig

### open questions
- Pure Python PG extensions (https://github.com/pytorch/pytorch/pull/66338)

# Example

This is a basic script (using 2 backends within a process group)

```python
# python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 basic_scenario.py
import torch.distributed as dist
import torch
import os

if __name__ == "__main__":
    rank = os.environ.get("RANK")
    # initialize with both gloo and nccl
    dist.init_process_group()
    # with gloo
    dist.all_reduce(torch.tensor([1.0]))
    print(f"Rank {rank} finished")
    # with nccl
    dist.all_reduce(torch.tensor([1.0], device=f"cuda:{rank}"))
```

Test Plan: Imported from OSS

Differential Revision: D42069829

Pulled By: H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90997
Approved by: https://github.com/awgu, https://github.com/fduwjj
2022-12-16 23:15:00 +00:00
1ad0048b64 Refactor distribuetd to use absolute header path (#85780)
Headers under torch/csrc/distributed may be referened with relative path, e.g., "<c10d/...>". However, relative path cannot be gracefully handled by Meta internal build when the NCCL PG is hipified to support AMD/RCCL because the "hipified" header files are generated in other directories. Moreover, using absolute path for header inclusion is the state-of-the-art in most components in Pytorch. Thus, this patch refactors all header paths in torch/csrc/distributed to be absolute.

See D39835774 for more details about Meta internal complication.

**How to test**: commit 9e5d199 removes -I./torch/csrc/distributed in compile options. Thus use it to verify we don't miss any relative path use of torch/csrc/distributed headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85780
Approved by: https://github.com/kumpera, https://github.com/huydhn
2022-09-30 05:13:50 +00:00
a50d8864fc Revert "Refactor distribuetd to use absolute header path (#85780)"
This reverts commit 668082718aefce95ecc1b1c312ea6f127b2c662e.

Reverted https://github.com/pytorch/pytorch/pull/85780 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but it breaks build due to a missing file <c10d/Store.hpp>
2022-09-30 02:04:29 +00:00
668082718a Refactor distribuetd to use absolute header path (#85780)
Headers under torch/csrc/distributed may be referened with relative path, e.g., "<c10d/...>". However, relative path cannot be gracefully handled by Meta internal build when the NCCL PG is hipified to support AMD/RCCL because the "hipified" header files are generated in other directories. Moreover, using absolute path for header inclusion is the state-of-the-art in most components in Pytorch. Thus, this patch refactors all header paths in torch/csrc/distributed to be absolute.

See D39835774 for more details about Meta internal complication.

**How to test**: commit 9e5d199 removes -I./torch/csrc/distributed in compile options. Thus use it to verify we don't miss any relative path use of torch/csrc/distributed headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85780
Approved by: https://github.com/kumpera
2022-09-30 00:27:24 +00:00
74ead61944 [2/N] [Dispatchable Collectives] Extract ProcessGroup::Work into a separate class and update references (#83680)
### Changes
- Move ProcessGroup::Work into its own class and update all the references to it / header includes.

#### Motivation
In the future PRs we will repurpose ProcessGroup to instead contain a list of Backends (ProcessGroupNCCL/Gloo/UCC) and perform dispatching to them based on tensor type. This change is prevent a circular dependency with ProcessGroup depending on Backend and Backend depending on ProcessGroup::Work.

Differential Revision: [D38839212](https://our.internmc.facebook.com/intern/diff/D38839212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83680
Approved by: https://github.com/kwen2501
2022-09-14 13:05:58 +00:00
5e25c2b4cc Add missing spaces to error messages in PG (#84780)
Just some formatting, no real changes. See https://github.com/pytorch/pytorch/pull/83285#discussion_r966469992
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84780
Approved by: https://github.com/kit1980
2022-09-10 00:50:02 +00:00
24a084eda6 [c10d] Fix async error in batch_isend_irecv (#82450)
Summary:
`batch_isend_irecv` previously required the use of `torch.cuda.synchronize` to avoid data race conditions. This was because the ncclStreams were recorderd in the returned ncclWork object _before_ a ncclGroupEnd by the `_batch_p2p_manager` was issued. Thus, the `req.wait()` was effectively waiting on nothing, leading to the later operators working on incorrect intermediate data.

This fix:
- keeps track of ncclStreams to wait on, and records them in the work objects after the batch manager issues a ncclGroupEnd
- renames the `_batch_p2p_manager` to `_coalescing_manager` for generality
- removes the explicit check for NCCL backend inside `_batch_p2p_manager` in distributed_c10.py and moves the manager start/end to ProcessGroup.hpp, in order to transparently work with all process groups

Test Plan: Modified the unittest for `batch_isend_irecv` to check that received tensors are the same as expected tensors. Verified that the test fails before the change, and passes after the change.

Differential Revision: D38100789

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82450
Approved by: https://github.com/kwen2501
2022-08-08 17:50:22 +00:00
f6a45f7984 [distributed] Make DDP work with python process group (#79176)
This PR enables python process group usage with DDP by doing the following:

- Surface PG::Work::getFuture() as overridable()
- Use Work::getFuture() to retrieve values from a PG.
- Add _create_work_from_future python method that creates a Work object that wraps a Future.

To test this changes we use both strategies to run DDP with a python based PG.

The reason for offering two methods is that both have short-comings.

The wrapper method is harder to troubleshoot as there's no visibility of how the future is used.

The subclass method has memory management issues as can be noticed in the test suite by having to keep Work instances alive by storing them in PG fields.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79176
Approved by: https://github.com/rohan-varma
2022-06-28 17:14:21 +00:00
e757cf40cc [c10d] Make broadcast as a custom op
Summary:
This patch makes broadcast as a custom op such that it's dispatcher
passable. It's one part of the effort to route comm ops to the dispatcher
such that tracing mechanisms that relies on the dispatcher can trace them,
e.g., LazyTensor and AOTAutograd.

Test Plan:
python test/distributed/test_c10d_nccl.py -k test_broadcast_ops
python test/distributed/test_c10d_gloo.py -k test_broadcast_basics
...and other existing distributed tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76722

Approved by: https://github.com/pritamdamania87
2022-06-14 01:54:29 +00:00
d6f22abbcc [PyTorch Distributed] Fix batch_isend_irecv
Summary:
`batch_isend_irecv` previously only worked for two-rank cases,
otherwise it would hang, e.g. pytorch/pytorch#73960. This Diff extends
`batch_isend_irecv` to support more than two ranks. The fix is by treating
the operation more like a collective rather than two-rank P2P when selecting
the communicator, since there can be more ranks participating in the batch call than "my" rank and "my" peer.

Rules:
- If `batch_isend_irecv` is the first collective call (including collectives and
  all-to-all) in the `group` given as the argument, then all ranks of the
  `group` are expected to participate in this call.
- Otherwise, if it is not the first collective call in the `group` (i.e. the
  communicator has been initialized), then batched P2P communication involving
  only subset of processes of the `group` is allowed.

Test Plan:
Added p2p_tests.py testing the following patterns:
+    sendrecv_neighbor(input, output)       # Ring like neighbor exchange
+    sendrecv_ripple(input, output)         # Exchange with growing distance (pytorch/pytorch#73960)
+    sendrecv_P2P(input, output)            # Single P2P operation
+    isendrecv_P2P(input, output)           # Single non-blocking P2P operation
+    isendrecv_P2P_batch(input, output, 0)  # batched P2P between only two ranks
+    isendrecv_P2P_batch(input, output, 1)  # batched P2P within a new group created for two ranks

Differential Revision: D35122664

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74701
Approved by: https://github.com/mingzhe09088, https://github.com/osalpekar
2022-04-13 05:55:00 +00:00
3b30c8da88 Add logging for ProcessGroup backends. (#73702)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73702

This PR adds logging to track usage of ProcessGroup backends. The
change involves adding an `init()` method to the ProcessGroup base class and
modifying all the subclasses to call it.

We can't add logging directly in the ProcessGroup constructor since
`getBackendName()` is pure virtual and can't be called from the base class
constructor. See
https://isocpp.org/wiki/faq/strange-inheritance#calling-virtuals-from-ctors for
more details.
ghstack-source-id: 150381852

Test Plan: check logging

Reviewed By: cbalioglu

Differential Revision: D34596295

fbshipit-source-id: 5baa7bfb53bdcd1b4032292c4e7715467f983288
(cherry picked from commit b2344144fe0c25db02c2e958f2acd772e9cd4dc8)
2022-03-08 17:45:42 +00:00