Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.
This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.
There's definitely some similarity with the `with_effects` hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.
The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)). By maintaining the node with its original calling structure in subgraph this all works.
Example transformation:
Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```
The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
Summary:
As titled.
Without the diff, we got P1963055009
With the diff passing in the enviroment, we can do correct sym_int deduction:
https://fburl.com/mlhub/p5zy7o28
Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:unbacked_symints -- test_sdfpa_unbacked_strides --print-passing-details --env TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 --env TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(u0, 0)"
```
Without the fix: P1964887260
With the fix: P1964888579
Differential Revision: D83211018
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163925
Approved by: https://github.com/ColinPeppler
## Context
An example from Qwen2-7B
- This come from running torch.compile with a sequence length that is
divisible by 8 (no padding needed). Call this `Run1`.
- If we then run the compiled model with a difference length that isn't
divisible by 8 (requires padding). Call this `Run2`.
- Then we'll see this error.
```
File "/var/tmp/torchinductor_nobody/2w/c2wby7ilxbna45xrtrrfjqpeutwouruviu2742ockunnd2bleeiz.py", line 1963, in call
buf24 = torch.ops.aten._scaled_dot_product_efficient_attention_backward.default(reinterpret_tensor(buf18, (s85, 3584 // s19, s48, 512 // (512 // s19)), (s48*(512 // (512 // s19))*(3584 // s19), 512 // (512 // s19), (512 // (512 // s19))*(3584 // s19), 1), 0), buf20, buf21, buf22, buf23, getitem, getitem_1, getitem_2, getitem_3, 0.0, [True, True, True, False], scale=0.08838834764831845)
File "torch/_ops.py", line 841, in __call__
return self._op(*args, **kwargs)
RuntimeError: attn_bias is not correctly aligned (strideM). attn_bias.stride(2) = 6102, and should be a multiple of 4.
```
- We only see the error because we did not recompile on `Run2`. Instead we ran the inputs on the same graph as `Run1`.
### A bit more on why.
Here we check whether to realize the unpadded buffer (unwrapped slice) which we want for `Run1` but not for `Run2`.
0897affcd5/torch/_inductor/lowering.py (L2687-L2694)
## Fix
Size hint doesn't guard, so the fix is to use `guard_or*` to guard.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163083
Approved by: https://github.com/eellison
# Feature
This PR supports lowering `IndexPutFallback` through Inductor's FX converter. The approach is very similar to the one taken in https://github.com/pytorch/pytorch/pull/162686.
Compared to `ScatterFallback`, this required one additional change: the value of `self.op_overload` for `IndexPutFallback` was inaccurate. Previously, it used `aten.index_put`, which would result in unsound FX IR. The existing Python/C++ codegen use `aten.index_put_`, since the fallback mutates its input. This PR changes `self.op_overload` to match that.
# Test plan
Added a CI test lowering deterministic index put via the FX converter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162863
Approved by: https://github.com/angelayi
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).
Changes Included
- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.
Fixes#147282
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos, https://github.com/atalman
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).
Changes Included
- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.
Fixes#147282
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).
Changes Included
- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.
Fixes#147282
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos
scaled_grouped_mm's kernel only supports column-major on the second operand. I -think- this is just for efficiency reasons. But inductor treats that buffer as flexible and may tweak the strides to be row-major instead, as seen in the issue.
~Tagging the op as "needs_fixed_stride_order"/"needs_exact_strides" does not work. Inductor only considers those tags for ops that don't have registered lowering (not sure if this is intended). scaled_grouped_mm does have a lowering, so we never check its tags.~ From discussion below, the op tags are expected to work.
FIXES https://github.com/pytorch/pytorch/issues/159097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159134
Approved by: https://github.com/eellison
Rewriting bucketing of all_gather and reduce_scatter with defining of "merge graph" via torch function.
`all_gather_merge_fn_to_trace`
`reduce_scatter_merge_fn_to_trace`
(Instead of creating nodes and doing FakeTensor prop manually)
This allows to experiment with merge function.
Used foreach_copy_ in merging function for all_gather - added lowering for inductor for `foreach_copy_`
Adding topological sort after bucketing passes (comment in post_grad.py):
```
# Fx collectives bucketing passes require topological sort for the cases:
# when bucketed collectives have users before the last collective in the bucket
# AND when inputs of bucketed collective have ancestors after the first collective in the bucket.
#
# In this case we can not manually pick the place for bucketed collective insertion.
# But we are guaranteed by the bucketing (independent collectives in the bucket),
# that it is possible to reorder nodes to satisfy all ordering requirements.
#
# --- before bucketing ---
# in0 = ...
# wait_ag0 = ag(in0)
# user0(wait_ag0)
# ...
# pre_in1 = ...
# in1 = transform(pre_in1)
# wait_ag1 = ag(in1)
# user1(wait_ag1)
#
# --- after bucketing ---
#
# in0 = ...
# user(wait_ag0) <--- wait_ag0 is defined only after bucketed collective.
#
# pre_in1 = ...
# in1 = transform(pre_in1)
# ag_bucket(in0+in1)
# wait_bucket
# wait_ag0 = wait_bucket[0]
# wait_ag1 = wait_bucket[1]
# user1(wait_ag1)
````
Correctness of the passes verified by loss curve for llama3 8b for simple_fsdp and for autoparallel:
<img width="1364" height="495" alt="Screenshot 2025-07-22 at 14 27 28" src="https://github.com/user-attachments/assets/67b2cabb-3206-450b-b529-e23c24292fc6" />
<img width="1355" height="509" alt="Screenshot 2025-07-22 at 14 27 56" src="https://github.com/user-attachments/assets/4d0e6b25-2eb1-47b2-8d68-dcec185239c4" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158663
Approved by: https://github.com/wconstab
## Fixes https://github.com/pytorch/pytorch/issues/157959
## mini repro from issue
```c++
import torch
from torch import nn
class Foo(nn.Module):
def __init__(
self,
use_parameter: bool
) -> None:
super().__init__()
self.b = 101
if use_parameter:
self.b = nn.Parameter(torch.Tensor([self.b]), requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# return x + self.b
# return x - self.b
return x / self.b
# return x * self.b
torch.manual_seed(42)
x = torch.rand((5, 5))
expected = Foo(False)(x)
models = [
Foo(False),
Foo(True),
torch.compile(Foo(False), fullgraph=True),
torch.compile(Foo(True), fullgraph=True),
]
for m in models:
print((m(x) - expected).sum())
```
all outputs equal zero except the result of torch.compile(Foo(False), fullgraph=True)
## summary:
when divisor is a scalar, inductor will lower div to mul the scalar's reciprocal.
this could lead precision lost in c++ kernel. but not in triton kernel
## why:
Generated C++ kernel; thanks to @xmfan for supplying the code.
```c++
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(25L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(16L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = static_cast<float>(0.009900990099009901);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
tmp3.store(out_ptr0 + static_cast<int64_t>(x0));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(16L) && x0 < static_cast<int64_t>(25L)))
{
auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
auto tmp1 = static_cast<float>(0.009900990099009901);
auto tmp2 = at::vec::Vectorized<float>(tmp1);
auto tmp3 = tmp0 * tmp2;
tmp3.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
}
}
}
}
}
```
The float type in C typically has 6 to 7 significant digits, while the double type has 15 to 16 significant digits.
```c++
#include <iostream>
#include <iomanip>
int main() {
auto tmp1 = static_cast<float>(0.009900990099009901);
auto tmp2 = static_cast<double>(0.009900990099009901);
std::cout << std::setprecision(20) << "tmp1 = " << tmp1 << std::endl;
std::cout << std::setprecision(20) << "tmp2 = " << tmp2 << std::endl;
return 0;
}
```
the ouput is
```bash
tmp1 = 0.0099009899422526359558
tmp2 = 0.0099009900990099011103
```
`auto tmp1 = static_cast<float>(0.009900990099009901);` This will cause tmp1 to become 0.0099009, resulting in a loss of precision, so the final result will not match the expected value.
I also found that the bug occurred at that position
86d8af6a6c/torch/_inductor/lowering.py (L6238)
The commit states that the precision lost is expected in cuda implementation.
original commit
03439d4c1c
cuda implementation
0636c11811/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu (L36-L38)
What is interesting is that the Triton kernel works correctly due to the precision of float type in python.
```python
def triton_poi_fused_div_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 25
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = 0.009900990099009901
tmp2 = tmp0 * tmp1
tl.store(out_ptr0 + (x0), tmp2, xmask)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158231
Approved by: https://github.com/eellison
When select has data dependent input, we cant tell if the actual index shall be index+size or index.
to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the
output view and we compute its value dynamically at runtime when inductor is lowered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605
Approved by: https://github.com/ColinPeppler
### What
- Use `statically_known_true` over `guard_size_oblivious` in cases where we're checking an optimization path. Otherwise, it will DDE and we can't take the safe/slower path.
- For broadcast checks, use `fallback=False` if we encounter a DDE. Typically, unbackeds would be ≥2 and that falls inline with size-oblivious reasoning (i.e. when `size_oblivious=True`).
### Example DDE
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0)
Caused by: (_inductor/lowering.py:488 in broadcast_symbolic_shapes)
```
```
torch._inductor.exc.InductorError: LoweringException: GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq((u0//387), 1) (unhinted: Eq((u0//387), 1)). (Size-like symbols: u0)
Caused by: (_inductor/ir.py:2797 in create)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155267
Approved by: https://github.com/eellison
When select has data dependent input, we cant tell if the actual index shall be index+size or index.
to avoid throwing dde, we allocate a new unbacked symbol to represent the storage offset of the
output view and we compute its value dynamically at runtime when inductor is lowered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157605
Approved by: https://github.com/ColinPeppler