- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.
- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796 Enum contains duplicate value: {value}
PIE808 Unnecessary start argument in range
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
This PR enables all PIE rules on ruff, there are already some enabled rules from this family, the new added rules are
```
PIE796 Enum contains duplicate value: {value}
PIE808 Unnecessary start argument in range
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165814
Approved by: https://github.com/ezyang
- During op dispatch local tensor is supposed to collect rng state from CPU and CUDA
devices so that it can be reset before execution of the op for each such that ops
with randomness produces the same result for all ranks (note that we are planning a
separate change to add support of per rank rng state). Previously we relied on
op input arguments to deduce which devices to get rng state from. Which doesn't work
for factory functions such torch.randn. Hence this changes switches to uncondionally
collecting rng state from all devices.
- Fixing per rank specific computations in _MaskedPartial and Shard placements discovered
during test enablement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165716
Approved by: https://github.com/ezyang
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)
Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.
Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
1. Run distributed job with B200 runner, periodically.
2. discovered generic distributed test issue that certain unit test hard-coded ranks, calling for require_exact_world_size(world_size) API instead of require_world_size(world_size).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159323
Approved by: https://github.com/eqy
Co-authored-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
Improve FakeTensor cache to handle SymNode and tracing properly.
For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.
If we aren't proxy tracing then caching is fine.
Thus these changes:
1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.
2. If there's a SymNode in the output then bypass tracing.
3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #165266, #164717
In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.
The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph. We then use that size when doing later operations.
Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).
At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.
After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
While enabling this test discovered lack of support for sub meshes. Added limited support
for sub meshes by properly computing rank coordinates for a given sub mesh. The implementation
follows similar approach to collectives. We infer all sub meshes for the given dimensions and
compute each rank's coordinates with respect to is sub mesh.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165596
Approved by: https://github.com/ezyang
Adding ag+mm support for the case, when gather_dim is last dim of matmul (reduction dim).
When we decompose matmul by reduction dimension we result in partials that needs additional reduction,
we allocate memory for accumulator.
Decomposition should not produce small (thin) mms that can not efficiently load the GPU. Limiting for minimal size of the shard 1024 (found empirically by testing in torchtitan).
scaled_mm is not supported yet for this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163068
Approved by: https://github.com/ngimel
Bucketing of multiple dtypes to be processed in one bucketed collective.
First target is to bucket bf16 and f32, but already can be used with other dtypes.
For now multidtype bucketing is only supported with "custom_ops" mode.
Non custom_ops needs additional work on inductor side.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162470
Approved by: https://github.com/eellison
The goal of this PR is to avoid storing the explicit `mesh` Tensor inside each DeviceMesh, and instead compute it on-the-fly when the end user needs it, and try to replace all of its internal usages with `_layout` and the newly-introduced `_global_rank_permutation` Tensor. The name of this attribute is up for debate. The advantage of the `_global_rank_permutation` Tensor is that it is _the same_ Tensor for the root mesh and all its children, so it doesn't need to be copied/reallocated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165554
Approved by: https://github.com/fduwjj
**Summary**
The load-balancing problem can be modeled as [identical-machines scheduling](https://en.wikipedia.org/wiki/Identical-machines_scheduling) problem. We already provided an easy-to-extend interface in #161062 for
implementing load-balancing and in this PR we start with adding a Round-Robin solution as an example
and also a verification. This can be easily adapted to other solutions like Shortest-processing-time-first/
Longest-processing-time-first with extra padding added for collectives.
- Added a new type of `_LoadBalancer` implementation `_PTRRLoadBalancer` which is designed for
`flex_attention()`. This load-balance strategy analyzes the `BlockMask` sparsity info and perform
Round-Robin (unlike traditional Round-Robin doing it in circular order, we do in zig-zag order).
- Make `_context_parallel_buffers` and `context_parallel_unshard` handle batched load-balance
index (previously it can only handle non-batched load-balance index), like in `create_cp_block_mask`.
**Test**
`pytest test/distributed/tensor/test_attention.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163617
Approved by: https://github.com/fegin
Bucketing a number of smallish improvements:
- Account for bucketing in overlap calculation: if an in-flight collective exists with the same bucket key, reduce new collectives estimated time by its latency time
- Update compute domination so we are ordering based on compute idx, as opposed to compute depth, so we never reorder compute. this makes it a bit easier to reason about memory, and pre-fetching, although we can exploring reordering in the future.
- When we wait on a collective, force all collectives on the same process group as it that were enqueued prior to the collective to wait as well.
Better Memory Handling:
- Pre-fetch limiting - when scheduling collectives for overlap, only pre-fetch up to a certain distance, then schedule off-path collectives (which are typically memory reducing).
- When we are above peak memory, schedule waits.
TODO:
- for each compute node, we know its original memory in the graph. we could limit pre-fetching that goes across peak memory
- By scheduling off-path collectives for overlap, we reduce memory, but if there weren't enough compute for overlap, we need to proactively schedule them. not an issue yet on examples.
- config some hard coded constants, clean up enablement (can do in subsequent pr)
On small llama 2d backward :
578 of 618 potentially hideable collectives hidden
original mem 14.4GB, rescheduled mem, 15.9GB
on forward:
254/256 potentially hideable collectives hidden
original mem 5.8 gb, reshceduled mem 5.8GB
WIP: adding tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165318
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945, #165059
https://github.com/pytorch/pytorch/pull/164820 introduced a bug that `_StridedShard` will call parent class `Shard`'s `split_tensor` method, thus results in incorrect data locality. (I think @ezyang spotted this issue, but we have no test to capture this)
Meanwhile, I notice another bug that when we normalize a `_StridedShard`'s placement, it will also trigger parent class `Shard`'s `split_tensor` method because it will create a Shard class [here](0c14f55de6/torch/distributed/tensor/_api.py (L783)). I think we never test `distribute_tensor` for `_StridedShard` before. So I added a test here to compare against ordered shard.
Using classmethod because the _split_tensor logic is different between `Shard` and `_StridedShard`. Basically I want to shard on local tensors without initializing the Shard object:
```
local_tensor = _StridedShard._make_shard_tensor(dim, tensor, mesh, mesh_dim, split_factor=split_factor)
local_tensor = Shard._make_shard_tensor(dim, tensor, mesh, mesh_dim)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165533
Approved by: https://github.com/XilunWu
These happen when building with CMAKE_BUILD_TYPE=RelWithAssert
This should fix two types of failures that started with https://github.com/pytorch/pytorch/pull/163665
Disclaimer that I used a lot of AI since I don't how pybind works or what refcounts and pointers are, so idk if this is a good solution, or even a solution at all (fwiw the tests pass now)
The first one type is
Truncated:
```
default_pg, _ = _new_process_group_helper(
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2096, in _new_process_group_helper
backend_class = creator_fn(dist_backend_opts, backend_options)
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/distributed/fake_pg.py", line 25, in _create_fake_pg
return FakeProcessGroup._create_internal(
RuntimeError: new_refcount != 1 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/c10/util/intrusive_ptr.h":319, please report a bug to PyTorch. intrusive_ptr: Cannot increase refcount after it reached zero.
Exception raised from retain_ at /var/lib/jenkins/workspace/c10/util/intrusive_ptr.h:319 (most recent call first):
C++ CapturedTraceback:
#4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > const> (), c10::SetStackTraceFetcher(std::function<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0
#5 c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) from ??:0
#6 c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) from ??:0
#7 c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) from ??:0
#8 void pybind11::class_<c10d::FakeProcessGroup, (anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup> >::init_instance<(anonymous namespace)::IntrusivePtrNoGilDestructor<c10d::FakeProcessGroup>, 0>(pybind11::detail::instance*, void const*) from init.cpp:0
#9 pybind11::detail::type_caster_generic::cast(void const*, pybind11::return_value_policy, pybind11::handle, pybind11::detail::type_info const*, void* (*)(void const*), void* (*)(void const*), void const*) from :0
#10 pybind11::cpp_function::initialize<torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> >, int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v>(torch::distributed::c10d::(anonymous namespace)::c10d_init(_object*, _object*)::{lambda(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >)#127}&&, c10::intrusive_ptr<c10d::FakeProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup> > (*)(int, int, c10::intrusive_ptr<c10d::FakeProcessGroup::Options, c10::detail::intrusive_target_default_null_type<c10d::FakeProcessGroup::Options> >), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) from init.cpp:0
```
and I fix it here by getting rid of `DontIncreaseRefcount` and using make_intrusive to do the ref count handling instead. However, I also had to move the constructor to be public, which I think is not good, based on the reasoning of the original PR
The other one type is
```
Traceback (most recent call last):
File "/var/lib/jenkins/workspace/test/test_testing.py", line 2415, in test_no_warning_on_import
self.assertEqual(out, "")
File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 4233, in assertEqual
raise error_metas.pop()[0].to_error( # type: ignore[index]
AssertionError: String comparison failed: "/opt/conda/envs/py_3.10/lib/python3.10/s[352 chars]):\n" != ''
- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/distributed/__init__.py:29: FutureWarning: pybind11-bound class 'torch._C._distributed_c10d.FakeProcessGroup' is using an old-style placement-new '__init__' which has been deprecated. See the upgrade guide in pybind11's docs. This message is only visible when compiled in debug mode.
- if is_available() and not torch._C._c10d_init():
To execute this test, run the following from the base repo dir:
python test/test_testing.py TestImports.test_no_warning_on_import
```
which I fix by getting rid of the `__init__` which I think is ok since it'll just error if you try to make one?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165479
Approved by: https://github.com/ezyang
Summary: This diff fixes a stress test failure by adding a new binary echo4.py and modifying the existing echo1.py binary. The changes are made in both fbcode and xplat directories. The api_test.py file is updated to use the new echo4.py binary, and the BUCK file is updated to include the new binary.
Test Plan:
```
buck test -j 18 'fbcode//mode/opt' fbcode//caffe2/test/distributed/elastic/multiprocessing:api_test -- --exact 'caffe2/test/distributed/elastic/multiprocessing:api_test - test_binary_redirect_and_tee (api_test.StartProcessesListAsBinaryTest)' --run-disabled --stress-runs 20 --record-results
```
```
buck test -j 18 'fbcode//mode/opt' fbcode//caffe2/test/distributed/elastic/multiprocessing:api_test -- --exact 'caffe2/test/distributed/elastic/multiprocessing:api_test - test_binary (api_test.StartProcessesListAsBinaryTest)' --run-disabled --stress-runs 20 --record-results
```
https://www.internalfb.com/intern/testinfra/testrun/17732923648474906https://www.internalfb.com/intern/testinfra/testrun/15481123834815653
Differential Revision: D83623694
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164353
Approved by: https://github.com/d4l3k
for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```
we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```
unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
The wrapper enable to share test body implementation while eliminating need test class by hand. As an example, this change converts the whole DTensorTest to use local tensor mode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165383
Approved by: https://github.com/ezyang
### Implementation of #151705
This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.
To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:
1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)
### Summary of This PR
This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.
### Notable Changes
1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.
@eellison @jansel @PaulZhang12 @shunting314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
Current implementation hardcodes 4D input and output tensor shapes
Change that by computing `output_conv_shape` for any number of input dims
Replace `[.., .., .., slice]` with `[..., slice]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165241
Approved by: https://github.com/ezyang
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164902
Approved by: https://github.com/ezyang
The custom op will fetch the required K and V. Currently, the forward pass is just an all-gather, and the backward pass is a reduce-scatter. While the logic is the same as all_gather_tensor_autograd, the custom op avoids the Autograd warning that wait_tensor() is registered to autograd.
For the next step, we should explore how to interpolate the required communication based on the information from BlockMask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163185
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542, #164500
1. https://github.com/pytorch/pytorch/pull/164111/ adds the support of splitting BlockMask. But BlockMask actually has B=1 case that the BlockMask will be broadcast. This PR adds the support of B=1 case.
2. The original split_args_kwargs_into_chunks doesn't initialize the default specs correctly. Since we now use tree_flatten and tree_unflatten to do split, we should also use tree_map to initialize the default spec. This will actually support the case when the values are not torch.Tensor, which were only supported if users explicitly provide the shard spec.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165306
Approved by: https://github.com/H-Huang
`context_parallel()` being a context manager has annoyed users. Now that we plan to redesign CP's UX to explicitly ask users to:
1. Wrap the attention op into an `nn.Module`
2. Lift any buffers that are not sequence agnostic to input
We can replace `context_parallel()` with two functional APIs: `_context_parallel_shard` and `_enable_context_parallel_dispatcher`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164500
Approved by: https://github.com/XilunWu
ghstack dependencies: #162542
A LocalTensor is a tensor subclass which simulates a tensor that is
distributed across SPMD ranks. A LocalTensor might be size N, but in fact
there are world_size shards/replicas of it stored internally. When you do a
plain PyTorch operation on it, we apply the operation to each shard; when you
do a collective, we do the mathematically equivalent operation on the local
shards. A LocalTensor is associated with a list of ranks which specify
which ranks it holds local tensors for.
NB, this is NOT a DataParallel like abstraction where you can run operations
on multiple different GPUs. It is intended purely for *debugging* purposes,
the overhead is almost certainly too high to keep eight GPUs (even the C++
autograd needs multithreading to keep up!) (It might potentially be possible
to trace through this with torch.compile and then compile it with CUDA graphs
but this is currently a non-goal.)
In order to handle MPMD, we provide a helper decorator that allows you to
run a function with no side effects for each LocalTensor shard and combine
results back into LocalTensor or LocalIntNode.
Note: This PR convert all DTensor ops and some DTensor tests to illustrate
intended usage and ensure conrrectness. In subsequent PR more tests will be
converted. DUring test conversion we aim to share as much as possible of
test logic between multi-process / multi-threaded and local tensor tests.
We would like to developers to be able to run both flavors of the tests.
Note: This work is based on the original proposal
by @ezyang (WIP PR https://github.com/pytorch/pytorch/pull/162753).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164537
Approved by: https://github.com/ezyang