for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```
we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```
unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165024
Approved by: https://github.com/mori360
When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.
When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](a2e2e1d8c0/torch/distributed/distributed_c10d.py (L3303)), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.
It is needed to unblock the aot_eager+autobucketing pass in this [PR](https://github.com/pytorch/torchtitan/pull/1813).
Otherwise, I hit the error as follows:
```bash
traceback : Traceback (most recent call last):
File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
return f(*args, **kwargs)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
self.train_step(data_iterator)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
loss = self.forward_backward_step(input_dict, labels)
File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
return super().__call__(*args, **kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
raise BackendCompilerFailed(
self.compiler_fn, e, inspect.currentframe()
).with_traceback(e.__traceback__) from None
File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(
~~~~~~~~~~~~~~~~~~^
aot_state,
^^^^^^^^^^
...<4 lines>...
inference_compiler,
^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
return aot_stage2_autograd(aot_state, aot_graph_capture)
File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
schedule_overlap_bucketing(gm)
~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
).run()
~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
dist.all_gather_object(
~~~~~~~~~~~~~~~~~~~~~~^
gathered_runtime_estimations, runtime_estimations, pg
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
return func(*args, **kwargs)
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
byte_tensor = torch.ByteTensor(byte_storage).to(device)
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match.
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165063
Approved by: https://github.com/eellison
Summary: For D84399286, failing ads ne deterministic tests now. These tests are especially brittle with subtle bitwise numerics changes. Will reenable for fbcode once e2e validation tests are performed
Test Plan: N/A
Differential Revision: D84514361
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165328
Approved by: https://github.com/izaitsevfb
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.
In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.
```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```
Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.
```
NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" LOG_RANK=0 TRAIN_FILE="torchtitan.train" TORCHFT_LIGHTHOUSE="http://localhost:29510" PYTORCH_ALLOC_CONF="expandable_segments:True" torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
Skip test_compiled_autograd_attribution on s390x
It fails both on s390x and x86_64 at least under some circumstances. Disable it for now until on s390x until it works reliably.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163647
Approved by: https://github.com/malfet
## Problem
Okay there's limitations with today's `atomically_apply_size_hint` though it works for most observed failures we've seen so far. However, it's easy to come up with an edge case.
Suppose you encounter this setup.
```
a: [s0 + u0]
b: [s1 + u1]
c: [u2 + u3]
d: [u100]
```
Today, we use a few heuristics to specify the LHS and RHS for replacements.
10d2734d9b/torch/_inductor/sizevars.py (L730-L759)
It's possible to end up with these replacement rules. Notice how there's no replacement for `s1 + u1` and `u2 + u3` :( That's because today picking the LHS and RHS matters a lot, and `s1 + u1` & `u2 + u3` happened to end up on the RHS.
```
s0 + u0 => s1 + u1
s0 + u0 => u2 + u3 # overrides previous replacement; each expr only gets one replacement
s0 + u0 => u100 # overrides previous replacement; ditto
```
I believe what we really want is this: everybody gets a replacement! And they all should (eventually) settle at the same canonical expr (i.e. `u100`) when running the replacement several times.
```
s1 + u1 ==> s0 + u0
u2 + u3 ==> s0 + u0
s0 + u0 ==> u100
```
We can just short-cut this by using the canonical expr as the replacement.
```
s1 + u1 ==> u100
u2 + u3 ==> u100
s0 + u0 ==> u100
```
## Implementation
I offer one way to deal with this:
1. assure every expression has one canonical replacement (i.e. `u100`)
2. if two expressions are equal (inferred from `deferred_runtime_asserts`), then they must have the same canonical replacement
We can implement the above with union find.
* Whenever you see `Eq(lhs, rhs)` then do `union(lhs, rhs)`.
* Whenever you want to find the canonical replacement for a given expr then do `find(expr)`.
* When picking the canonical replacement we can use a few heuristics like (1) prefer a fully backed expr, (2) replacing with sub-expressions, and whatever we'd like.
Differential Revision: [D84549260](https://our.internmc.facebook.com/intern/diff/D84549260)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164324
Approved by: https://github.com/laithsakka
Update the torch-xpu-ops commit to [intel/torch-xpu-ops@ce9db1](ce9db15136), includes:
- Fix test_barrier hang by using static global rank in ProcessGroupXCCL
- Update install_xpu_headers only when content should change to speedup recompilation
- Add global rank information to communication logging
- Remove duplicate normalization from FFT methods
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165321
Approved by: https://github.com/EikanWang
The wrapper enable to share test body implementation while eliminating need test class by hand. As an example, this change converts the whole DTensorTest to use local tensor mode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165383
Approved by: https://github.com/ezyang
Follow up to #165098 - adding bf16 support for the backward pass. To avoid BC breaking changes/losing precision, we upcast the parameters to fp32 after the op gets called, and downcast the gradients to bf16 before returning.
For testing, we upcast to fp32 before calling the reference function. We increase the tolerance to 1e-2 for bf16 inputs because of a difference in casting calculations between python's `x.to(torch.bfloat16)` and cpp's `x.to(at::kBFloat16)` (after comparing intermediate tensors, we found that the numerics diverge after the final casting). We don't explicitly cast in the CPP op but rather let autograd/optimizer handle it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165325
Approved by: https://github.com/andrewor14
Creates the fork/join stream ops. These ops are passthrough ops which mutate all of their args (without actually performing any computation on them) so that during functionalization, implicit dependencies are added on all of their args. This allows us to prevent reordering during our pre/post grad graph passes.
Make custom ops inplace
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162900
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027, #162899, #163028
Stores streams in a global object look table that maps a dynamo selected index to objects. This index is generated during tracing, and at runtime, a helper function is called from the bytecode to populate this map.
This differs from the previous implementation that simply mapped IDs to the associated objects. This required specialization on the IDs of the specific objects, while this new approach does not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162899
Approved by: https://github.com/anijain2305
ghstack dependencies: #163027
### Implementation of #151705
This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.
To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:
1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)
### Summary of This PR
This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.
### Notable Changes
1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.
@eellison @jansel @PaulZhang12 @shunting314
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
Current implementation hardcodes 4D input and output tensor shapes
Change that by computing `output_conv_shape` for any number of input dims
Replace `[.., .., .., slice]` with `[..., slice]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165241
Approved by: https://github.com/ezyang
Summary:
### Problem
ArrayRef's `equals()`does elementwise quality using `==` operator. This can cause a DDE for unbacked symints since `==` operator calls `guard_bool`.
```
// SymInt.h
bool operator==(const SymInt& o) const {
return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
```
### Solution
Adds `sym_equals()` to do elementwise equality for `SymIntArrayRef`. Use this instead of `equals()` for `SymIntArrayRef`.
Reviewed By: guangy10, pianpwk, muchulee8
Differential Revision: D84168401
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165112
Approved by: https://github.com/Skylion007
This PR introduces a way to compile a region of FX graph using `fx.traceback.annotate`.
### UX
1) In the user code, mark the region that you want to be compiled with inductor using `with fx_traceback.annotate({"compile_with_inductor": 0})`. As of now, we just rely on the string `compile_with_inductor` and ignore the integer. As the needs arise, we can update the logic.
Example
```
def fn(x, y):
sin = torch.sin(x)
with fx_traceback.annotate({"compile_with_inductor": 0}):
mul = sin * y
add = mul + 1
return torch.sin(add)
```
2) You have to instruct the compiler to use the annotations with `compile_fx_annotated_nodes_with_inductor` transformation. This is somewhat controversial, and a user might expect that just setting annotation is enough. But for now to control the blast radius, we need to explicitly do this. One such example is
```
# Set the fw and bw compiler of aot_autograd to `compile_fx_annotated_nodes_with_inductor`
def aot_eager_regional_inductor():
return aot_autograd(
fw_compiler=compile_fx_annotated_nodes_with_inductor,
bw_compiler=compile_fx_annotated_nodes_with_inductor,
)
```
3) Fixable in short-term - You have to wrap the user code in `torch.fx.traceback.preserve_node_meta` to ensure that annotations are propagated to the compiler. This is fixable, just need to make CI happy.
### Implementation
1) Relies on `CapabilityBasedPartitioner` to "scoop" out regions based on annotations, and then create subgraphs in the main graph.
2) Call `torch._inductor.standalone_compile` on these subgraphs, and jam the returned callable into the FX graph at the place of call_module
Resulting graph looks something like this - search for `torch__inductor_standalone_compile_inner`
Forward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
sin: "f32[10]" = torch.ops.aten.sin.default(primals_1)
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(sin, primals_2)
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:68 in fn, code: add = mul + 1
getitem: "f32[10]" = inner[0]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
sin_1: "f32[10]" = torch.ops.aten.sin.default(getitem)
return (sin_1, primals_1, primals_2, sin, getitem)
```
Backward graph
```
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10]", primals_2: "f32[10]", sin: "f32[10]", add: "f32[10]", tangents_1: "f32[10]"):
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
cos_1: "f32[10]" = torch.ops.aten.cos.default(primals_1); primals_1 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:70 in fn, code: return torch.sin(add)
cos: "f32[10]" = torch.ops.aten.cos.default(add); add = None
mul_1: "f32[10]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None
# No stacktrace found for following nodes
inner = torch__inductor_standalone_compile_inner(mul_1, sin, primals_2); mul_1 = sin = primals_2 = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:67 in fn, code: mul = sin * y
getitem: "f32[10]" = inner[0]
getitem_1: "f32[10]" = inner[1]; inner = None
# File: /data/users/anijain/pytorch2/test/dynamo/test_regional_inductor.py:64 in fn, code: sin = torch.sin(x)
mul_4: "f32[10]" = torch.ops.aten.mul.Tensor(getitem_1, cos_1); getitem_1 = cos_1 = None
return (mul_4, getitem)
```
### Some issue raised in the HOP meeting
1) CSE will not differentiate different meta custom nodes and do wrong thing.
2) SAC - The recomputed forward will be smaller than the forward. Will we compile a smaller region than?
3) What happens if you have a op in the middle which does not disturb the topology, is it still 1 subgraph?
4) What happens with the nesting of `fx_traceback.annotate`? Are there any ordering requirements?
5) What are we going to use the annotations for?
a) compile flex
b) streams
c) nn.Module info to organize MoE components for pipelining
d) PP stages
e) Rename graph nodes for more debugging
f) No nested regional compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164776
Approved by: https://github.com/SherlockNoMad
ghstack dependencies: #165188
(Extract out the algorithm from https://github.com/pytorch/pytorch/pull/160266.)
Build a graph to search for the path from source placement to destination placement (with device order). Currently solution introduces too many all-gathers and missing the opportunity for all-to-all when redistribute, especially when we consider the device order.
### How to build the graph:
When operator of Shard, think of collective op as operation on a stack of device axis:
- I, J are tensor dimensions;
- X, Y, Z, Y are ordered mesh dimensions.
<img width="357" height="253" alt="image" src="https://github.com/user-attachments/assets/23bb3cc3-0506-4071-9053-3c525cf0e526" />
Detailed collective op transition is implemented in `DTensorRedistributePlanner.get_next_state`.
### How to find the min cost path:
Assign weight to different type of collective ops and use Dijkstra to find the min cost path from the graph we build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164902
Approved by: https://github.com/ezyang