ff164e9161
Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 16:07:25 -08:00
2e61e1e740
Update base for Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 16:07:25 -08:00
14334d0607
Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:49:12 -08:00
60f98d5a27
Update base for Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:49:12 -08:00
403bff4473
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:08:43 -08:00
67dbbe261d
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:08:43 -08:00
9b4ac45d2f
Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer ( #166165 )"
...
This reverts commit eefa16342c9f322b56c7c0cd6d309c3ed8f0b882.
Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/jeanschmidt due to Breaking internal tests D86216934 ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3499645688 ))
2025-11-06 22:34:48 +00:00
e4034677c1
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 14:34:08 -08:00
1015260e07
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 14:34:08 -08:00
a45a17f65e
Fix boxcox to return same result for same input in one batch ( #166986 )
...
Summary:
The SIMD path is using SLEEF version of pow which is slightly different from std::pow. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results.
Deploy:
Need to make a hotfix in waas to monitor release signals, since this diff can cause testing failures in veloski and waas release correctness tests.
Test Plan: Sandcastle.
Differential Revision: D86218207
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166986
Approved by: https://github.com/swolchok
2025-11-06 22:33:26 +00:00
c5593e75b3
Fix flaky memory profiler test ( #167168 )
...
Fixes #167037
Do not check the exact number of frames.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167168
Approved by: https://github.com/angelayi
2025-11-06 21:39:44 +00:00
c90a976370
Update pythoncapi_compat.h ( #167138 )
...
Update to commit 44c8e14bbbb5d5135ae90957036a61397e4df577.
Should slightly simplify https://github.com/pytorch/pytorch/pull/166342
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167138
Approved by: https://github.com/albanD
2025-11-06 21:31:58 +00:00
d144382dc9
Move enrich_profiler_metadata config import out of gm.recompile() ( #167114 )
...
Fixes T243967987
Move `enrich_profiler_metadata` from `torch._dynamo.config` to `torch.fx.experimental._config`.
We cannot import anything inside recompile(), it made some perf regress internally. We move the config so we can import it at the top of `graph_module.py` without causing any circular import.
We also cannot delete the old config right now because some internal tests rely on copies of the old `graph_module.py` cpp file in unit tests. But I think we should be able to delete the old config soon after this PR lands.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167114
Approved by: https://github.com/angelayi
2025-11-06 21:21:40 +00:00
78827c5e00
Distributed Autotuning ( #163369 )
...
This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product.
Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times.
Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64.
There are several advantages:
1. Faster autotuning times - each CPU/GPU does less work total
2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program.
Results:
In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s.
Usage:
The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE.
Co-authored-by: @PaulZhang12
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163369
Approved by: https://github.com/PaulZhang12
2025-11-06 21:10:21 +00:00
ab1e734cd7
[ez] avoid log spam when random data is generated ( #166919 )
...
It's annoying to see full screen of this warning when running fx_graph_runnable files saved in tlparse.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166919
Approved by: https://github.com/eellison
2025-11-06 21:05:20 +00:00
888958ad6c
Prevent torch._check causing graph breaks ( #164676 )
...
Handle `torch._check` in `TorchInGraphFunctionVariable.call_function`. Basically, it has two arguments - a predicate (bool) and a message (callable). If predicate is a constant, evaluate `torch._check`. If predicate is true, it just will compile and nothing happens. If predicate is false, `torch._check` will raise an exception.
If predicate is not constant, we manually emit a proxy. I tried to build as_proxy() inside NestedUserFunctionVariable, but failed to, that's why I create it here. I try to extract message. If it's a function, I retrieve it. If not, set it to None. Maybe we could extract it if message is a closure, but not sure how
Fixes #163668
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164676
Approved by: https://github.com/williamwen42 , https://github.com/mlazos
Co-authored-by: William Wen <william.wen42@gmail.com >
2025-11-06 21:00:48 +00:00
d19f36bea1
[BE][Ez]: Update fmtlib submodule to 12.1.0 ( #166983 )
...
Fixed some compiler idiosyncrasies, improves CPP support, bugfixes, and performance optimizations. This is a header only minor library change so should be low risk and improve the performance of our formatting/loggers. Also allows fmtlib to be used in more constexpr contexts.
Full changelog here: https://github.com/fmtlib/fmt/releases/tag/12.1.0
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166983
Approved by: https://github.com/atalman
2025-11-06 20:39:00 +00:00
096c9356de
[CUDA][cuBLASLt] addmm -- enable 2D bias in the Lt path when followed by an activation ( #165548 )
...
As per title.
This one is based off [#163955 ](https://github.com/pytorch/pytorch/pull/163955 ), but I will rebase once it is merged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165548
Approved by: https://github.com/eqy
2025-11-06 20:29:32 +00:00
03dea563f4
Add guidance on how to migrate kernels to the libtorch stable ABI ( #167112 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167112
Approved by: https://github.com/janeyx99
2025-11-06 20:27:27 +00:00
2e83ae2de7
[pp] Add reduce_grad Action ( #166449 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166449
Approved by: https://github.com/wconstab , https://github.com/sanketpurandare
2025-11-06 20:02:46 +00:00
77b70970f7
[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel ( #167182 )
...
Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036 , which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs.
Test Plan:
Inductor test (fbcode):
`INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"`
Tritonbench (fbcode):
`clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy`
Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy`
Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`
Differential Revision: D86376880
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167182
Approved by: https://github.com/mlazos , https://github.com/jananisriram
2025-11-06 19:55:38 +00:00
c9b2db73ca
[Sigmoid][Delta Update][2/N] update delta update api to load original value first before casting to target dtype ( #167039 )
...
Summary: The current delta update has a strong assumption that the non-lowered weights share the same tensor dtype from the lowered version. This is not true by design. When dtype mismatches the data loading will load the data into unexpected dtype which introduces undefined behavior. This diff aims to close the gap by always load tensor by its original dtype first then cast to desired dtype.
Test Plan:
No more NaN values!
{P2022339213}
Reviewed By: kqfu
Differential Revision: D86181685
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167039
Approved by: https://github.com/henryoier
2025-11-06 19:31:18 +00:00
ba2e6b0b4f
[ROCm] Enable StaticCudaLauncher for ROCm ( #166492 )
...
This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware.
**Changes**
**Python (torch/_inductor/runtime/)**
- static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling
- triton_heuristics.py: Updated device type checks to support both cuda and hip
**C++ (torch/csrc/)**
- Module.cpp: Enabled StaticCudaLauncher for ROCm builds
- inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls
- inductor/static_cuda_launcher.h: Updated header guard
**Tests (test/inductor/)**
- test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling
**Enabled Unit Tests**
All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm:
1. test_basic
2. test_unsigned_integers
3. test_signed_integers
4. test_basic_1arg
5. test_constexpr
6. test_implied_constant
7. test_kernel_no_args
8. test_high_shared_mem
9. test_too_high_shared_mem
10. test_kernel_empty_tensor
11. test_kernel_many_args
12. test_basic_compile
13. test_incompatible_code
14. test_static_launch_user_defined_triton_kernels
15. test_empty_tensor
16. test_any
17. test_disable_static_cuda_launcher
In addition to this, the following tests from test/inductor/test_codecache.py also pass:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True
4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False
5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False
6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True
The following tests are skipped since triton bundling is necessary for StaticCudaLauncher:
1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True
2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492
Approved by: https://github.com/jeffdaily
2025-11-06 19:29:35 +00:00
8523a64c4b
Fix python -m build: error: unrecognized arguments: --no-build-isolation ( #166848 )
...
Fixes #166326
The PR fixes the following error:
```
python -m build: error: unrecognized arguments: --no-build-isolation
```
The regression has been introduced in the [commit](50d418f69f (diff-e5a6ba9ea3717e5913cd885e81f143937ea727282edd6939479a2a60b1051bf5R73) ) in the scope of [PR](https://github.com/pytorch/pytorch/pull/156712 ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166848
Approved by: https://github.com/seemethere
2025-11-06 19:13:37 +00:00
9fef18e31d
[ROCm] Enable multi-arch compilation and unit tests for AOT Inductor ( #166357 )
...
## Summary
This PR adds multi-architecture kernel compilation support for ROCm in PyTorch's AOT Inductor module, enabling a single compiled model to run across multiple AMD GPU architectures (MI200, MI300, MI350, etc.) without recompilation.
## Implementation
- **Multi-arch compilation pipeline**: Compiles LLVM IR to multiple GPU architectures and bundles them using `clang-offload-bundler`
- **Architecture detection**: Automatically detects target architectures from `torch.cuda.get_arch_list()`, with overrides via `PYTORCH_ROCM_ARCH` environment variable
- **ROCm-specific utilities**: New `rocm_multiarch_utils.py` module handles ROCm toolchain integration
- **Test infrastructure**: Adapted AOT Inductor tests to support both CUDA and ROCm compilation paths
## Testing
Successfully tested on:
- MI200
- MI300
**Enabled tests:**
- `test_simple_multi_arch`
- `test_compile_after_package_multi_arch`
- `test_compile_with_exporter`
- `test_compile_with_exporter_weights`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166357
Approved by: https://github.com/jeffdaily
2025-11-06 19:08:15 +00:00
aaea391b62
[annotate][export] Add annotation to assertion nodes in export ( #167171 )
...
Fixes #166906
```
python test/export/test_export.py -k test_annotate_on_assert
```
The assertions are not marked with annotation because these nodes are created in `apply_runtime_assertion_pass`. Currently the annotation will only be added if the nodes are created during tracing. So we need to manually add the annotation.
Nodes added in `apply_runtime_assertion_pass` will have the same annotation as the input node to the assertion.
Output graph:
Note that `_assert_scalar_default_1` is not annotated becayse it's an assertion on the size of `x` which is not annotated.
```
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s77]", y: "i64[]"):
# No stacktrace found for following nodes
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
# Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:729 in forward, code: x = torch.cat([x, x])
cat: "f32[2*s77]" = torch.ops.aten.cat.default([x, x]); x = None
# Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:730 in forward, code: b = y.item()
item: "Sym(u0)" = torch.ops.aten.item.default(y); y = None
ge_1: "Sym(u0 >= 4)" = item >= 4
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 4 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
# No stacktrace found for following nodes
mul_1: "Sym(2*s77)" = 2 * sym_size_int_1; sym_size_int_1 = None
le: "Sym(2*s77 <= u0)" = mul_1 <= item; mul_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression 2*s77 <= u0 on node 'le'"); le = _assert_scalar_default_1 = None
# Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:732 in forward, code: return x * b
mul: "f32[2*s77]" = torch.ops.aten.mul.Tensor(cat, item); cat = item = None
return (mul,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167171
Approved by: https://github.com/angelayi
2025-11-06 18:57:30 +00:00
7206668f7c
Update torch.var documentation to use modern API ( #167209 )
...
## Summary
Fix outdated unbiased parameter references in normalization module documentation. Replace deprecated torch.var(input, unbiased=False/True) with modern torch.var(input, correction=0/1) API throughout BatchNorm, InstanceNorm, LayerNorm, and GroupNorm docstrings.
## Changes
- torch/nn/modules/batchnorm.py: Updated 4 instances across BatchNorm1d, BatchNorm2d, BatchNorm3d, and SyncBatchNorm
- torch/nn/modules/instancenorm.py: Updated 3 instances across InstanceNorm1d, InstanceNorm2d, and InstanceNorm3d
- torch/nn/modules/normalization.py: Updated 2 instances in LayerNorm and GroupNorm
## Test plan
Mathematical behavior remains identical: unbiased=False ≡ correction=0 (biased estimator), unbiased=True ≡ correction=1 (unbiased estimator). Documentation now uses consistent modern API terminology with no functional changes to code behavior.
Fixes #166804
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167209
Approved by: https://github.com/albanD
2025-11-06 18:52:22 +00:00
7729de07d3
Build libgomp (gcc-13) from src on AArch64 ( #166549 )
...
This improves thread-scaling on AArch64 (see details on #155795 )
Fixes : #155795
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166549
Approved by: https://github.com/malfet
2025-11-06 18:31:03 +00:00
73078f305f
Add missing super().setUp() ( #167163 )
...
In a trunk failure today, we saw the same test running on both trunk and slow shards. The reason is that this test didn't invoke `super().setUp()`, so all the test features like slow and disabled test didn't apply to them.
I use Claude to find all test classes with a `setUp()` method that didn't called `super().setUp()` and patch all of them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167163
Approved by: https://github.com/malfet
2025-11-06 17:55:23 +00:00
ea7add4837
fix static_input_indices subclass remapping under training ( #167127 )
...
We have some logic figure out "given which inputs have static indices in the pre-subclass-desugaring graph, figure out the static indices in the post-subclass-desugaring graph", and it was busted for training.
Separately, we should probably not have to do this logic at all - as @eellison mentioned, inputs/outputs in the graph are less likely to be tweaked through graph passes, so it would be more convenient and less hassle if we just stashed if a given input was static directly on the Descriptor for it. I did not end up doing that in this PR though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167127
Approved by: https://github.com/ezyang
2025-11-06 17:34:35 +00:00
0ed4119420
[ROCm][CI] Run rocm.yml and inductor-rocm.yml every 3rd hour ( #167220 )
...
Even after [reducing frequency of rocm.yml and inductor-rocm.yml to per hour](https://github.com/pytorch/pytorch/pull/166870 ), we are still observing queueing on MI2xx runners as of Nov 6 2025 10:30AM CST:
<img width="470" height="191" alt="{DFECE929-174D-4EE4-9448-D43AA1AF0B53}" src="https://github.com/user-attachments/assets/014b2266-7c60-44e5-9a32-3ebea64232b6 " />
We think it's because we had to move the periodic.yml workflow runs to the MI210 runners in light of the Cirrascale runners not being available: https://github.com/pytorch/pytorch/issues/166866 . We observe [increased queueing](https://hud.pytorch.org/queue_time_analysis?dateRange=7&startDate=2025-10-30T16%3A00%3A48.381Z&endDate=2025-11-06T16%3A00%3A48.381Z&granularity=hour&chartType=bar&repos=pytorch%2Fpytorch&category=machine_type&machineTypes=linux.rocm.gpu.2&items=linux.rocm.gpu.2 ) after the point where we added periodic jobs to the MI210 runners.
<img width="453" height="252" alt="linux rocm gpu 2_queueing" src="https://github.com/user-attachments/assets/532984cf-046b-4a02-a096-f17364632da3 " />
This PR temproarily changes the rocm.yml and inductor-rocm.yml workflows to run on a 3-hourly basis rather than every hour, until the Cirrascale outage is resolved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167220
Approved by: https://github.com/jeffdaily
2025-11-06 17:23:23 +00:00
03fd2b796e
[Flight Recorder] Reverted to include stack traces for dump pipe triggered FR dump ( #167023 )
...
[Flight Recorder] Reverted to include stack traces for dump pipe triggered FR dump (#167023 )
Summary:
We should also retry if include stacktraces failed. Changed was introduced in https://github.com/pytorch/pytorch/pull/164591
Test Plan: eyes
Reviewed By: fduwjj
Differential Revision: D86248484
2025-11-06 09:16:29 -08:00
fd7bf9ce10
[Inductor] Fix unbacked float symbol handling in kernel codegen ( #166890 )
...
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.
Fixes : #166888 #163674
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison , https://github.com/mlazos
2025-11-06 17:14:31 +00:00
41c9eeecec
Update Sphinx dependencies ( #164901 )
...
This pull request updates the PyTorch documentation build system to support newer versions of Sphinx and its related dependencies, improves coverage checking for undocumented objects, and adds configuration enhancements to the docs build. The most important changes are grouped below.
**Dependency Upgrades and Compatibility:**
* Upgraded `sphinx` to version 7.2.6 and updated related documentation dependencies (`breathe`, `exhale`, `docutils`, `myst-nb`, `sphinx-design`, `myst-parser`, and others) in `.ci/docker/requirements-docs.txt` to ensure compatibility with Python 3.13 and improve documentation generation. [[1]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L1-L12) [[2]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L39-R45) [[3]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L59-R64)
* Replaced the editable install of `pytorch_sphinx_theme2` with a pinned version for stability in documentation builds.
**Documentation Coverage and Build Improvements:**
* Updated the coverage check logic in `.ci/pytorch/python_doc_push_script.sh` to parse the new Sphinx 7.2.6+ coverage report format, extracting the undocumented count from the statistics table for more reliable coverage validation.
**Configuration and Formatting Enhancements:**
* Introduced `autosummary_filename_map` in `docs/source/conf.py` to resolve duplicated autosummary output filenames for functions and classes with the same name, improving documentation clarity.
**Minor Documentation Formatting:**
* Removed an unused `:template:` directive from `docs/source/quantization-support.md` for cleaner autosummary output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164901
Approved by: https://github.com/albanD
2025-11-06 17:14:26 +00:00
bfc0ba4af9
nn.Linear: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (#166071 )
...
As per title.
It seems safe to be able to generalize to arbitrary contiguous inputs since `at::matmul` is likely to do the flattening to avoid `baddmm`.
Additionally, we guard for bias to be 1D and contiguous which is guaranteed to be fused with no copies.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166071
Approved by: https://github.com/ngimel
2025-11-06 16:50:12 +00:00
3fdc5dbf1d
Make CUDA preload logic more straightforward ( #167046 )
...
I.e. remove distinction between two cases, and always preload full set of libraries
For some reason, when one uses `virtualenv` instead of `venv`,
preloading `cudart` works, but it fails to find cudnn or cublasLT later on
Fix it, by getting read of partial preload logic for one of the cases and always preload full set of libraries
Test plan on stock Ubuntu:
```
pip install virtualenv
virtualenv --symlinks -p python3.11 --prompt virtv venv-virt
source venv-virt/bin/activate
pip install torch
python -c 'import torch'
```
Fixes https://github.com/pytorch/pytorch/issues/165812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167046
Approved by: https://github.com/atalman
2025-11-06 16:30:16 +00:00
cc477f6009
[inductor] Use runtime estimations in iterative sink waits pass ( #167081 )
...
Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K
reorder iterative part
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167081
Approved by: https://github.com/eellison
ghstack dependencies: #167080
2025-11-06 16:14:48 +00:00
7b055a0103
Add per_process_memory_fraction to PYTORCH_CUDA_ALLOC_CONF ( #161035 )
...
torch.cuda.memory.set_per_process_memory_fraction allows setting
an upper bound on how much device memory is allocated. This PR
exposes this setting to an environment variable.
For example, PYTORCH_CUDA_ALLOC_CONF="per_process_memory_fraction:0.5"
will limit the device memory to half of the available memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161035
Approved by: https://github.com/ngimel , https://github.com/eqy
2025-11-06 16:10:16 +00:00
da2eb31b82
[MTIA][PyTorch] Add mtia as native device for PyTorch tests ( #167089 )
...
Summary: Add MTIA as a native device type in PyTorch.
Test Plan: CI
Reviewed By: PatriceVignola
Differential Revision: D80111801
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167089
Approved by: https://github.com/andyanwang , https://github.com/nautsimon , https://github.com/albanD
2025-11-06 15:43:45 +00:00
2005b5f548
[inductor] Use runtime estimations in iterative reorder collectives pass ( #167080 )
...
Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K
reorder iterative part
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167080
Approved by: https://github.com/eellison
2025-11-06 14:20:49 +00:00
b2d72a4008
Revert "Don't hardcode double argument for reduction base ( #166951 )"
...
This reverts commit a74fe75c450277eb88a95c764e8b0a664a550a86.
Reverted https://github.com/pytorch/pytorch/pull/166951 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/166951#issuecomment-3497253260 ))
2025-11-06 13:26:04 +00:00
80ec2ab78e
[8/N] Fix unused loop variables in tests ( #166921 )
...
This PR continues to fix or remove unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166921
Approved by: https://github.com/mlazos
2025-11-06 12:20:00 +00:00
c724f0097d
[2/N] Use key in dict for existence checks ( #167174 )
...
This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167174
Approved by: https://github.com/mlazos
2025-11-06 12:13:47 +00:00
a51208c656
Check cluster_dims attribute exists before access ( #167187 )
...
Error in Helion CI's AMD job: https://github.com/pytorch/helion/actions/runs/19118581048/job/54633730633
```
> (binary.metadata.num_ctas, *binary.metadata.cluster_dims)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
if hasattr(binary, "metadata")
else ()
)
),
"function": get_first_attr(binary, "function", "cu_function"),
"runner": get_first_attr(binary, "run", "c_wrapper"),
"math": math_lib,
"torch": torch_lib,
"triton": triton_lib,
}
E torch._inductor.exc.InductorError: AttributeError: 'KernelMetadata' object has no attribute 'cluster_dims'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167187
Approved by: https://github.com/oulgen
2025-11-06 08:02:57 +00:00
ed4aa449b6
CustomOp Inline Fusion ( #165952 )
...
Add Inline Fusion Support for Custom Op Autotuning
--------------------------------------------------
This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations.
### Usage
```python
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition."""
...
@torch.library.custom_op("my_lib::matmul_relu", mutates_args={})
def custom_matmul_relu_dk(
a: torch.Tensor, b: torch.Tensor, k_splits: int
) -> torch.Tensor:
return torch.relu(decompose_k_implementation(a, b, k_splits))
register_custom_op_autotuning(
custom_op=custom_matmul_relu_dk,
configs=[
CustomOpConfig(k_splits=2),
CustomOpConfig(k_splits=4),
CustomOpConfig(k_splits=8),
CustomOpConfig(k_splits=32),
CustomOpConfig(k_splits=64),
],
name="decompose_k_autotuned",
input_gen_fns={
"a": lambda fake: torch.randn_like(fake, device='cuda'),
"b": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
```
### How It Works
Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency.
During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens.
Performance Results
Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes):
<img width="782" height="377" alt="Screenshot 2025-11-04 at 12 43 11 AM" src="https://github.com/user-attachments/assets/22131d4c-a8ce-4f55-bdcd-ac758ddad8cd " />
Metric | Result
-- | --
Average Speedup vs ATen | 1.28x
Max Speedup vs ATen | 1.41x
<br class="Apple-interchange-newline">
The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile.
<img width="4874" height="3545" alt="image" src="https://github.com/user-attachments/assets/190a1233-412f-4f34-84cd-9b7cb582f504 " />
**Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled.
--------------
### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already)
FP32:
<img width="738" height="357" alt="Screenshot 2025-11-04 at 12 05 08 AM" src="https://github.com/user-attachments/assets/ee421d22-c426-42f2-8dcd-4dcc547d6219 " />
FP16:
<img width="769" height="403" alt="Screenshot 2025-11-04 at 12 13 49 AM" src="https://github.com/user-attachments/assets/346d1ffc-15af-40b0-9378-cf9b297711c2 " />
The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values.
#### Usage:
Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer.
```python
# Define the matmul+relu function
def matmul_relu(x, y):
return torch.nn.functional.relu(torch.matmul(x, y))
# Compile with inline subgraph fusion enabled
@torch.compile
def compiled_matmul_relu(x, y):
return matmul_relu(x, y)
# Reset dynamo to ensure clean compilation
torch._dynamo.reset()
with config.patch(
{
"max_autotune": True,
# CRITICAL: These two flags enable inline subgraph fusion
"benchmark_epilogue_fusion": False, # Must be False for inline fusion!
"enable_inline_subgraph_fusion": True, # Enable inline fusion
}
):
# Compile and run
result = compiled_matmul_relu(a, b)
torch.cuda.synchronize()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165952
Approved by: https://github.com/PaulZhang12 , https://github.com/eellison
2025-11-06 06:59:10 +00:00
9eebda944d
make narrow_tensor_symint DDE-free ( #166379 )
...
https://github.com/pytorch/pytorch/issues/158081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166361
2025-11-06 06:09:22 +00:00
09d8953fb4
Update tensorpipe submodule ( #167108 )
...
To pick a single change 2b4cd91092 that should fix compilation errors with clang-21
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167108
Approved by: https://github.com/Skylion007
2025-11-06 06:08:13 +00:00
8b2365094d
Expose torch.compiler.config.force_disable_caches as a public API ( #166699 )
...
Exposing this flag as some upstream frameworks (like vLLM) could benefit from knowing whether torch.compile caches are enabled or not to adjust their own caching behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166699
Approved by: https://github.com/oulgen , https://github.com/mlazos
2025-11-06 05:59:05 +00:00
7b423c2d21
[user-streams] Mark stream ops as side effectful ( #167152 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167152
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167141 , #167151
2025-11-06 05:03:18 +00:00
46b3f913b3
[user-streams] Add record/wait ops ( #167151 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167151
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167141
2025-11-06 05:03:18 +00:00
f7b7f40a6f
[user-streams] Enable stream ops to work in eager ( #167141 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167141
Approved by: https://github.com/Lucaskabela
2025-11-06 05:03:18 +00:00
91337ae3ff
[audio hash update] update the pinned audio hash ( #167031 )
...
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml ).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167031
Approved by: https://github.com/pytorchbot
2025-11-06 04:57:05 +00:00
eea951758f
[dynamo, 3.14] disable dynamo cpython tests in 3.14 (again) ( #167000 )
...
The previous PR was not enough to prevent errors caused by cpython dynamo tests in 3.14
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167000
Approved by: https://github.com/mlazos , https://github.com/guilhermeleobas
2025-11-06 04:34:33 +00:00
3feea296a5
torch.fx: add debug-level logging to Interpreter.run_node ( #117351 ) ( #166622 )
...
### Summary
Adds a debug-level logging statement to torch.fx.Interpreter.run_node, as proposed in [#117351 ](https://github.com/pytorch/pytorch/issues/117351 ), to make FX graph execution traceable when debugging or instrumenting model transformations.
When debug logging is enabled, each executed node emits a single structured log line formatted via `LazyString(lambda: n.format_node())`, deferring string construction unless logging is active.
### Example Output
With `logging.DEBUG` enabled:
```
run_node x = x()
run_node add = _operator.add(x, 1)
run_node clamp = torch.clamp(add, min=0.0, max=5.0)
run_node output = output(clamp)
```
With `logging.DEBUG` disabled no additional output is produced (unchanged default behavior).
### Test Plan
Verified locally with Python 3.11 on macOS using a PyTorch build from source.
- With `logging.DEBUG` enabled: each node emits a debug log via LazyString.
- With `logging.DEBUG` disabled: no additional output.
- Confirmed all `Interpreter` tests pass locally:
`pytest test/test_fx.py -k "Interpreter"`
Updated the example output to reflect the new `_format_fx_node` helper and inclusion of `kwargs`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166622
Approved by: https://github.com/aorenste
2025-11-06 04:33:09 +00:00
c3c3653418
[1/N] Add return types of Python functions ( #167162 )
...
This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff `ANN` rules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167162
Approved by: https://github.com/Lucaskabela
2025-11-06 04:32:14 +00:00
f72772b184
[PP] make runtime dbg log print custom actions ( #167113 )
...
Previously the log only printed if the default implementation for an
action was used, now it prints before dispatching to custom registered
actions.
Tested by running on autoparallel graph runner and observing forward
pass action logged
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167113
Approved by: https://github.com/sanketpurandare , https://github.com/Skylion007
2025-11-06 04:20:50 +00:00
981dd71893
Refactor: extract OperatorArgsKwargsView from parseIValuesToPyArgsKwargs ( #166368 )
...
Intended to make it easier to reuse this logic for processing operator arguments as IValues in following PR(s).
Testing: python test/test_python_dispatch.py (broke during development, seems to work now)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166368
Approved by: https://github.com/albanD
2025-11-06 04:18:54 +00:00
d31599f40b
[7/N] Fix unused loop variables in tests ( #167043 )
...
This PR continues to fix or remove unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167043
Approved by: https://github.com/Lucaskabela
2025-11-06 03:36:59 +00:00
85fab6c9b0
Fix duplicate benchmarking entries for addmm ( #166652 )
...
There have been duplicate entries for addmm in dashboard. This PR fixes the duplicate entries issues
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166652
Approved by: https://github.com/yangw-dev
2025-11-06 03:25:03 +00:00
c08ce30d18
[ci][cpu] Update compiler to GCC-13 in jammy-aarch64 ( #166849 )
...
This is needed because manylinux uses GCC-13 since #152825
As a result of the current compiler version mismatches, we've seen tests passing jammy-aarch64 pre-commit CI, but failing for wheels built in manylinux
Related to: #166736
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166849
Approved by: https://github.com/robert-hardwick , https://github.com/malfet , https://github.com/Skylion007 , https://github.com/atalman
2025-11-06 03:14:16 +00:00
e1a1aeaf5b
[1/N] Use key in dict for existence checks ( #167035 )
...
This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167035
Approved by: https://github.com/janeyx99
2025-11-06 02:25:10 +00:00
943227f57b
[c10d] Fix split_group bug by having the parent pg option deep copied ( #167125 )
...
Summary: Inside group_split api, we share the reference of PG option with parent PG if a PG option is not explicitly specified. This is bad because if we split parent pg multiple times, we will run into errors.
Test Plan: UT + internal test.
Differential Revision: D86225394
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167125
Approved by: https://github.com/Skylion007
2025-11-06 02:08:05 +00:00
3a2d75a086
Change template 'Release highlight for proposed Feature'->'New Feature for Release' ( #167145 )
...
Makes it simpler and more clear
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167145
Approved by: https://github.com/huydhn
2025-11-06 02:01:57 +00:00
69af74972b
Bugfix to forward autodiff causing different datatype 2 ( #165784 )
...
Fixes #160513
## The Problem Summary
The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/__init__.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.
## The Fix
The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the `is_wrapped_number` method but this came with a big issue. During the `forward_ad` the derivative of those scalars turned out to be `ZeroTensor`s. The `ZeroTensor` internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the `is_number_wrapped_` property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for `is_wrapped_number_` requires `dim > 0` and a scalar `ZeroTensor` is a meta dtype tensor which complicates things.
So I chose the route of creating a new property called `was_wrapped_number` and exposed this property to the python tensor API. I had to modify the autograd code generation to set `was_wrapped_number` in the mul, add, and div operations in `VariableType.cpp`. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.
I wrote a new ops testing module `TestForwardADWithScalars`. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.
[edit]: Just used `efficientzerotensor` meta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved. The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in the `BinaryOps.cpp` to get appropriate dtype for resulting arithmetic.
@ezyang @OihanJoyot
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165784
Approved by: https://github.com/ezyang
2025-11-06 01:59:53 +00:00
7432676187
[MPS] Fix crash in BCELoss backwards with reduction="none" and inputs with trailing 1s in shape ( #166786 )
...
Fixes #166746 by removing squeezes that caused shape mismatches when calling backwards through `BCELoss(reduction='none')`.
Based on running these tests, it seems MPSGraph can handle inputs without squeezing.
```
python test/test_mps.py TestMPS -k test_bce
python test/test_mps.py TestConsistency -k binary_cross
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166786
Approved by: https://github.com/malfet
2025-11-06 01:55:38 +00:00
fd5edda1ed
Reland "Add model code stack trace to torch.profile ( #166677 )" ( #167110 )
...
```python
python test/test_fx.py -k profiler
```
Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.
We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.
`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.
One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.
`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.
<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a " />
Example code gen'd.
```
def forward(self, args_list):
args_iter = iter(args_list)
arg0_1 = next(args_iter)
arg1_1 = next(args_iter)
args_list.clear()
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
repeated_subgraph0 = self.repeated_subgraph0
_rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None
_rf_invoke_subgraph.__exit__(None, None, None)
_rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
getitem = invoke_subgraph[0]; invoke_subgraph = None
_rf_getitem.__exit__(None, None, None)
return (getitem,)
_rf.__exit__(None, None, None)
def forward(self, arg0_1, arg1_1):
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
_rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
_rf_mul.__exit__(None, None, None)
_rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
sin = torch.ops.aten.sin.default(mul); mul = None
_rf_sin.__exit__(None, None, None)
_rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
_rf_add.__exit__(None, None, None)
return (add,)
_rf.__exit__(None, None, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167110
Approved by: https://github.com/pianpwk
2025-11-06 01:14:27 +00:00
872d1daec2
Avoid DDE in narrow with unbacked start ( #166361 )
...
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-06 01:04:19 +00:00
6cd57e6fc2
[cuBLAS] Force tensor-core-no-reduction algo in cuBLASLt for n=1 cases ( #166735 )
...
Ostensibly useful for batch-invariance purposes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166735
Approved by: https://github.com/ngimel
2025-11-06 00:50:42 +00:00
d29efba8fa
Move almalinux docker image to DEVTOOLSET 13 ( #167018 )
...
1. Update general Almalinux image to Devtoolset 13.
2. Fix ROCm images, missing devtoolset-13
This image used by Linux Job in test-infra
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167018
Approved by: https://github.com/sudharssun , https://github.com/d4l3k
2025-11-06 00:34:40 +00:00
a344069f2a
Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py ( #166969 )
...
This PR adds missing skips for efficient attention tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969
Approved by: https://github.com/jeffdaily
2025-11-05 23:16:51 +00:00
af829c0dad
[ROCm] Skip nvfp4 tests on ROCm ( #167066 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066
Approved by: https://github.com/jeffdaily , https://github.com/slayton58
2025-11-05 23:15:17 +00:00
3869aa115b
fix fr reset api ( #166970 )
...
Summary:
- there are various places that access fr's `entries_` field
- if we empty the entries_ on reset, the accesses can result in an error
- so we only perform a soft delete instead of clearing out the entries copletely
- only reset id_ on the reset
- keep track of a reset_epoch which increments everytime reset is called
- dump_entries only returns entries from the latest epoch
- api's that access entries also check if the reset epoch matches
- make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com ). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970 ).
* #166972
* #166971
* __->__ #166970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970
Approved by: https://github.com/fduwjj
2025-11-05 23:06:00 +00:00
47eb34b7ac
[ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance ( #167094 )
...
# Summary
This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations.
The key change: specialize `OpaqueType<N>` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy.
The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0.
| GPU | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup |
| ---- | ----------- | ----------- | ------------------- | ---------- | ------- |
| H100 | (16, 1e6) | int32 | 1.61 | 1.37 | 1.18× |
| H100 | (1, 1e8) | int32 | 6.6 | 5.0 | 1.3× |
| H20 | (16, 1e6) | int64 | 3.57 | 3.03 | 1.18× |
| H20 | (1, 1e8) | int64 | 19.3 | 13.0 | 1.48× |
# Analysis
`torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType<sizeof(value_t)>` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.`
## The Problem
The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types:
- `char data[8]:` Compiler may allocate 8 registers (one per byte)
- `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling)
This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy.
From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%.
## The Solution
Specialize `OpaqueType<N>` for common sizes using native integer types:
```
// Before
template <int N> struct alignas(N) OpaqueType { char data[N]; };
// After
template <int N> struct alignas(N) OpaqueType { char data[N]; }; // fallback
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
```
This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation.
# Testing & Compatibility
## Testing:
✅ Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3)
✅ Register usage reduction verified with NSight Compute
✅ Linter passes
## Compatibility:
✅ No API/ABI changes
✅ Template instantiation count unchanged
# Reference
For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094
Approved by: https://github.com/ngimel , https://github.com/Skylion007
2025-11-05 22:34:19 +00:00
08200280ce
[CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub ( #166991 )
...
While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages.
Add it to __all__ so that people can still use it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991
Approved by: https://github.com/XilunWu
ghstack dependencies: #166456 , #166501
2025-11-05 22:22:55 +00:00
ad7a57262c
[12/N] Apply ruff UP035 rule ( #166929 )
...
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 22:06:19 +00:00
711a775878
fix nccl estimations ( #167093 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093
Approved by: https://github.com/kwen2501 , https://github.com/eellison
2025-11-05 22:01:49 +00:00
e9a688f02e
[DebugMode] output, tensor id annotations for DebugMode ( #165076 )
...
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`
Example output for `test_debug_mode_mm`, with both enabled:
```
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0)
aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
redistribute_input(1, S(0) -> R)
redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32]
_c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32]
aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32]
<method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P
aten::sum(dt$14: f32[8, 32]| S(0))
aten::sum(t$15: f32[1, 32]) -> t$16: f32[]"""
```
Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-05 22:00:11 +00:00
e69aaaf45a
[user-streams] Add backward test ( #167021 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167019
2025-11-05 21:24:44 +00:00
fd8f368d31
[user-streams] Add graph annotation checks ( #167019 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019
Approved by: https://github.com/Lucaskabela
2025-11-05 21:24:44 +00:00
13d2cc7bd2
Remove python workaround for ContextDecorator ( #167049 )
...
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-05 20:56:04 +00:00
c6c913d18e
Add torch::stable::Tensor sizes and strides ( #165153 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991 , #165152
2025-11-05 20:55:34 +00:00
ef3f953966
Revert "[DebugMode] output, tensor id annotations for DebugMode ( #165076 )"
...
This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f.
Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159 ))
2025-11-05 20:52:43 +00:00
ea44f12bce
[13/N] Apply ruff UP035 rule ( #167048 )
...
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
a74fe75c45
Don't hardcode double argument for reduction base ( #166951 )
...
Fixes https://github.com/pytorch/pytorch/issues/43254
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951
Approved by: https://github.com/ngimel , https://github.com/Skylion007
ghstack dependencies: #166813
2025-11-05 20:34:15 +00:00
6d30666bc1
Revert "[12/N] Apply ruff UP035 rule ( #166929 )"
...
This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52.
Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076 . @cyyever Please re-merge after revert of #165076 . ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596 ))
2025-11-05 20:02:47 +00:00
8e8cbb85ee
Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen ( #166890 )"
...
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.
Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038 ))
2025-11-05 19:42:39 +00:00
fbd70fb84e
Update typing docs to reference pyrefly ( #166883 )
...
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584
[Inductor] Naive foreach autotune support ( #162053 )
...
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.
Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |
After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |
num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos , https://github.com/naromero77amd , https://github.com/jeffdaily
Co-authored-by: Nichols A. Romero <nick.romero@amd.com >
2025-11-05 19:27:23 +00:00
6052a01b71
[BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py ( #167022 )
...
Provides type coverage to torch/_dynamo/variables/dicts.py
Coverage report:
`mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log`
Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022
Approved by: https://github.com/Skylion007
2025-11-05 19:18:35 +00:00
14b153bcf2
include DTensor metadata when pretty-printing fx.Graphs ( #166750 )
...
Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it:
<img width="1446" height="582" alt="image" src="https://github.com/user-attachments/assets/99ea5ce6-1008-4ba5-b58e-542cd34a340b " />
```
import torch
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.distributed.tensor import distribute_tensor, Shard, Replicate
from torch.utils._debug_mode import DebugMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree
world_size = 8
device_type = "cpu"
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size)
device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,))
dim = 128
A = torch.randn(8, dim)
B = torch.randn(dim, dim)
dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_()
dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_()
def f(dA, dB):
dy = dA @ dB
loss = dy.sum()
loss.backward()
return dA.grad, dB.grad
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
gm = make_fx(f, tracing_mode='fake')(dA, dB)
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
gm.graph.eliminate_dead_code()
gm.recompile()
gm.print_readable(colored=True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750
Approved by: https://github.com/ezyang , https://github.com/wconstab , https://github.com/Skylion007
2025-11-05 18:58:54 +00:00
641de23c96
ci: Add aarch64 docker builds for modern clang ( #166416 )
...
Should enable us to build using some arm optimizations that are only
available on the newest versions of clang.
Signed-off-by: Eli Uriegas <eliuriegas@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416
Approved by: https://github.com/malfet
2025-11-05 18:55:56 +00:00
89165c0a2b
Update triton to 3.5.1 release ( #166968 )
...
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela , https://github.com/njriasan
2025-11-05 18:26:34 +00:00
dcc2ba4ca4
Add some code for exploring the space of accessible size/stride configs via plain views ( #167076 )
...
We are working on a translation from as_strided to view operations, but
only when the as_strided is representable as a plain view. A useful
testing utility in this situation is the ability to enumerate all valid
views on an original tensor. So we have a small test here that shows
it is possible.
To avoid an explosion of states, we don't handle permutes and size=1,
which are degenerate cases (you can always do a single permute and
a series of unsqueezes to get to the final desired state.)
Authored with claude code assistance.
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076
Approved by: https://github.com/albanD
ghstack dependencies: #166868 , #166867
2025-11-05 18:25:19 +00:00
ad5c7c20e0
Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI ( #165922 )"
...
This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578.
Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312 ))
2025-11-05 18:13:57 +00:00
c86540f120
Revert "Add model code stack trace to torch.profile ( #166677 )"
...
This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8.
Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160 ))
2025-11-05 18:11:11 +00:00
c17aa0f113
[ROCm] Enable group gemm through CK ( #166334 )
...
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334
Approved by: https://github.com/atalman
2025-11-05 18:03:59 +00:00
4ff068c33a
[Code Clean] Replace assert with if statement and raise AssertionError ( #166935 )
...
Including:
- `torch/profiler/profiler.py`
Fixes part of #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935
Approved by: https://github.com/fffrog , https://github.com/albanD
2025-11-05 17:59:16 +00:00
0c7a4a6b48
[Inductor] Fix unbacked float symbol handling in kernel codegen ( #166890 )
...
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.
Fixes : #166888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890
Approved by: https://github.com/eellison
2025-11-05 17:50:08 +00:00
f93ee16fb6
[CI] Parse xml and upload json while running ( #166988 )
...
Then we can point an ClickHouse ingestor at this s3 path and get them into ClickHouse while the job is running.
use filelock to make sure each json is uploaded once so we don't end up with dups in ClickHouse
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166988
Approved by: https://github.com/izaitsevfb
2025-11-05 17:19:24 +00:00
9c2c3dbc15
Revert "Update triton to 3.5.1 release ( #166968 )"
...
This reverts commit b4e4ee81d386db922d8f63359f9870eff1f44052.
Reverted https://github.com/pytorch/pytorch/pull/166968 on behalf of https://github.com/malfet due to It might have caused deadlock/test timeouts, see d4dcd0354c/1 ([comment](https://github.com/pytorch/pytorch/pull/166968#issuecomment-3492399396 ))
2025-11-05 17:12:30 +00:00
d4dcd0354c
[pytree][dynamo] add test to ensure tree_map preserves dict order ( #166236 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166236
Approved by: https://github.com/mlazos
2025-11-05 17:04:40 +00:00
aba2fa3259
Fix clang-21 warnings ( #166859 )
...
Fixes compiler warnings thrown by Clang-21
Fixes #166755
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166859
Approved by: https://github.com/aditew01 , https://github.com/fadara01 , https://github.com/malfet
2025-11-05 16:55:51 +00:00
d2d13bf62d
Invert unary read and write for fusion ( #161404 )
...
For [this repro](https://gist.github.com/eellison/75a99616a0fcca0436316bbfd8987fae ) enables fusion of `to_blocked` with the prior `to_mx` calculation, so that there is only a single kernel per tensor, resulting in a 10% speedup of the non conversion code (need to update my local devserver to 12.9 to time the matmul as well).
The `to_mx` kernel has a contiguous write:
```Py
op6_op7: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op6_op7.writes = [MemoryDep('buf6', c0, {c0: 2097152}), MemoryDep('buf7', c0, {c0: 67108864})]
op6_op7.unmet_dependencies = []
op6_op7.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 67108864})]
op6_op7.outputs = [
buf6: ComputedBuffer
buf6.layout = FixedLayout('cuda:0', torch.float32, size=[8192, 256], stride=[256, 1])
buf6.users = [
NodeUser(node=SchedulerNode(name='op7'), can_inplace=False, is_weak=False),
NodeUser(node=SchedulerNode(name='op9'), can_inplace=False, is_weak=False),
]
buf7: ComputedBuffer
buf7.layout = FixedLayout('cuda:0', torch.float8_e4m3fn, size=[8192, 256, 32], stride=[8192, 32, 1])
buf7.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```
While the `to_blocked` has a single discontiguous read and a single contiguous write.
```Py
op9: SchedulerNode(ComputedBuffer)
op9.writes = [MemoryDep('buf9', c0, {c0: 2097152})]
op9.unmet_dependencies = [ MemoryDep('buf6', 32768*((c0//32768)) + 8192*(((ModularIndexing(c0, 1, 16))//4)) + 256*(ModularIndexing(c0, 16, 32)) + 4*(ModularIndexing(c0, 512, 64)) + (ModularIndexing(ModularIndexing(c0, 1, 16), 1, 4)), {c0: 2097152})]
op9.met_dependencies = []
op9.outputs = [
buf9: ComputedBuffer
buf9.layout = FixedLayout('cuda:0', torch.float8_e8m0fnu, size=[2097152], stride=[1])
buf9.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)]
]
```
To enable fusion, we invert the read, giving op9 and contiguous read and discontiguous write. More explanation here: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9
[Tlparse with this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ).
[Tlparse without this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 ).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161404
Approved by: https://github.com/shunting314
2025-11-05 16:10:52 +00:00
7a6ff88196
Widen ops support to take in IntHOArrayRef vs only std::vec ( #165152 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991
2025-11-05 16:00:24 +00:00
59563dfe56
Refactor out headeronly ArrayRef ( #164991 )
...
Differential Revision: [D85091961](https://our.internmc.facebook.com/intern/diff/D85091961 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991
Approved by: https://github.com/swolchok
2025-11-05 16:00:24 +00:00
5c639466f7
Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel ( #167003 )"
...
This reverts commit 658c5f879c37142b1df51c7eb6c5a5bb06318597.
Reverted https://github.com/pytorch/pytorch/pull/167003 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19093785744/job/54553796743 ) [HUD commit link](658c5f879c ) ([comment](https://github.com/pytorch/pytorch/pull/167003#issuecomment-3491527704 ))
2025-11-05 14:30:15 +00:00
0b4dd08e04
[dynamo] Introduce _set_lru_cache ( #167038 )
...
Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926 . This PR can't be defaulted on, that would be terrible for cache look up times.
There's a proper fix in the works by @williamwen42.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038
Approved by: https://github.com/williamwen42
2025-11-05 09:05:11 +00:00
edd8d356b6
fixes keyerror when loading parameter with unsaved optimizer state ( #165228 )
...
Fixes #164257
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165228
Approved by: https://github.com/fegin
2025-11-05 08:07:46 +00:00
658c5f879c
[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel ( #167003 )
...
Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036?fbclid=IwY2xjawN3RL1leHRuA2FlbQIxMQBicmlkETExOEcxcnVhNVA1TzRSVmhiAR63GOEpJbZA-JhQ0CSj9ji8H_RHBUhDwYNDtxjOYfDol56OGqmC4r7jPP96Fw_aem_bWvtMfVifLQrnpv1YB_fJA , which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs.
Test Plan:
Inductor test (fbcode):
`INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"`
Tritonbench (fbcode):
`clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy`
Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy`
Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`
Differential Revision: D86231180
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167003
Approved by: https://github.com/jananisriram
2025-11-05 06:51:30 +00:00
59a6c83dfe
[fx] Add strict argument validation to Interpreter.boxed_run ( #166784 )
...
# Summary
This PR fixes an issue where `torch.fx.Interpreter.boxed_run` would silently ignore extra input arguments instead of validating the argument count.
Previously, `boxed_run` would only consume as many inputs as there were placeholder nodes and then clear the entire `args_list`, hiding potential bugs. This change introduces a strict check to ensure `len(args_list)` matches the number of placeholder nodes, raising a `RuntimeError` on a mismatch.
Fixes #166583 .
# Changes
* Validate `len(args_list)` against the number of placeholder nodes at the beginning of `boxed_run`.
* Raise a `RuntimeError` with a clear message ("extra arguments" or "missing arguments") if the counts do not match.
* Move `args_list.clear()` to only execute after successful validation and environment setup. If an error is raised, `args_list` is preserved for debugging.
# Testing
* Added `test_interpreter_boxed_run_argument_validation` to `test/test_fx.py`.
* This test covers three scenarios:
1. Correct number of arguments (succeeds, `args_list` is cleared).
2. Extra arguments (raises `RuntimeError`, `args_list` is preserved).
3. Missing arguments (raises `RuntimeError`, `args_list` is preserved).
# User-facing impact / BC notes
This is a bug fix. Code that was incorrectly passing the wrong number of arguments to `boxed_run` will now fail fast with a `RuntimeError` instead of executing silently with unintended inputs. Correctly written code is unaffected.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166784
Approved by: https://github.com/ezyang , https://github.com/xmfan
2025-11-05 06:39:32 +00:00
431dfe8692
[dynamo] extend collections.defaultdict support with *args, **kwargs and custom default_factory ( #166793 )
...
Fixes #166238
Extend `collections.defaultdict` to accept `*args` and `**kwargs` in the constructor. And also support custom `default_factory`, such as `dd.default_factory` (a `GetAttrVariable`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166793
Approved by: https://github.com/guilhermeleobas
2025-11-05 06:09:39 +00:00
c00696144d
Add model code stack trace to torch.profile ( #166677 )
...
```python
python test/test_fx.py -k profiler
```
Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.
We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.
`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.
One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.
`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.
<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a " />
Example code gen'd.
```
def forward(self, args_list):
args_iter = iter(args_list)
arg0_1 = next(args_iter)
arg1_1 = next(args_iter)
args_list.clear()
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
repeated_subgraph0 = self.repeated_subgraph0
_rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None
_rf_invoke_subgraph.__exit__(None, None, None)
_rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
getitem = invoke_subgraph[0]; invoke_subgraph = None
_rf_getitem.__exit__(None, None, None)
return (getitem,)
_rf.__exit__(None, None, None)
def forward(self, arg0_1, arg1_1):
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
_rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
_rf_mul.__exit__(None, None, None)
_rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
sin = torch.ops.aten.sin.default(mul); mul = None
_rf_sin.__exit__(None, None, None)
_rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
_rf_add.__exit__(None, None, None)
return (add,)
_rf.__exit__(None, None, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677
Approved by: https://github.com/ezyang
2025-11-05 06:08:34 +00:00
9ffc480c5a
Add min/max support for barebones uint types ( #166813 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166813
Approved by: https://github.com/Skylion007
2025-11-05 04:44:21 +00:00
14956eaef4
[ROCm][CI] revert ROCm magma commit hash to last known good ( #167044 )
...
PR https://github.com/pytorch/pytorch/pull/166693 updated the magma commit hash but this has been linked to ROCm 7.1 CI failures. Go back to last known working magma version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167044
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
2025-11-05 04:18:04 +00:00
066c5c57a9
Fix typo in gloo_hip library name ( #166502 )
...
The typo was never noticed; conditions to enable it require system gloo: `-DUSE_SYSTEM_GLOO=ON -DUSE_GLOO=ON -DUSE_DISTRIBUTED=ON -DUSE_ROCM=ON`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166502
Approved by: https://github.com/jerryzh168 , https://github.com/cyyever
2025-11-05 04:14:01 +00:00
08ef852a4b
[unified v2][apple] Clean up APPLETVOS from caffe2 ( #166953 )
...
Summary: This is not used, so delete it
Test Plan:
```
$ buck targets xplat/... > /dev/null
```
Reviewed By: dtolnay
Differential Revision: D86125712
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166953
Approved by: https://github.com/seemethere
2025-11-05 03:09:56 +00:00
56fc99915b
Fix typos in complex numbers docs ( #166671 )
...
This PR fixes two small typos in the complex numbers docs:
1. "numbercial" -> "numerical"
2. "easily to switch" -> "easily switch to"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166671
Approved by: https://github.com/jcaip , https://github.com/Arpitha781 , https://github.com/mlazos , https://github.com/cyyever
2025-11-05 03:05:06 +00:00
5863ba1b2e
[12/N] Apply ruff UP035 rule ( #166929 )
...
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 03:03:41 +00:00
a743f9eeb5
Revert "Avoid DDE in narrow with unbacked start ( #166361 )"
...
This reverts commit ed45c5f38df6aa419c67d139d932c2c94404223a.
Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/malfet due to Looks like it broke test_torchfuzz subtests, see 01e6e35c7f/1 ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3488916766 ))
2025-11-05 02:39:55 +00:00
53b03f1a2b
Revert "make narrow_tensor_symint DDE-free ( #166379 )"
...
This reverts commit d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c.
Reverted https://github.com/pytorch/pytorch/pull/166379 on behalf of https://github.com/malfet due to Need to revert previous PR in the stack ([comment](https://github.com/pytorch/pytorch/pull/166379#issuecomment-3488910172 ))
2025-11-05 02:36:46 +00:00
cd5d810c3a
Annotation should be deepcopied ( #167017 )
...
The annotation should be deepcopied. Otherwise all nodes with the same `seq_nr` share the same underlying dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167017
Approved by: https://github.com/yiming0416
2025-11-05 02:22:33 +00:00
01e6e35c7f
Send / recv support in local tensor ( #166595 )
...
This change introduces LocalRunnerMode that allows you to run multiple
SPMD functions concurrently. SMPD functions are executing one at a time,
yielding execution capability while waiting for send or receive operations
to complete. Send and receive peer operations only supported while running
under LocalRunnerMode.
The example test in this change demonstrates how ranks are sending data
to the next peer and receiving data from the previous peer (ring).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166595
Approved by: https://github.com/wconstab , https://github.com/ezyang
2025-11-05 01:36:44 +00:00
bcd159bcdd
Fix the vmap op fallback bug ( #166032 )
...
## The bug
In some environments, if run:
```py
def inner_func(x):
return x.to(torch.float32, memory_format=torch.channels_last)
x = torch.randn(2, 2, 3, 4, device="cpu", dtype=torch.float64)
torch.vmap(inner_func)(x)
```
we get:
```
E RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops.
```
Otherwise, it would always fallback and result in an error for ops like `to.dtype` and `to.dtype_layout` even the kernels are registered.
## The cause
The alias key of `FuncTorchBatchedDecomposition` is not properly translated to runtime dispatch keys when updating the dispatch table of `OperatorEntry::dispatchTable_`. [[link](984b096d10/aten/src/ATen/core/dispatch/OperatorEntry.cpp (L500-L501) )]
The [`getRuntimeDispatchKeySet`](f3fa560dec/c10/core/DispatchKeySet.cpp (L62) ) use if-else to translate all other alias keys but `FuncTorchBatchedDecomposition`.
This would result in not finding the kernel in many cases.
## The fix
This PR adds one more `if` statement to `getRuntimeDispatchKeySet` to map `FuncTorchBatchedDecomposition` to the corresponding runtime dispatch key, `FuncTorchBatched`.
So, that the dispatch table can be properly updated.
This fix allows people to use ops inside vmaps in more environments and across more compilers.
## Why does it work without the PR
As long as the `FuncTorchBatchedDecomposition` [[link](51319ca090/aten/src/ATen/functorch/BatchRulesDecompositions.cpp (L35) )]
is registered before the fallback method of `FuncTorchBatched` [[link](d311a3d1dc/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp (L759) )], everything runs fine.
In this case, it relies on the registration of the fallback method to update the dispatch table, which flushes all the kernels in `OperatorEntry::kernels_` into `dispatchTable_`, among which there are kernels registered with `FuncTorchBatchedDecomposition`.
## When does it fail
However, the order of the op registration and the fallback registration is not garanteed at all.
It relies on the C++ static initialization order, which varies from environment to environment.
On our compiler, it the fallback registration goes first and the alias key kernels under `FuncTorchBatchedDecomposition` comes later and not get flushed into the dispatch table by the fallback registration.
Therefore, it cannot find the kernel for it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166032
Approved by: https://github.com/albanD
2025-11-05 01:16:58 +00:00
64ae31c5d3
[HOP][print] Add HOP subclass for printing ( #166660 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166660
Approved by: https://github.com/angelayi , https://github.com/anijain2305
Co-authored-by: Angela Yi <yiangela7@gmail.com >
2025-11-05 01:16:49 +00:00
45da6e1fe1
[CD] Upload XPU inductor benchmark test reports to s3 ( #166954 )
...
As the title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166954
Approved by: https://github.com/atalman
2025-11-05 01:02:57 +00:00
39160dba0c
shrink_group implementation to expose ncclCommShrink API ( #164518 )
...
Closes #164529
To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink ) API to PyTorch.
This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.
For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-11-05 00:54:40 +00:00
f2fbc81c50
[RFC] Add experimental Pallas TorchInductor backend ( #166822 )
...
Very simple Pallas TorchInductor backend
Given
```
import torch
def f(x, y):
return x.sin() + y
torch._inductor.config.cuda_backend="pallas"
x = torch.randn(4).cuda()
y = torch.randn(4).cuda()
compiled = torch.compile(f, backend="inductor", fullgraph=True)
torch.testing.assert_close(compiled(x, y), f(x, y))
```
it outputs
```
import torch
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from torch.utils import dlpack as torch_dlpack
def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0):
tmp0 = in_ptr0[...]
tmp1 = jnp.sin(tmp0)
tmp2 = in_ptr1[...]
tmp3 = tmp1 + tmp2
out_ptr0[...] = tmp3
def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None):
# Convert Torch -> JAX for inputs
in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0))
in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1))
# Prepare output spec from PyTorch tensor
# Map PyTorch dtype to JAX dtype string
_torch_dtype_to_jax = {
torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,
torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,
torch.uint8: jnp.uint8, torch.bool: jnp.bool_,
}
out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype])
compiled = pl.pallas_call(
lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs),
out_shape=out_spec,
grid=(1,),
)
res = compiled(in_ptr0_jax, in_ptr1_jax)
# Copy result back into the provided torch output tensor
res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))
out_ptr0.copy_(res_t)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822
Approved by: https://github.com/jansel
ghstack dependencies: #166976 , #166982
2025-11-05 00:52:41 +00:00
4271ffe918
don't produce invalid grid configs ( #166974 )
...
Proper fix for #164048 , fixes gather too, reverts #164049
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166974
Approved by: https://github.com/eqy
2025-11-05 00:20:27 +00:00
7eefcfb1db
[BE][Typing][Dynamo] Type torch/_dynamo/variables/ctx_manager.py ( #166878 )
...
Provides type coverage to torch/_dynamo/variables/ctx_manager.py
Coverage report:
`mypy torch/_dynamo/variables/ctx_manager.py --linecount-report /tmp/coverage_log`
Compare before to after - we go from 0 lines and 0 funcs covered to 1541 lines and 144 funcs covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166878
Approved by: https://github.com/Skylion007
2025-11-04 23:54:18 +00:00
4b12c0344d
Add default .github/copilot-instructions.md and item in .gitignore for allowing local changes ( #166864 )
...
Fixes [#166850 ](https://github.com/pytorch/pytorch/issues/166850 )
- Create a default `.github/copilot-instructions.md` file (used Claude Sonnet 4.5 in Copilot).
- Add `.github/copilot-instructions.md` to the `.gitignore` file.
The prompt used is below, which is preset by Copilot:
```
Analyze this codebase to generate or update `.github/copilot-instructions.md` for guiding AI coding agents.
Focus on discovering the essential knowledge that would help an AI agents be immediately productive in this codebase. Consider aspects like:
- The "big picture" architecture that requires reading multiple files to understand - major components, service boundaries, data flows, and the "why" behind structural decisions
- Critical developer workflows (builds, tests, debugging) especially commands that aren't obvious from file inspection alone
- Project-specific conventions and patterns that differ from common practices
- Integration points, external dependencies, and cross-component communication patterns
Source existing AI conventions from `**/{.github/copilot-instructions.md,AGENT.md,AGENTS.md,CLAUDE.md,.cursorrules,.windsurfrules,.clinerules,.cursor/rules/**,.windsurf/rules/**,.clinerules/**,README.md}` (do one glob search).
Guidelines (read more at https://aka.ms/vscode-instructions-docs ):
- If `.github/copilot-instructions.md` exists, merge intelligently - preserve valuable content while updating outdated sections
- Write concise, actionable instructions (~20-50 lines) using markdown structure
- Include specific examples from the codebase when describing patterns
- Avoid generic advice ("write tests", "handle errors") - focus on THIS project's specific approaches
- Document only discoverable patterns, not aspirational practices
- Reference key files/directories that exemplify important patterns
Update `.github/copilot-instructions.md` for the user, then ask for feedback on any unclear or incomplete sections to iterate.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166864
Approved by: https://github.com/malfet
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com >
2025-11-04 23:53:56 +00:00
661b639663
use_cpp_bmm_template supports more use cases ( #165469 )
...
Summary: In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous, but the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. This diff specifically checks for contiguity within the 2D matrix of each batch, and enables more uses for cpp bmm template.
Differential Revision: D84561331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165469
Approved by: https://github.com/desertfire
2025-11-04 23:47:17 +00:00
0cd809f60c
[inductor][AMD] Filter out invalid Triton Configs for MI350X _scaled_mm ( #166442 )
...
Summary: Mirrors change done in D81180838 but for inductor. Without this change, running _scaled_mm on MI350X accelerator would crash.
Test Plan: HIP_VISIBLE_DEVICES=7 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run mode/opt-amd-gpu -m rocm70 -c fbcode.rocm_arch=mi350 scripts/jchunx/gemm:scaled_mm_microbench -- --csv_file /home/jchunx/scripts/fp8_shapes.csv --backend triton,aten --fast_accum=true 2>&1 | tee ~/logs/scaled_mm.log
Reviewed By: bilal
Differential Revision: D85694383
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166442
Approved by: https://github.com/bilal
2025-11-04 23:47:11 +00:00
a96728d188
Clarify safety of CUDA graph memory pool sharing across graphs that are replayed in arbtirary order. ( #166975 )
...
Some users at pytorch conference were asking me about whether it is safe to share a memory pool among cuda graphs that never run concurrently, but may run in arbitrary order, if they don't depend upon each other's output. Even though your capture order doesn't match replay order in this situation, this is safe. However, our documents confusingly said this wasn't allowed. This update is intended to help with that. Since vLLM essentially depends upon this behavior, I call it out specifically.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166975
Approved by: https://github.com/eellison , https://github.com/BoyuanFeng
2025-11-04 23:36:03 +00:00
c1e91bd4c3
[export] Codemod unittests to use new graph capture API ( #166957 )
...
Summary:
as title.
Test Plan:
pytest test/functorch/test_aot_joint_with_descriptors.py
pytest test/higher_order_ops/test_local_map.py
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166957
Approved by: https://github.com/angelayi , https://github.com/yushangdi
2025-11-04 22:55:30 +00:00
d7e2d0ad30
make narrow_tensor_symint DDE-free ( #166379 )
...
https://github.com/pytorch/pytorch/issues/158081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166361
2025-11-04 22:43:15 +00:00
81038fd326
Revert "Add model code stack trace to torch.profile ( #166677 )"
...
This reverts commit e8052f2f99de1fb7284e38082ff5714e17cd9562.
Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/malfet due to Broke lint, please rebase, we've moved from mypy to pyrefly ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3488219996 ))
2025-11-04 22:26:35 +00:00
e020fb3431
[Minor][Inductor] move some combo kernel log from warning to debug ( #166993 )
...
Combo kernel warns for long reduction and large pointwise. This becomes too spammy for users such as vLLM.
This PR moves these logs from warn to debug. I validated the spammy log is removed on llama-3.1-8B.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166993
Approved by: https://github.com/zou3519 , https://github.com/eellison
2025-11-04 22:09:27 +00:00
e8052f2f99
Add model code stack trace to torch.profile ( #166677 )
...
```python
python test/test_fx.py -k profiler
```
Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen.
We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace.
`map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry.
One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove.
`aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True.
<img width="1188" height="565" alt="Screenshot 2025-10-31 at 4 40 52 PM" src="https://github.com/user-attachments/assets/41e8113f-3e6d-439b-bffd-cfbf0c03a47a " />
Example code gen'd.
```
def forward(self, args_list):
args_iter = iter(args_list)
arg0_1 = next(args_iter)
arg1_1 = next(args_iter)
args_list.clear()
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__()
repeated_subgraph0 = self.repeated_subgraph0
_rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__()
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None
_rf_invoke_subgraph.__exit__(None, None, None)
_rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__()
getitem = invoke_subgraph[0]; invoke_subgraph = None
_rf_getitem.__exit__(None, None, None)
return (getitem,)
_rf.__exit__(None, None, None)
def forward(self, arg0_1, arg1_1):
_rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__()
_rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__()
mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
_rf_mul.__exit__(None, None, None)
_rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__()
sin = torch.ops.aten.sin.default(mul); mul = None
_rf_sin.__exit__(None, None, None)
_rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__()
add = torch.ops.aten.add.Tensor(sin, 5); sin = None
_rf_add.__exit__(None, None, None)
return (add,)
_rf.__exit__(None, None, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677
Approved by: https://github.com/ezyang
ghstack dependencies: #166676
2025-11-04 22:05:36 +00:00
a64c7d7404
[DebugMode] output, tensor id annotations for DebugMode ( #165076 )
...
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`
Example output for `test_debug_mode_mm`, with both enabled:
```
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0)
aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
redistribute_input(1, S(0) -> R)
redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
_c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32]
_c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32]
aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32]
<method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P
aten::sum(dt$14: f32[8, 32]| S(0))
aten::sum(t$15: f32[1, 32]) -> t$16: f32[]"""
```
Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-04 21:30:46 +00:00
cdca63db8c
Fix quoting in pytest_cache.py invocations ( #166955 )
...
Especially the job identifier can contain spaces so needs to be quoted
Fixes e.g. https://github.com/pytorch/pytorch/actions/runs/19063797853/job/54449422160#step:15:52
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166955
Approved by: https://github.com/Skylion007
2025-11-04 21:28:19 +00:00
ed45c5f38d
Avoid DDE in narrow with unbacked start ( #166361 )
...
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-04 21:24:57 +00:00
7f0e932136
[dynamo] don't use LocalSource for temp variables created by side_effects ( #166917 )
...
Fixes https://github.com/pytorch/pytorch/issues/166900
Implementation notes:
- I tried to disallow guard generation before side effect application in order to futureproof improper guard generation. However, this was not feasible since it is possible to realize lazy VTs while generating side effects (e.g. realizing a constant variable that is used in a deque update).
- `codegen_save_tempvars` now generates `TempLocalSource` for create temporary variables now, so that they won't get confused with `LocalSource` - we should error out when we attempt to create guards for `TempLocalSource`. I considered using `SyntheticLocalSource`, but that has additional `subguards_allowed` behavior that we may not want to have for temp variables.
- We moved the guard installation for constant user-defined pytree objects from `as_python_constant` to `__init__`. Objects created outside the compile-region will be guarded, while objects created inside the compile-region will not be guarded.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166917
Approved by: https://github.com/anijain2305
2025-11-04 21:16:18 +00:00
2673f8b007
Fix torch.linalg.eig inductor stride mismatch ( #162484 )
...
Fixes #159445
### Summary
- Fixed a stride layout issue in the `torch.linalg.eig` meta kernel that prevented successful compilation with the inductor backend. The meta kernel was producing incorrect row-major strides.
- LAPACK/BLAS libraries (underlying implementation) expect column-major layout
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162484
Approved by: https://github.com/isuruf
2025-11-04 21:06:58 +00:00
4e1bd16738
inductor: Switch quiesce to use timer based implementation. ( #166581 )
...
Major change is to switch to a timer based implementation. Additionally,
we get rid of the context manager for turning of the compile pool. We
still have the warmup calls.
Note that this only modifies the async_compile methods, the fx pool is
left running.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166581
Approved by: https://github.com/masnesral
ghstack dependencies: #166467
2025-11-04 21:01:49 +00:00
871d0cd196
If USE_CUDA=1 is set, do not fallback to no CUDA ( #166982 )
...
So many times i build pytorch only to notice chef nuked my nvcc and i wasted 30m building a cpu version, lets hard error fast
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166982
Approved by: https://github.com/malfet
ghstack dependencies: #166976
2025-11-04 20:51:14 +00:00
2bba37309b
[inductor] runtime estimations disable use_nccl_estimator by default ( #166973 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166973
Approved by: https://github.com/eellison , https://github.com/jathu
2025-11-04 20:48:22 +00:00
b4e4ee81d3
Update triton to 3.5.1 release ( #166968 )
...
This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968
Approved by: https://github.com/Lucaskabela , https://github.com/njriasan
2025-11-04 20:34:13 +00:00
3283eaa5ba
Upload test stats for trunk/sha tag ( #166916 )
...
Noticed that workflow runs for `trunk/{sha}` tags (issued by autorevert) don't populate test_run_s3 Clickhouse table.
This PR is addressing this by changing the gate condition to upload tests stats.
see https://github.com/pytorch/pytorch/actions/runs/19054297956/job/54421254448#step:8:23
as an evidence that HEAD_BRANCH is correctly populated for trunk tags.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166916
Approved by: https://github.com/huydhn , https://github.com/clee2000
2025-11-04 20:33:56 +00:00
397d9fe2ae
[inductor] coordesc not tune XBLOCK for mix-order-reduction ( #166669 )
...
For mix-order reduction, we current force XBLOCK to be 1 to simplify codegen. Don't tune it in CDT.
Differential Revision: [](https://our.internmc.facebook.com/intern/diff/ )
Differential Revision: [D86224689](https://our.internmc.facebook.com/intern/diff/D86224689 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166669
Approved by: https://github.com/jansel , https://github.com/mlazos , https://github.com/eellison , https://github.com/v0i0
2025-11-04 20:27:07 +00:00
d77c24caac
Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel ( #165036 )"
...
This reverts commit 0e1a88904f4a5e30634b196678b56e1d6ec074f5.
Reverted https://github.com/pytorch/pytorch/pull/165036 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19059329909/job/54439919668 ) [HUD commit link](0e1a88904f ) ([comment](https://github.com/pytorch/pytorch/pull/165036#issuecomment-3487846555 ))
2025-11-04 20:13:33 +00:00
cef98ae5cb
[aotd] Compiled saved tensor hooks context ( #166887 )
...
Draft to expose compiled saved tensor hook context to selectively apply them.
Exposing node, fw_graph, bw_graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166887
Approved by: https://github.com/bdhirsh
2025-11-04 20:07:00 +00:00
52ea135f77
[BE] Delete Python-3.9 stdlib definitions from torch.package ( #166768 )
...
And simplify the entire function to just assert and return
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768
Approved by: https://github.com/cyyever , https://github.com/atalman
2025-11-04 19:33:14 +00:00
a5f3035aaf
More pyrefly local errors ( #166976 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976
Approved by: https://github.com/maggiemoss , https://github.com/Skylion007
2025-11-04 18:51:35 +00:00
1d3f5e19da
[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI ( #165922 )
...
Fix and regression test for https://github.com/pytorch/pytorch/issues/165801
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922
Approved by: https://github.com/malfet , https://github.com/atalman , https://github.com/Skylion007 , https://github.com/drisspg
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com >
Co-authored-by: Andrey Talman <atalman@fb.com >
2025-11-04 18:46:43 +00:00
496277a8ff
[ROCm][CI] Lower runner check gpu count for distributed jobs ( #166961 )
...
This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information:
https://github.com/pytorch/pytorch/issues/166866
It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961
Approved by: https://github.com/jeffdaily
2025-11-04 18:44:21 +00:00
53f75cd5ba
Fixed some syntax errors in SECURITY.md file. ( #166718 )
...
Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718
Approved by: https://github.com/mikaylagawarecki
2025-11-04 18:18:38 +00:00
527b1109a8
Delete deprecated fp32 precision warnings ( #166956 )
...
The deprecation warning led to warning spamming in PyTorch APIs, like
torch.compile. This is not how a deprecation warning should go: if we
add a deprecation warning, we'd better update our built-in APIs to
prevent warning spam.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956
Approved by: https://github.com/albanD
2025-11-04 17:50:04 +00:00
3144713325
subproc_pool: Add support for enabling quiesce via a timer ( #166467 )
...
This adds the capability to subproc pool to enable quiesce via a timer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467
Approved by: https://github.com/masnesral
2025-11-04 17:37:41 +00:00
eefa16342c
[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer ( #166165 )
...
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-04 17:23:04 +00:00
d02f68f484
[BE] Use [[maybe_unused]] ( #166865 )
...
Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature
Will replace further repetitions of the same pattern soon after
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865
Approved by: https://github.com/mikaylagawarecki , https://github.com/Skylion007 , https://github.com/janeyx99
2025-11-04 17:08:28 +00:00
68eb55c4b2
Add model code stack trace to cuda.memory._snapshot ( #166676 )
...
We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information.
To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be `<eval_with_key>`. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `<eval_with_key>` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`.
`augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames.
The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`.
Alternative:
Instead of setting co_filename, we can also do it like below:
Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available.
```
key = filename
globals_copy = globals.copy()
globals_copy["__file__"] = key
globals_copy["__name__"] = key
linecache.lazycache(key, globals_copy)
exec(compile(src, key, "exec"), globals)
````
Other changes:
- Update `MemoryViz.js` to display fx node information and original model code if exist
```
python test/test_fx.py -k test_lineno_map
python test/test_fx.py -k test_custom_traceback_raised
python test/test_public_bindings.py
python test/test_cuda.py -k test_fx_memory
python test/test_fx.py -k test_informative_co_filename
python test/test_fx.py -k test_autowrap_functions
python test/dynamo/test_utils.py -k test_inductor_provenance
```
```python
# Profile with memory snapshot
torch.cuda.memory._record_memory_history()
with torch._dynamo.config.patch("enrich_profiler_stack_trace", True):
compiled = torch.compile(mod, backend="aot_eager", fullgraph=True)
result = compiled(torch.randn(10, 10, device="cuda:0"))
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True)
torch.cuda.memory._record_memory_history(enabled=None)
```
<img width="913" height="711" alt="Screenshot 2025-10-30 at 10 40 44 AM" src="https://github.com/user-attachments/assets/8d7a1833-f98d-4756-b666-1d63ab57b27b " />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676
Approved by: https://github.com/albanD , https://github.com/ezyang
2025-11-04 17:01:02 +00:00
8d4b8ab430
[ez] Print some more test timing info in the logs ( #166447 )
...
You can just subtract timestamps, but this makes it easier
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447
Approved by: https://github.com/Skylion007
2025-11-04 16:45:22 +00:00
afd50bdd29
[CI] Use smaller amx + avx2 runners for inductor test? ( #164989 )
...
Results from CI:
No failures but generally takes longer, maybe ~20% increase in time?
But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease
If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989
Approved by: https://github.com/BoyuanFeng , https://github.com/huydhn
2025-11-04 16:43:06 +00:00
56dfd4c74b
Add CUDA MXFP4 scaled mm support via. FBGEMM ( #166526 )
...
Summary:
* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg , https://github.com/ngimel
2025-11-04 15:53:16 +00:00
24db5c4451
[inductor] do not hard fail on FakePG with nccl estimator ( #166869 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166869
Approved by: https://github.com/eellison
ghstack dependencies: #166521
2025-11-04 15:22:38 +00:00
cc8bfd1206
Docker release build: Use 13.0.0 nvidia docker ( #166904 )
...
Forward fix for failing Docker release builds
Related to: https://github.com/pytorch/pytorch/issues/166897
Nightly Docker build failure https://github.com/pytorch/pytorch/actions/runs/18900508440/job/53946606434
Due to missing base image:
```
ERROR: failed to build: failed to solve: docker.io/nvidia/cuda:13.0.2-devel-ubuntu22.04: not found
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166904
Approved by: https://github.com/tinglvv , https://github.com/malfet
2025-11-04 13:58:10 +00:00
c45b156605
Fix DeepSeek scaling tensor handling ( #166752 )
...
Summary:
cuBlasLt enforces size/stride requirements for 1x128 and 128x128 blockwise scaling
kernels, some of which weren't being handled, causing silent incorrect
answers especially for 128x128 scaling cases.
cuBlasLt enforces ([docs](https://docs.nvidia.com/cuda/cublas/#scaling-factors-layouts )) for deepseek-style
scaling, for `A: MxN`, `B: KxN` you have the following:
```Py
L = K // 128
L4 = round_up(L, 4)
1x128 x 128x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [L4, N // 128], stride: [1, L4]
128x128 x 1x128:
* A_scale: [L4, M // 128], stride: [1, L4]
* B_scale: [N, K // 128], stride: [1, N]
1x128 x 1x128:
* A_scale: [M, K // 128], stride: [1, M]
* B_scale: [N, K // 128], stride: [1, N]
```
Notable here is the `L4` term, which means that we must round up to the nearest multiple of 4 blocks
in the `K` dimension. This wasn't enforced previously, and caused silent wrong answers
where `(K // 128) % 4 != 0`.
Test Plan:
Reviewers:
Subscribers:
@vkuzo
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166752
Approved by: https://github.com/drisspg , https://github.com/vkuzo
2025-11-04 13:32:24 +00:00
8fff7e36b4
[xpu][test] Add UT for expandable segments ( #166495 )
...
# Motivation
This PR aims to reuse some UT to validate the expandable segment feature.
# Additional Context
Currently, the failure is related to the internal track `GSD-11403`, we could get the fix when upgrading the driver to `ci-neo-master-034630` or greater
TODO: add test conv and gemm into this test case when upgrading the driver.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166495
Approved by: https://github.com/albanD , https://github.com/EikanWang , https://github.com/gujinghui
ghstack dependencies: #166299 , #166292 , #166424
2025-11-04 08:01:35 +00:00
82fa2aa269
DTensor: Fix trivial as_strided case, add alias support ( #166867 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166867
Approved by: https://github.com/albanD
ghstack dependencies: #166868
2025-11-04 07:18:32 +00:00
09e0285608
[xpu][feature][inductor] Enable decompose_mm_pass and UT on Intel GPU ( #166613 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166613
Approved by: https://github.com/hl475
2025-11-04 06:58:05 +00:00
d980d8dc79
[dynamo] Implement __sym_float__ for SymBool to fix multiplication TypeError ( #165264 )
...
Fixes #164684
### Description
Symbolic tracing fails during multiplication between a `SymBool` and a `Tensor`. This scenario is triggered when `.item()` is called on a 0-dim boolean tensor within a `torch.compile` region. In compile mode, this yields a `SymBool`, and the subsequent `SymBool * FakeTensor` operation is unsupported, leading to a `TypeError` or a data-dependent `UserError`.
### Solution
This PR addresses the issue at the type-conversion level, as suggested by reviewers.
The root cause of the TypeError is that torch.sym_float() (which is called by _maybe_convert_to_dtype during type promotion for aten.mul) lacks a conversion path for SymBool and incorrectly falls back to builtins.float(SymBool).
This fix addresses this by implementing the __sym_float__(self) method within the SymBool class (defined in torch/__init__.py).
The torch.sym_float(a) utility function is already designed to check for hasattr(a, "__sym_float__") before falling back to builtins.float(). By adding this method, SymBool instances now correctly advertise their ability to be cast to SymFloat. The new method implementation leverages self.node.sym_float() to correctly convert the symbolic boolean value to its symbolic float representation (0.0 or 1.0), resolving the TypeError at its source.
This approach is more fundamental than modifying a specific operation in builtin.py and ensures SymBool can be correctly promoted to SymFloat in any operation, while still preserving its boolean nature for control flow operations like guard_or_false (which is verified by a new test case).
### Verification
1. **Bug Reproduced**: The initial `UserError: Could not guard on data-dependent expression` was successfully reproduced with the script from the issue. As shown below
<img width="1369" height="945" alt="Screenshot 2025-10-13 at 10 29 05" src="https://github.com/user-attachments/assets/8daa4555-3347-4af5-906a-02150b8df9d1 " />
2. **Fix Validated**: After applying the code changes, the same script now runs to completion, printing `✅ eager success` and `✅ compile success`. As shown below
<img width="1228" height="82" alt="Screenshot 2025-10-13 at 10 29 21" src="https://github.com/user-attachments/assets/94c4f143-b898-4dda-9bff-0ad5450a30fa " />
3. Added a new test class DynamoOpPromotionTests to test/dynamo/test_misc.py with three new test cases:
1. test_symbool_tensor_mul_does_not_fail: Verifies that the original bug report code (with .item() + *) no longer raises an error when compiled.
2. test_symbool_guard_or_false: Verifies that this fix does not cause a regression for guard_or_false(SymBool) (the concern raised by reviewers).
3. test_symbool_tensor_mul: Verifies the behavior of Tensor(bool) * Tensor(float) (without .item()) for completeness.
All new tests were added and pass locally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165264
Approved by: https://github.com/laithsakka , https://github.com/Lucaskabela
2025-11-04 06:33:20 +00:00
c7d00de115
[xpu][fix] Fix XPU oneDNN memory query bug: pointer to array ( #166830 )
...
# Motivation
I believe this is a bug - here's why:
In [dnnl_common_types.h](98132c4908/include/oneapi/dnnl/dnnl_common_types.h (L116-L125) ) is defined as a pointer to an `int64_t[12]` array;
We can confirm this from the implementation in [memory_desc.cpp](98132c4908/src/common/memory_desc.cpp (L746-L748) ) where the member indeed points to an internal array.
# Solution
Therefore, when accessing `md_padded_dims`, we should first dereference the pointer and then use it with an index - directly using it without dereferencing would corrupt memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166830
Approved by: https://github.com/EikanWang
2025-11-04 06:12:40 +00:00
d3cf90ada5
Revert "[inductor] require shape in TritonCSEVariable ( #162275 )"
...
This reverts commit c21868b4359586550b12e1d9102283c792f45dff.
Reverted https://github.com/pytorch/pytorch/pull/162275 on behalf of https://github.com/izaitsevfb due to breaking test_rms_norm_bwd_float32_split_reductions_True_shape2 ([comment](https://github.com/pytorch/pytorch/pull/162275#issuecomment-3484049109 ))
2025-11-04 06:06:18 +00:00
0e1a88904f
[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel ( #165036 )
...
Make sure you're on cutlass 4.2.0+
Test Plan:
Tritonbench(oss):
`clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy`
Unit Tests(oss):
`clear; python test/inductor/test_cutedsl_grouped_mm.py`
Differential Revision: D82010227
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165036
Approved by: https://github.com/alexsamardzic , https://github.com/drisspg , https://github.com/mlazos
2025-11-04 05:58:58 +00:00
3232caa078
[XPU][Fix] Register convolution_overrideable for flops count ( #166839 )
...
Fixes #166838
1. Register `convolution_overrideable` key for flop_counter. CUDA relies on keys with `cudnn_convolution`. For devices like `XPU`, it falls to `convolution_overrideable`. Without the correct registration, the flop_couter will silently return 0 for XPU in line:
e1d011d6eb/torch/_inductor/analysis/profile_analysis.py (L178-L179)
2. Enable the tests when enabling the XPU on `test_analysis.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166839
Approved by: https://github.com/guangyey , https://github.com/EikanWang , https://github.com/jansel
2025-11-04 05:56:29 +00:00
a6c6acea9d
[11/N] Apply ruff UP035 rule ( #166225 )
...
This PR continues to apply ruff UP035 rule to inductor code. ruff UP035 rule aims to use Python 3.10 syntax and libraries.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166225
Approved by: https://github.com/aorenste
2025-11-04 04:53:40 +00:00
55be1cc739
[dynamo, 3.14] add explicit SymFloat int conversion ( #166902 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166902
Approved by: https://github.com/malfet , https://github.com/pianpwk
ghstack dependencies: #166757 , #166894 , #166895
2025-11-04 04:38:03 +00:00
344cebda52
[dynamo, 3.14] disable cpython dynamo unittests if 3.14 ( #166895 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166895
Approved by: https://github.com/guilhermeleobas
ghstack dependencies: #166757 , #166894
2025-11-04 04:38:03 +00:00
ba72c6b981
[dynamo, 3.14] fix dynamo error message test for 3.14 ( #166894 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166894
Approved by: https://github.com/malfet
ghstack dependencies: #166757
2025-11-04 04:38:03 +00:00
888efcc453
[dynamo, 3.14] support tracing type.__dict__[__annotations__].__get__ to trace through typing.get_type_hints ( #166757 )
...
This is covered by `test_get_type_hints` in test/dynamo/test_repros.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166757
Approved by: https://github.com/Lucaskabela
2025-11-04 04:38:03 +00:00
24aa9a2ef7
[ROCm][CI] Add distributed testing back to trunk.yml ( #166915 )
...
Adding distributed testing back to trunk since we have been observing [reasonable queueing](https://hud.pytorch.org/queue_time_analysis?dateRange=30&startDate=2025-10-05T01%3A44%3A55.924Z&endDate=2025-11-04T01%3A44%3A55.925Z&granularity=week&chartType=bar&repos=pytorch%2Fpytorch&category=machine_type&machineTypes=linux.rocm.gpu.gfx942.1&items=linux.rocm.gpu.gfx942.1 ) based on current MI3xx capacity.
Partially addresses https://github.com/pytorch/pytorch/issues/166108 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166915
Approved by: https://github.com/jeffdaily
2025-11-04 04:29:29 +00:00
f70faf2b9a
[xpu][feature] Introduce PeerToPeerAccess API for XPU ( #166424 )
...
# Motivation
This PR introduces support for peer-to-peer (P2P) access between devices, including querying and enabling P2P connections between two devices.
It supports two categories of allocations:
- Regular allocations;
- Expandable segment allocations.
# Additional Context
The follow-up is that we should use this feature to optimize our copy kernel when P2P is supported.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166424
Approved by: https://github.com/gujinghui , https://github.com/albanD
ghstack dependencies: #166299 , #166292
2025-11-04 04:03:28 +00:00
167e64ba1a
[xpu][feature] Support expandable segment feature for XPU ( #166292 )
...
# Motivation
This PR intends to add expandable segment feature support on XPU. This will help
- Reduce memory fragmentation;
- Gradually map physical pages into virtual address space as needed.
# Additional Context
The traditional caching allocator frequently allocates and frees device memory blocks. However, over time, with varying tensor size, the device address space becomes fragmented. Even when there's enough total free memory, a lack of contiguous space can cause large allocations to fail.
The **expandable segment** feature addresses this by dynamically extending physical memory within a reserved virtual address range, reducing fragmentation and minimizing reallocation overhead.
The potential drawbacks are
- Virtual memory overhead;
- Potential page mapping overhead;
- Increased complexity.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166292
Approved by: https://github.com/albanD , https://github.com/EikanWang , https://github.com/gujinghui
ghstack dependencies: #166299
2025-11-04 04:03:28 +00:00
875b18d53c
[xpu][feature] Introduce ExpandableSegment for XPU ( #166299 )
...
# Motivation
This PR intends to add `ExpandableSegment` struct, which is used to help support the expandable segment feature. I split it to a single PR to facilitate the code review.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166299
Approved by: https://github.com/EikanWang , https://github.com/albanD , https://github.com/gujinghui
2025-11-04 04:03:28 +00:00
eec3749c44
[DebugMode] .fwd_stack_trace for autograd bwd ops ( #166842 )
...
In #166440 , didn't realize you could turn on anomaly mode while disabling NaN checks for these stacks. Adding them to `debug_mode.operators[*].fwd_stack_trace`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166842
Approved by: https://github.com/yushangdi , https://github.com/mikaylagawarecki
2025-11-04 03:28:43 +00:00
40133fe966
Fix MSCV C++ compilation error of pycore_stackref.h header ( #165686 )
...
Wraps the header in a C file and compile it using a C compiler, which should support designated initializers
Fix issue #160647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165686
Approved by: https://github.com/williamwen42
2025-11-04 02:51:16 +00:00
f288433d3e
[dynamo] Raise on as_python_constant error on getattr ( #166909 )
...
This ensures that we graph break at the right time, leading to the right
stack trace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166909
Approved by: https://github.com/tugsbayasgalan
2025-11-04 02:45:59 +00:00
864633fca0
[xpu][test] Enable test_fxir_backend tests for XPU ( #166493 )
...
This PR enables `test_fxir_backend.py`'s tests formerly skipped xpu tests. No additional changes needed for the features.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166493
Approved by: https://github.com/angelayi , https://github.com/EikanWang
2025-11-04 02:14:46 +00:00
c21868b435
[inductor] require shape in TritonCSEVariable ( #162275 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162275
Approved by: https://github.com/mlazos
ghstack dependencies: #164158
2025-11-04 02:13:41 +00:00
a0a8eca01a
Fixes torch.compile(nn.ModuleList()) changes bool() behavior ( #159208 )
...
Fixes #159139
## The Cause
The bug occurs because the OptimizedModule wrapper in torch._dynamo.eval_frame doesn't call the len method. This causes Python's bool() check to fall back to the default object truthiness (always True) instead of correctly evaluating containers with len() == 0 as False.
## The Fix
A very easy fix . I just added the len method to OptimizedModule in torch._dynamo.eval_frame class to delegate the call to the original module
```python
def __len__(self):
"""
Proxy the len() call to the original module to fix truthiness checks.
"""
return len(self._orig_mod)
```
This successfully fixes the issue . The script now works as expected.
## Reproduction Script
```python
import torch
import torch.nn as nn
# Create an empty nn.ModuleList
original = nn.ModuleList()
# Compile it using torch.compile
compiled = torch.compile(original)
# Compare their boolean evaluations
print(f"bool(original): {bool(original)}")
print(f"bool(compiled): {bool(compiled)}")
# Trigger failure if they differ
assert bool(original) == bool(compiled), "BUG: truthiness behavior mismatch after compilation"
```
## Output
bool(original): False
bool(compiled): False
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159208
Approved by: https://github.com/Lucaskabela
Co-authored-by: pushkar-hue <pushkarsharma.rtm@gmail.com >
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com >
2025-11-04 02:12:10 +00:00
0958f307d9
Add _heapq polyfill ( #161093 )
...
----
* Redirect `_heapq.*` functions to the python implementation
* Handle TypeError in PolyfilledFunctionVariable to raise observed exceptions
* Implement `__next__` method in IteratorVariable class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161093
Approved by: https://github.com/Lucaskabela
2025-11-04 02:11:33 +00:00
7551507c41
[BE][Typing][Dynamo] Type torch/_dynamo/variables/builtin.py ( #166745 )
...
Provides type coverage to torch/_dynamo/variables/builtin.py
### Coverage report:
`mypy torch/_dynamo/variables/builtin.py --linecount-report /tmp/coverage_log`
Compare before to after - we go from 2213 lines and 64 funcs covered to 3212 lines and 85 funcs covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166745
Approved by: https://github.com/williamwen42
2025-11-04 01:33:10 +00:00
f92834d477
Fix unused assignments ( #166791 )
...
This PR cleans up unused assignments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166791
Approved by: https://github.com/xmfan
2025-11-04 01:07:19 +00:00
e1fc01bef8
Enable clang-tidy on some excluded headers ( #166835 )
...
This PR enables clang-tidy on some excluded headers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166835
Approved by: https://github.com/Skylion007
2025-11-04 00:37:32 +00:00
22a745737a
Remove ifndef C10_MOBILE around aoti_torch_abi_version impl ( #166882 )
...
See if after the headeronly migration the mobile build would still fail.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166882
Approved by: https://github.com/mikaylagawarecki
2025-11-04 00:37:22 +00:00
ee708ea96c
fix test_type_hints ( #163150 )
...
Fixes #163149
### Summary:
Fixes mypy type checking failures in `test_type_hints` by consolidating typing imports and eliminating duplicate/conflicting import patterns that caused mypy to fail resolving type annotations.
### Impact:
- `test_type_hints` works fine now
- module: tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163150
Approved by: https://github.com/Skylion007
2025-11-04 00:29:22 +00:00
64819e3701
[Pytorch] Improve conversion from bf16 on aarch64/NEON ( #166880 )
...
Summary:
Conversion from/to bfloat16 was not getting covered by conversion templates, because these used bfloat16_t as data type instead of the custom c10::BFloat16
Conversion by casting from/to bfloat16_t is broken in clang-[17, 20], fixed in clang-21.
Because Pytorch does not currently have CI running binaries compiled using clang-21, we won't implement this approach for now.
We are currently only adding conversion from bfloat16, as it can be implementing by zero-extending into a 4-byte float.
We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2:
Before:
bfloat16_t->uint8 ===> 423.583us
bfloat16_t->int8 ===> 424.090us
bfloat16_t->int16 ===> 430.817us
bfloat16_t->int64 ===> 571.547us
bfloat16_t->double ===> 459.089us
After:
bfloat16_t->uint8 ===> 123.783us ----> 342% higher throughput
bfloat16_t->int8 ===> 131.575us -----> 322% higher throughput
bfloat16_t->int16 ===> 136.794us ----> 315% higher throughput
bfloat16_t->int64 ===> 177.699us ----> 322% higher throughput
bfloat16_t->double ===> 165.556us ---> 277% higher throughput
Test Plan:
Correctness:
buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch
Performance:
buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test
Differential Revision: D86119613
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166880
Approved by: https://github.com/mcfi , https://github.com/aditew01
2025-11-04 00:19:42 +00:00
79ff2c66c8
Revert "Fix unused assignments ( #166791 )"
...
This reverts commit 5125872aeb737fa20ea2ec08338e9342cba694e7.
Reverted https://github.com/pytorch/pytorch/pull/166791 on behalf of https://github.com/cyyever due to incomplete PR ([comment](https://github.com/pytorch/pytorch/pull/166791#issuecomment-3483116247 ))
2025-11-04 00:13:50 +00:00
665a411351
Revert "[CUDA] Skip pynvml test on platforms that don't have complete support ( #159689 )"
...
This reverts commit 68e31e2f814f9f6a9fb87381367e6b33e17c1c2b.
Reverted https://github.com/pytorch/pytorch/pull/159689 on behalf of https://github.com/izaitsevfb due to breaking internal tests [D86127316] ([comment](https://github.com/pytorch/pytorch/pull/159689#issuecomment-3483095879 ))
2025-11-04 00:10:14 +00:00
5c89bdb461
[MPS] Fix smooth_l1_loss backward for fp16 ( #166687 )
...
- Enable fp16 implementation for CPU, by using `convert_to_float` primitives instead of `convert_bfloat16_float` and extending bf16 implementation to half
- Simplify OpInfo definitions for the backward
Originally PR used `AT_DISPATCH_ALL_TYPES_AND(kHalf,`, but it cause ICE with gcc-13 when compiled with SVE128:
```
/opt/rh/gcc-toolset-13/root/usr/bin/c++ -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DAT_PER_OPERATOR_HEADERS -DBUILD_ONEDNN_GRAPH -DCAFFE2_BUILD_MAIN_LIB -DCAFFE2_PERF_WITH_SVE=1 -DCPUINFO_SUPPORTED_PLATFORM=1 -DENABLE_IPC_FABRIC -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_POSIX_FALLOCATE=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DKINETO_NAMESPACE=libkineto -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_MIMALLOC -DUSE_RPC -DUSE_TENSORPIPE -DXNN_LOG_LEVEL=0 -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/pytorch/build/aten/src -I/pytorch/aten/src -I/pytorch/build -I/pytorch -I/pytorch/nlohmann -I/pytorch/moodycamel -I/pytorch/third_party/mimalloc/include -I/pytorch/torch/csrc/api -I/pytorch/torch/csrc/api/include -I/pytorch/caffe2/aten/src/TH -I/pytorch/build/caffe2/aten/src/TH -I/pytorch/build/caffe2/aten/src -I/acl -I/acl/include -I/pytorch/build/caffe2/../aten/src -I/pytorch/torch/csrc -I/pytorch/torch/headeronly -I/pytorch/third_party/miniz-3.0.2 -I/pytorch/third_party/kineto/libkineto/include -I/pytorch/third_party/kineto/libkineto/src -I/pytorch/third_party/cpp-httplib -I/pytorch/aten/src/ATen/.. -I/pytorch/third_party/FXdiv/include -I/pytorch/c10/.. -I/pytorch/third_party/pthreadpool/include -I/pytorch/third_party/cpuinfo/include -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include -I/pytorch/third_party/NNPACK/include -I/pytorch/third_party/FP16/include -I/pytorch/third_party/tensorpipe -I/pytorch/build/third_party/tensorpipe -I/pytorch/third_party/tensorpipe/third_party/libnop/include -I/pytorch/third_party/kleidiai -I/pytorch/third_party/fmt/include -I/pytorch/build/third_party/ideep/mkl-dnn/include -I/pytorch/third_party/ideep/mkl-dnn/src/../include -I/pytorch/third_party/onnx -I/pytorch/build/third_party/onnx -I/pytorch/third_party/flatbuffers/include -isystem /pytorch/build/third_party/gloo -isystem /pytorch/cmake/../third_party/gloo -isystem /pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /pytorch/third_party/protobuf/src -isystem /opt/OpenBLAS/include -isystem /pytorch/third_party/XNNPACK/include -isystem /pytorch/cmake/../third_party/eigen -isystem /pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /pytorch/third_party/ideep/include -isystem /pytorch/INTERFACE -isystem /pytorch/third_party/nlohmann/include -isystem /pytorch/third_party/concurrentqueue -isystem /pytorch/build/include -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_PYTORCH_QNNPACK -DAT_BUILD_ARM_VEC256_WITH_SLEEF -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -DC10_NODEPRECATED -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=range-loop-construct -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -faligned-new -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-dangling-reference -Wno-error=dangling-reference -Wno-stringop-overflow -DHAVE_SVE_CPU_DEFINITION -DHAVE_SVE256_CPU_DEFINITION -DHAVE_ARM_BF16_CPU_DEFINITION -O3 -DNDEBUG -DNDEBUG -fPIC -fdiagnostics-color=always -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -D__NEON__ -DBLAS_HAS_SBGEMM -Wall -Wextra -Wdeprecated -Wunused -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wredundant-move -Wno-interference-size -Wno-maybe-uninitialized -fvisibility=hidden -pthread -fopenmp -O3 -march=armv8-a+sve+bf16 -D__ARM_FEATURE_BF16 -DCPU_CAPABILITY_SVE -msve-vector-bits=256 -DCPU_CAPABILITY=SVE256 -DCPU_CAPABILITY_SVE256 -MD -MT caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o -MF caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o.d -o caffe2/CMakeFiles/torch_cpu.dir/__/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp.o -c /pytorch/build/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp
during RTL pass: expand
In file included from /pytorch/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp:6,
from /pytorch/build/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp.SVE256.cpp:1:
/pytorch/aten/src/ATen/native/cpu/Loops.h: In function ‘void at::native::SVE256::vectorized_loop(char**, int64_t, int64_t, func_t&&, vec_func_t&&) [with func_t = at::native::{anonymous}::smooth_l1_backward_cpu_kernel(at::TensorIterator&, const c10::Scalar&, double)::<lambda()>::<lambda()>::<lambda(scalar_t, scalar_t, scalar_t)>&; vec_func_t = at::native::{anonymous}::smooth_l1_backward_cpu_kernel(at::TensorIterator&, const c10::Scalar&, double)::<lambda()>::<lambda()>::<lambda(at::vec::SVE256::Vectorized<c10::Half>, at::vec::SVE256::Vectorized<c10::Half>, at::vec::SVE256::Vectorized<c10::Half>)>&]’:
/pytorch/aten/src/ATen/native/cpu/Loops.h:200:1: internal compiler error: in expand_insn, at optabs.cc:8185
200 | vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
| ^~~~~~~~~~~~~~~
Please submit a full bug report, with preprocessed source.
See <http://bugzilla.redhat.com/bugzilla > for instructions.
Preprocessed source stored into /tmp/ccgYMlTo.out file, please attach this to your bugreport.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687
Approved by: https://github.com/Skylion007
2025-11-03 23:54:54 +00:00
7b64ad906c
[FSDP][Replicate] got rid of reshard_after_forward and updated test cases ( #166469 )
...
**Summary:** I have gotten of reshard_after_forward and shard_placement as inputs for replicate as there will be no sharding. I have also updated all the necessary tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166469
Approved by: https://github.com/weifengpy
ghstack dependencies: #166433 , #166459
2025-11-03 23:48:18 +00:00
d944279def
[FSDP][Replicate] added two replicate overload declarations and changed device_mesh to mesh ( #166459 )
...
**Summary:** Just like in fully_shard, I added two overload replicate functions. The `@overload` declarations are necessary because the `@contract` decorator uses `ParamSpec` to capture function parameters, which creates a generic `_ContractFn` protocol signature (`*args: _P.args, **kwargs: _P.kwargs`) that Pyrefly cannot properly type-check when calling the function with explicit keyword arguments. In addition, to make the api cleaner I changed device_mesh input argument to mesh to match fully_shard formatting.
**Test Cases**
1. pytest test/distributed/_composable/test_replicate_with_fsdp.py
2. pytest test/distributed/_composable/test_replicate_training.py
3. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166459
Approved by: https://github.com/weifengpy
ghstack dependencies: #166433
2025-11-03 23:35:21 +00:00
5048e4701d
explicitly remove call_mod_node_to_replace after inlining the submodule in const_fold._inline_module` ( #166871 )
...
Summary:
https://github.com/pytorch/pytorch/pull/166609 updated `is_impure` check to now check ops inside a subgraph to decide whether a `call_module` node is pure or not.
This change of behavior affects dead code elimination, commonly run as `gm.graph.eliminate_dead_code()`. Specifically, dead code elimination will not erase a node that has no users if this node has side effect or is impure. With above mentioned pr, dead code elimination no longer eliminates unused subgraphs that contain side-effectful ops.
This affects `const_fold.split_const_subgraph`, what this function does is:
1. split a graph into two submodules, one containing all const ops and one containing non-const ops
2. inline the submodule containing non-const ops back to main graph.
3. run dead code elimination to remove the unused non-const submodule.
With pr #166609 step 3 no longer erases the unused module. As an example, exported graph
```
graph():
%x : [num_users=2] = placeholder[target=x]
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
%empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([5, 10], [0, 1]), kwargs = {device: cpu, pin_memory: False})
%bernoulli : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%empty_permuted, 0.6), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul, 0.6), kwargs = {})
return (div,)
```
After running const_fold, empty_permuted is const-folded, the rest of ops are not, and the main graph looks like
```
graph():
%x : [num_users=3] = placeholder[target=x]
%_fx_const_folded_attrs : [num_users=2] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
%bernoulli_p : [num_users=1] = call_function[target=torch.ops.aten.bernoulli.p](args = (%_fx_const_folded_attrs, 0.6), kwargs = {})
%mul_tensor : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, %bernoulli_p), kwargs = {})
%div_tensor : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%mul_tensor, 0.6), kwargs = {})
%submod_1 : [num_users=0] = call_module[target=submod_1](args = (%x, %_fx_const_folded_attrs), kwargs = {})
return (div_tensor,)
```
`submod_1` is dangling, unused, and just inlined into the graph.
## Fix
This pr updates `const_fold._inline_module` function to explicitly remove the non-const submodule which is unused, after it has inlined the submodule's ops into main graph.
Test Plan:
Added a test in `test_fx_const_fold.py`.
The test would have failed before this PR becuase it yields above example graph leaving an unused `call_module[target=submod_1]` op.
With the PR, the module is erased from main graph correctly.
Differential Revision: D86056354
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166871
Approved by: https://github.com/blaine-rister , https://github.com/mlazos
2025-11-03 23:23:10 +00:00
616314cfd5
[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp ( #166433 )
...
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter.
**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py
2. pytest test_pp_composability.py
3. pytest test_replicate_with_fsdp.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166433
Approved by: https://github.com/weifengpy
2025-11-03 23:20:23 +00:00
2b7e4c3ef2
[DCP] Add option to use PrefixStore to create checkpoint background process ( #166560 )
...
Summary:
DCP checkpoint background process currently determines the port used for pg via get_free_port().
During checkpoint background process initialization, gloo pg init occasionally times out on the first call but succeeds in a subsequent call.
We hypothesized that the timeouts are related to the port being used, and the solution would be to create the pg with PrefixStore and reuse the master port.
This diff adds the option for checkpoint background process to use PrefixStore with MASTER_ADDR + MASTER_PORT.
The default behavior is unchanged. Enabling the new PrefixStore behavior requires setting "DCP_USE_PREFIX_STORE" env var to "1".
context:
https://fb.workplace.com/groups/319878845696681/permalink/1516883985996155/
Differential Revision: D84928180
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166560
Approved by: https://github.com/meetv18
2025-11-03 23:08:12 +00:00
6c98657239
Add some Triton related suppressions that don't show on CI ( #166868 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166868
Approved by: https://github.com/maggiemoss , https://github.com/zou3519
2025-11-03 22:54:50 +00:00
86b2d82e84
Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer ( #166165 )"
...
This reverts commit 94f2657c4b534136aa8958bc35d44ceac5ccd60c.
Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/izaitsevfb due to breaks test_LinearAndSoftmax_codegen test ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3482926991 ))
2025-11-03 22:52:41 +00:00
eea8ff2d34
Fix torch.full with dynamic tensor fill_value in torch.compile ( #166554 )
...
Fixes #166253
## Summary
When `torch.full` is called with a 0-D tensor as `fill_value` inside a `torch.compile`'d function, the value was being incorrectly cached, causing subsequent calls with different values to return the first value.
## Root Cause
The Dynamo handler for `torch.full` was calling `aten._local_scalar_dense` to convert tensor fill_values to Python scalars at compile time, which baked the value into the compiled graph as a constant.
## Solution
Modified the Dynamo handler to decompose `torch.full(size, tensor_fill_value)` into `empty(size).fill_(tensor_fill_value)` when `fill_value` is a `TensorVariable`, keeping the fill value dynamic in the compiled graph.
## Testing
Added test case that verifies torch.full works correctly with dynamic tensor fill_values across multiple calls and dtypes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166554
Approved by: https://github.com/Lucaskabela
2025-11-03 21:44:10 +00:00
11f73d78c8
[export] Downgrade captured buffers as normal constants. ( #166777 )
...
Summary:
make_fx() will register tensor constants as new buffers while tracing a shuffle graph for dynamo graph capture. This breaks the invariance that the resulting graph looks identical to the original eager model in terms of state dict.
So we need to de-register the buffers and set them as plain tensor constants.
Test Plan:
pytest test/export/test_experimental.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166777
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #166775 , #166776
2025-11-03 21:28:42 +00:00
7d1b976146
[export] Make dict_keys_getitem tracable. ( #166776 )
...
Summary:
dict_keys_getitem can show up in the bytecode but it's using dict.keys() which is not fx tracable.
fx.wrap should make it as a standalone function in the graph to be invoked later with real inputs.
Test Plan:
pytest test/export/test_experimental.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166776
Approved by: https://github.com/jamesjwu
ghstack dependencies: #166775
2025-11-03 21:28:42 +00:00
27cfdd9e77
[export] Return more information from tracing context in graph capture. ( #166775 )
...
Summary:
as title, we should return an entire tracing_context object instead of fake_mode only, since tracing context should contain full set of information.
Test Plan:
pytest test/export/test_experimental.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166775
Approved by: https://github.com/tugsbayasgalan
2025-11-03 21:28:42 +00:00
01d8d8584b
[MTIAGraph][Pytorch][2.1/n] Add API to destroy graph C++ instance ( #166806 )
...
I missed this API for MTIAGraph in D84457757(https://github.com/pytorch/pytorch/pull/165963 )
Differential Revision: [D86026706](https://our.internmc.facebook.com/intern/diff/D86026706/ )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166806
Approved by: https://github.com/albanD
ghstack dependencies: #166805
2025-11-03 21:11:40 +00:00
b8855e7b0b
Add conv ops to operator microbenchmark ( #166331 )
...
Adding `conv` (conv1d, conv2d, conv3d) to the list of operator microbenchmarks run in the CI script (`.ci/pytorch/test.sh`), ensuring convolution operators are now benchmarked alongside existing ones.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166331
Approved by: https://github.com/huydhn , https://github.com/jbschlosser
2025-11-03 20:54:52 +00:00
6725ee89c8
Fix cuda blas build error due to extra && ( #166811 )
...
Fixes #166810
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166811
Approved by: https://github.com/slayton58 , https://github.com/Skylion007 , https://github.com/malfet
2025-11-03 20:35:26 +00:00
3a38ec78e1
[inductor] Expand use of generic benchmark function ( #164938 )
...
Use the more generic `Benchmarker.benchmark` function to allow benchmarking other devices that support the required functionality, for example prologue and epilogue fusion can be benchmarked for triton CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164938
Approved by: https://github.com/nmacchioni , https://github.com/eellison
2025-11-03 20:15:25 +00:00
77b9399d83
[random] Add generator arg to rand*_like APIs ( #166160 )
...
Fixes #165865
## What this PR does?
- [x] Add `generator` arg to `rand*_like` APIs (`rand_like()`, `randn_like()`, `randint_like()`).
- [x] Add unit tests for `rand*_like` APIs
- [x] Add corresponding arg docs
- [x] Refactor `rand*_like()` codes in `TensorFactories.cpp`
- [x] Add corresponding and former missed items in `VmapModeRegistrations.cpp`
## Example (using `rand_like()`)
```python
gen0 = torch.Generator()
gen1 = torch.Generator()
gen2 = torch.Generator()
gen0.manual_seed(42)
gen1.manual_seed(42)
gen2.manual_seed(2025)
tensor = torch.empty(10)
t0 = torch.rand_like(tensor, generator=gen0)
t1 = torch.rand_like(tensor, generator=gen1)
t2 = torch.rand_like(tensor, generator=gen2)
assert t0 == t1
assert t2 != t0
assert t2 != t1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166160
Approved by: https://github.com/cyyever , https://github.com/albanD
2025-11-03 19:58:45 +00:00
83cd626365
[opaque_obj_v2] make_fx support ( #165005 )
...
By wrapping the python objects with FakeScriptObject(FakeOpaqueQueue) we restrict users to do anything to this object. torch.compile support can be easily enabled by the rest of [this stack](https://github.com/pytorch/pytorch/pull/163936 ) and existing support for ScriptObjects.
One thing to note is that by default in functionalization we mark all ops that take in FakeScriptObjects as being effectful. Should this be the case for these custom ops that take in python objs?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165005
Approved by: https://github.com/zou3519
2025-11-03 19:48:37 +00:00
5125872aeb
Fix unused assignments ( #166791 )
...
This PR cleans up unused assignments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166791
Approved by: https://github.com/xmfan
2025-11-03 19:45:01 +00:00
c10975d2e6
Revert "Avoid DDE in narrow with unbacked start ( #166361 )"
...
This reverts commit c76199980d09198964409919335e86cc6e3dc575.
Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3482194351 ))
2025-11-03 19:41:07 +00:00
68e31e2f81
[CUDA] Skip pynvml test on platforms that don't have complete support ( #159689 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159689
Approved by: https://github.com/msaroufim , https://github.com/Skylion007
2025-11-03 19:40:20 +00:00
ee1bc3f0d5
Manylinux ROCm docker images. use devtoolset-13 ( #166764 )
...
Update devtoolset in Manylinux 2.28 rocm builds. 11 is too old does not support compiling with C++20 properly
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166764
Approved by: https://github.com/sudharssun , https://github.com/jeffdaily
2025-11-03 19:32:33 +00:00
612ead1619
[distributed] Replace assert statements with AssertionError exceptions ( #165216 )
...
Replaces 71 assert statements across 11 files in `torch.distributed` with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag.
Fixes #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165216
Approved by: https://github.com/albanD
2025-11-03 19:30:48 +00:00
3af1f7bbf4
[easy][MTIAGraph][Pytorch] clang-format files ( #166805 )
...
Per suggestion from the previous PR(https://github.com/pytorch/pytorch/pull/165963 ), separating clang-format changes.
Differential Revision: [D86031474](https://our.internmc.facebook.com/intern/diff/D86031474/ )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166805
Approved by: https://github.com/Skylion007 , https://github.com/albanD
2025-11-03 19:27:09 +00:00
71a2e93547
[cuDNN][SDPA] Check-in test for #166211 ( #166570 )
...
Repros without the neeed for specific tensor data.
Should be passing with cuDNN frontend 1.15.0 which current `main` has.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166570
Approved by: https://github.com/atalman
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com >
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com >
2025-11-03 19:21:14 +00:00
c76199980d
Avoid DDE in narrow with unbacked start ( #166361 )
...
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-03 19:13:40 +00:00
e3bd7bd1f4
[FP8] Enable FP16 output support for torch scaled_mm when using CUTLASS on SM90 ( #166744 )
...
Summary: NVIDIA uses CUTLASS for row-wise scaling prior to cuBLAS version 12.9. This change enables support for FP16 data type for both bias and output when using CUTLASS.
Test Plan:
pytest -svv test/test_scaled_matmul_cuda.py
Test results on cuda-12.4:
```
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_bfloat16_cuda PASSED [0.0022s]
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_float16_cuda PASSED [0.0023s]
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_float32_cuda SKIPPED [0.0005s]
======================= 51 passed, 516 skipped in 5.26s ========================
```
Test results on cuda-12.9:
```
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_bfloat16_cuda PASSED [0.0046s]
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_float16_cuda PASSED [0.0040s]
test/test_scaled_matmul_cuda.py::TestFP8MatmulCUDA::test_scaled_mm_vs_emulated_row_wise_float32_cuda PASSED [0.0038s]
======================= 70 passed, 482 skipped in 5.88s ========================
```
Reviewed By: pranavsharma, RandySheriff
Differential Revision: D84169910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166744
Approved by: https://github.com/slayton58
2025-11-03 19:10:16 +00:00
aa4a8c9b92
[Inductor][Triton][FP8] Support tile-wise (1x128) scaling in Inductor ( #165132 )
...
Summary:
Support tile-wise `1x128` scaling in Inductor Triton for FP8 GEMMs, i.e. scaling values along tensors `a` and `b` represent a `1x128` slice of input.
NOTE: Block-wise `128x128` and `1x128` scaling is only supported in CUDA 12.9+; therefore, tile-wise scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run tile-wise scaling (as with deepseek-style scaling).
Test Plan:
Works out-of-the-box with TritonBench:
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/{opt,inplace} pytorch/tritonbench:run -- --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 256 --n 768 --k 512 --output="/home/jananisriram/personal/random_bench.csv" --scaling-pair=BlockWise1x128,BlockWise1x128 --atol=1e-2 --rtol=0.5
```
Differential Revision: D84025878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165132
Approved by: https://github.com/eqy , https://github.com/drisspg , https://github.com/njriasan
2025-11-03 18:37:13 +00:00
fa0fd6be13
Revert "[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp ( #166433 )"
...
This reverts commit bcad4f2e68e2a93a2855c1c22f0856fbb7c729e2.
Reverted https://github.com/pytorch/pytorch/pull/166433 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166433#issuecomment-3481929476 ))
2025-11-03 18:31:20 +00:00
2f3f88f445
Revert "[FSDP][Replicate] added two replicate overload declarations and changed device_mesh to mesh ( #166459 )"
...
This reverts commit d67d807270e070bbb873af61ea944ed98b52b9cf.
Reverted https://github.com/pytorch/pytorch/pull/166459 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166433#issuecomment-3481929476 ))
2025-11-03 18:31:20 +00:00
d67d807270
[FSDP][Replicate] added two replicate overload declarations and changed device_mesh to mesh ( #166459 )
...
**Summary:** Just like in fully_shard, I added two overload replicate functions. The `@overload` declarations are necessary because the `@contract` decorator uses `ParamSpec` to capture function parameters, which creates a generic `_ContractFn` protocol signature (`*args: _P.args, **kwargs: _P.kwargs`) that Pyrefly cannot properly type-check when calling the function with explicit keyword arguments. In addition, to make the api cleaner I changed device_mesh input argument to mesh to match fully_shard formatting.
**Test Cases**
1. pytest test/distributed/_composable/test_replicate_with_fsdp.py
2. pytest test/distributed/_composable/test_replicate_training.py
3. pytest test/distributed/_composable/test_composability/test_pp_composability.py -k test_replicate_pp
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166459
Approved by: https://github.com/weifengpy
ghstack dependencies: #166433
2025-11-03 18:20:07 +00:00
bcad4f2e68
[FSDP][Replicate] final version integrating 1D device mesh replicate into fsdp ( #166433 )
...
**Summary:** I have created a new composable replicate api that's integrated into FSDP's codebase with minimal changes. The key changes I made are when we use DDPMeshInfo, we use Replicate placements, prevent initial sharding of parameters, set worldsize to 1 to skip allgathers and reducescatter.
**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py
2. pytest test_pp_composability.py
3. pytest test_replicate_with_fsdp.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166433
Approved by: https://github.com/weifengpy
2025-11-03 18:09:27 +00:00
5b17ef30d0
Update docs-build to c7i ( #166727 )
...
This updates the docs-build nightly configuration to match other uses of the _linux-build.yml workflow using `runner_prefix` rather than `runner` directly. The default runner defined in _linux-build.yml is the c7i variant so this also updates the runner appropriately.
Relates to pytorch/test-infra#7175 . While moving to c7i costs 5% more, CPU intensive jobs should run roughly 15-20% faster resulting in a cost reduection of 10-15% for those jobs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166727
Approved by: https://github.com/huydhn
2025-11-03 18:02:09 +00:00
7b2992685b
Update test jobs in pull workflow to c7i ( #165646 )
...
Relates to pytorch/test-infra#7175 . While moving to c7i costs 5% more, CPU intensive jobs should run roughly 15-20% faster resulting in a cost reduection of 10-15% for those jobs.
This PR updates for the following test job suite that seem to benefit from the newer hardware:
* backwards_compat
* numpy_2_x
* ONNX default
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165646
Approved by: https://github.com/jordanconway , https://github.com/huydhn
2025-11-03 18:00:09 +00:00
f3fa560dec
Integrate NVIDIA cuSolver backend into ATen/Linalg (initial implementation for eig/eigval) ( #166715 )
...
### Summary
Adds support for NVIDIA’s cuSolver backend to torch.linalg.eig and torch.linalg.eigvals within the ATen/Linalg framework.
### Motivation
Extending PyTorch’s Linalg backends with NVIDIA’s cuSolver enables faster execution of torch.linalg.eig and torch.linalg.eigvals, complementing existing MAGMA and CPU implementations.
The speedup observed on consumer hardware (RTX4070/Ryzen 5700x) is in the order of **2x**, with preliminary testing on HPC hardware (H100, EPYC 9454) suggesting **up to 10x speedup**.
### Details
- Implements cuSolver support for linalg_eig and linalg_eigvals using the interface described in [NVIDIA cuSolver documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgeev ) as introduced in CUDA 12.8 [CUDA 12.8 release notes](https://docs.nvidia.com/cuda/archive/12.8.0/cuda-toolkit-release-notes/index.html )
- Follows the existing MAGMA backend design, adapting it for cuSolver’s cusolverDnXgeev API.
- Integrates with existing eig/eigvals dispatch mechanism.
- No automatic CPU↔GPU backend switching. (Happy to discuss)
- Verified via existing Linalg test coverage; no new tests introduced in this PR.
- Tested successfully against both test_linalg.py including slow test suites.
- Tested MAGMA fallback successfully using CUDA 12.4. (observed unrelated test failures)
### Impact
- Enables much faster solving of eigenvalue problems
- Maintains numerical consistency and test stability across backends.
- No change to public API or user-facing behavior.
Special thanks to @AlbanD for prior feedback and discussions regarding the PR and @lezcano for feedback on the related testing PR [https://github.com/pytorch/pytorch/pull/166322 ](https://github.com/pytorch/pytorch/pull/166322 ).
Happy to discuss backend dispatch strategy, results from performance and stability testing can be seen here [https://dev-discuss.pytorch.org/ ](https://dev-discuss.pytorch.org/t/cusolver-dnxgeev-faster-cuda-eigenvalue-calculations/3248/7 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166715
Approved by: https://github.com/lezcano , https://github.com/albanD
2025-11-03 17:44:22 +00:00
984b096d10
[ROCm][CI] Change rocm.yml and inductor-rocm.yml cron schedule to run every hour ( #166870 )
...
Temporary PR to change the rocm.yml and inductor-rocm.yml workflows to run on an hourly basis rather than on every commit. This is caused by the following:
We are observing cirrascale network timeouts as of 11/03/2025. [HUD Link](94f2657c4b/1 )
[SEV](https://github.com/pytorch/pytorch/issues/166866 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166870
Approved by: https://github.com/jeffdaily
2025-11-03 17:33:11 +00:00
104b868618
Fix build error by checking cuda version in CUDAGreenContext ( #166800 )
...
Fixes #166799
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166800
Approved by: https://github.com/mlazos , https://github.com/eqy , https://github.com/malfet
2025-11-03 16:41:38 +00:00
94f2657c4b
[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer ( #166165 )
...
Prefer unfused addmm when there is at least a single elemwise/reduction consumer..
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165
Approved by: https://github.com/eellison
2025-11-03 15:50:32 +00:00
3f6538febd
Remove tools from BC linter ( #166858 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166858
Approved by: https://github.com/albanD
2025-11-03 15:42:54 +00:00
f33abae695
Switch to pyrefly as only type checker ( #166197 )
...
This formally switches pytorch over from MyPy as a type checker to Pyrefly, and should help reduce the noise in lint runner right now, I will fast follow with PR's silencing existing errors and will work over the weekend to ensure trunk stays in a clean slate while we roll this out.
test:
`lintrunner init`
`lintrunner`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166197
Approved by: https://github.com/ezyang , https://github.com/seemethere , https://github.com/albanD
2025-11-03 15:32:56 +00:00
73da7a40b6
[MPS] Error out when BatchNorm is called for Complex ( #166215 )
...
Or BatchNorm or LayerNorm for Long types
Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166215
Approved by: https://github.com/dcci , https://github.com/kulinseth , https://github.com/Skylion007
ghstack dependencies: #166214
2025-11-03 15:24:09 +00:00
335b5c7d4b
Avoid std::copy_n in CopyKernel and IndexKernel ( #143544 )
...
This PR simplifies `std::copy_n` calls in CopyKernel and IndexKernel. `std::copy_n` is used to create a data pointer array from the input data pointers. However, more careful review reveals that the dest pointers are actually aliases of the original pointers. So we can removes the pointer manipulations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143544
Approved by: https://github.com/albanD
2025-11-03 15:16:04 +00:00
76bb27e248
Revert "Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed ( #164939 )" ( #165910 )" ( #166812 )
...
This reverts commit e6ba4d072510464c846f2013822f9388210eb907.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166812
Approved by: https://github.com/SherlockNoMad
2025-11-03 15:06:11 +00:00
a2da69385a
Remove nightly pth check from pyrefly ( #166857 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166857
Approved by: https://github.com/albanD
2025-11-03 14:53:49 +00:00
d177900723
[Code Clean] Clean asserts in torch/ao/quantization (root, quantizer, backend_config) ( #165433 )
...
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/~
- torch/ao/quantization/quantizer/
- torch/ao/quantization/backend_config/
fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165433
Approved by: https://github.com/mlazos , https://github.com/fffrog , https://github.com/cyyever
2025-11-03 14:52:37 +00:00
61bcc8d75a
Revert "Fixes torch.compile(nn.ModuleList()) changes bool() behavior ( #159208 )"
...
This reverts commit 21b48f8dfa7685699df4c97c0ba373d5364230d9.
Reverted https://github.com/pytorch/pytorch/pull/159208 on behalf of https://github.com/atalman due to Broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/159208#issuecomment-3480743499 ))
2025-11-03 14:10:01 +00:00
1656b253c5
Revert "[MPS] Fix smooth_l1_loss backward for fp16 ( #166687 )"
...
This reverts commit 4e7232c5daf753e04e8f4189229e3c33888a33e5.
Reverted https://github.com/pytorch/pytorch/pull/166687 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/19027214755/job/54332952760 ) [HUD commit link](95ab09cb54 ) ([comment](https://github.com/pytorch/pytorch/pull/166687#issuecomment-3480694316 ))
2025-11-03 14:05:25 +00:00
5d6230779d
Revert "Give full Dynamo stack traces in CI ( #160417 )"
...
This reverts commit e0791fc11dc0024a828495985898b29120dcc4c1.
Reverted https://github.com/pytorch/pytorch/pull/160417 on behalf of https://github.com/atalman due to test/dynamo/test_aot_compile.py::TestAOTCompile::test_aot_compile_graph_break_error_fmt [GH job link](https://github.com/pytorch/pytorch/actions/runs/19028849833/job/54339349886 ) [HUD commit link](e0791fc11d ) ([comment](https://github.com/pytorch/pytorch/pull/160417#issuecomment-3480680049 ))
2025-11-03 14:00:20 +00:00
a4077b568f
Revert "[MPS] Error out when BatchNorm is called for Complex ( #166215 )"
...
This reverts commit 9261a1fb128412201ef009d30844a2417364d73b.
Reverted https://github.com/pytorch/pytorch/pull/166215 on behalf of https://github.com/atalman due to sorry need to revert https://github.com/pytorch/pytorch/pull/166687 ([comment](https://github.com/pytorch/pytorch/pull/166215#issuecomment-3480661671 ))
2025-11-03 13:56:32 +00:00
ae038f871b
[inductor] Collectives estimations: option to use nccl estimator for fx node ( #166521 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166521
Approved by: https://github.com/eellison
2025-11-03 13:11:54 +00:00
defac66e39
[xla hash update] update the pinned xla hash ( #166845 )
...
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml ).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166845
Approved by: https://github.com/pytorchbot
2025-11-03 11:32:14 +00:00
061fa73c97
Reapply "Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed ( #164939 )" ( #165910 )" ( #166812 )
...
This reverts commit 5a3930abbc19eac9a179455df82e206e69765ed2.
Reverted https://github.com/pytorch/pytorch/pull/166812 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166812#issuecomment-3480004525 ))
2025-11-03 11:16:15 +00:00
9501405de6
[caffe2] Ignore -Wswitch-enum warnings ( #166760 )
...
Summary: Projects that use `-Wswitch-enum` will encounter issues when building and using *PyTorch* (`caffe2`). Address these issues to empower more rigorous upstream compiler warnings/errors.
Test Plan: CI Pass
Differential Revision: D85893917
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166760
Approved by: https://github.com/atalman
2025-11-03 09:37:47 +00:00
e0791fc11d
Give full Dynamo stack traces in CI ( #160417 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160417
Approved by: https://github.com/SherlockNoMad
2025-11-03 08:51:21 +00:00
e1d011d6eb
[2/N] Change C-style casts to static_cast or reinterpret_cast ( #165891 )
...
A follow-up of #165750 to clean up C casts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165891
Approved by: https://github.com/Skylion007
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com >
2025-11-03 08:02:58 +00:00
3f5401020b
[3/N] Add clang-tidy readability checks ( #164692 )
...
This PR adds two checks:
```
readability-static-definition-in-anonymous-namespace
Finds static function and variable definitions
in anonymous namespace.
readability-named-parameter
Find functions with unnamed arguments.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164692
Approved by: https://github.com/Skylion007
2025-11-03 07:28:21 +00:00
5a3930abbc
Revert "Back out "Do not decompose in functionalization/proxy tensor if autograd wouldn't have decomposed ( #164939 )" ( #165910 )" ( #166812 )
...
This reverts commit e6ba4d072510464c846f2013822f9388210eb907.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166812
Approved by: https://github.com/SherlockNoMad
2025-11-03 07:21:20 +00:00
a5f00077fc
torch.cond supports autograd now ( #165908 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165908
Approved by: https://github.com/zou3519 , https://github.com/ydwu4 , https://github.com/Skylion007
2025-11-03 06:16:15 +00:00
69fb3ebb5d
Fix: type promotion in FakeTensor ( #166522 )
...
Fixes #166042
common_dtype is being alloted first datatype even though one is passing some other value in type_promotions. Putting a condition around the same.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166522
Approved by: https://github.com/Lucaskabela
2025-11-03 06:11:35 +00:00
1c4ced2eaf
[2/N] Correctly use test parameters ( #166783 )
...
This PR fixes unused test parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166783
Approved by: https://github.com/mlazos
2025-11-03 05:36:52 +00:00
392acee68a
[6/N] Remove unused loop variables in tests ( #166785 )
...
This PR removes unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166785
Approved by: https://github.com/Skylion007 , https://github.com/mlazos
2025-11-03 03:52:52 +00:00
fee1ac927d
[DebugMode] add stack traces ( #166440 )
...
Captures stack trace for torch_dispatch calls, under `with DebugMode(record_stack_trace=True)`: Traces aren't rendered in debug string, but are in `.stack_trace` for each log.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166440
Approved by: https://github.com/yushangdi
2025-11-03 02:48:09 +00:00
4a7fefd7c7
[dynamo] fix pos-only names should can be collected in **kwargs ( #166798 )
...
See the new testcase for more details. It fails on trunk and is fixed by this PR.
```python
In [1]: def func(a, /, **kwargs):
...: return a, kwargs
In [2]: func(1, a=2)
Out[2]: (1, {'a': 2})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166798
Approved by: https://github.com/guilhermeleobas
2025-11-03 02:40:34 +00:00
3b4315940d
[export] Fix static_input_indices for aot_export_joint ( #166761 )
...
`static_input_indices` is used for cudagraphs to determine which input indices are static and will not have changing addresses. Since export never integrated with cudagraphs this information was not necessary. But now we need it!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166761
Approved by: https://github.com/BoyuanFeng
2025-11-03 01:57:51 +00:00
3eddf04922
Revert "Add min/max support for barebones uint types ( #166813 )"
...
This reverts commit 9c22bbb2dce31b854e3387db77eaff501434f352.
Reverted https://github.com/pytorch/pytorch/pull/166813 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166813#issuecomment-3478450413 ))
2025-11-02 22:50:36 +00:00
7c203b8420
[BE] Using std::move to reduce copy constructor calls by one. ( #163599 )
...
inspired by https://github.com/pytorch/pytorch/pull/163416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163599
Approved by: https://github.com/Skylion007
2025-11-02 21:54:58 +00:00
3ca216ae17
Add claude skills for uint support and AT_DISPATCH_V2 ( #166814 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166814
Approved by: https://github.com/Skylion007 , https://github.com/malfet
ghstack dependencies: #166813
2025-11-02 21:36:19 +00:00
9c22bbb2dc
Add min/max support for barebones uint types ( #166813 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166813
Approved by: https://github.com/Skylion007
2025-11-02 21:36:19 +00:00
6268883f9c
[MPS] Refactor torch.cat and add fast path for contiguous inputs ( #166556 )
...
In many cases when the fast path is used, the performance is pretty similar to what it used to be. However, with tensors on the order of about 1000 elements there is a modest speedup, which increases as the number of input tensors increases and the number of dimensions increases.
This script was used for performance comparison: <1f04647bbf/cat/perf0.py >
Before change:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000843 ms, 0.010431 ms, 0.08, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000838 ms, 0.013467 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000792 ms, 0.009457 ms, 0.08, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000834 ms, 0.010694 ms, 0.08, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000627 ms, 0.000641 ms, 0.98, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001172 ms, 0.006493 ms, 0.18, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000812 ms, 0.006148 ms, 0.13, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000686 ms, 0.009382 ms, 0.07, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000738 ms, 0.006532 ms, 0.11, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003835 ms, 0.193963 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.552435 ms, 0.690500 ms, 0.80, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.488799 ms, 0.708988 ms, 0.69, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 0.000799 ms, 0.005997 ms, 0.13, cat, [[tensor(shape[1000]), tensor(shape[1000])]], {'dim': 0}
13: 0.000916 ms, 0.011791 ms, 0.08, cat, [[tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]], {'dim': 0}
14: 0.001028 ms, 0.012269 ms, 0.08, cat, "[[tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1...", {'dim': 0}
15: 0.001127 ms, 0.025197 ms, 0.04, cat, "[[tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(s...", {'dim': 0}
16: 0.321997 ms, 0.142815 ms, 2.25, cat, [[tensor(shape[1000000]), tensor(shape[1000000])]], {'dim': 0}
17: 1.989967 ms, 1.013615 ms, 1.96, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
18: 3.161745 ms, 0.965378 ms, 3.28, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
19: 3.416246 ms, 0.972278 ms, 3.51, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
After change:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000902 ms, 0.011074 ms, 0.08, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000899 ms, 0.010453 ms, 0.09, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000771 ms, 0.005843 ms, 0.13, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000776 ms, 0.010449 ms, 0.07, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000616 ms, 0.000600 ms, 1.03, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001150 ms, 0.007624 ms, 0.15, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000728 ms, 0.007949 ms, 0.09, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000671 ms, 0.005458 ms, 0.12, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000770 ms, 0.006590 ms, 0.12, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003835 ms, 0.190193 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.529047 ms, 0.734389 ms, 0.72, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.512615 ms, 0.531172 ms, 0.97, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 0.000740 ms, 0.004288 ms, 0.17, cat, [[tensor(shape[1000]), tensor(shape[1000])]], {'dim': 0}
13: 0.000955 ms, 0.004119 ms, 0.23, cat, [[tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]], {'dim': 0}
14: 0.001037 ms, 0.004578 ms, 0.23, cat, "[[tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1000]), tensor(shape[1...", {'dim': 0}
15: 0.001115 ms, 0.004918 ms, 0.23, cat, "[[tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(shape[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), tensor(s...", {'dim': 0}
16: 0.334119 ms, 0.145008 ms, 2.30, cat, [[tensor(shape[1000000]), tensor(shape[1000000])]], {'dim': 0}
17: 2.419846 ms, 0.984192 ms, 2.46, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
18: 3.117338 ms, 1.000345 ms, 3.12, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
19: 3.047707 ms, 0.971730 ms, 3.14, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166556
Approved by: https://github.com/malfet
2025-11-02 21:27:05 +00:00
16212f0d6b
[Sparse] support for exp op ( #166801 )
...
support for exp op in Sparse tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166801
Approved by: https://github.com/eqy
2025-11-02 21:14:43 +00:00
c8adc08b3b
[Fix] Optimize max unpooling index validation using aminmax ( #165394 )
...
Replace separate min() and max() calls with single aminmax() call in max_unpool_out_mps_template to improve performance by reducing tensor traversals from O(2n) to O(n).
Changes:
- Use indices.aminmax() instead of separate indices.min()/max() calls
- Add required ATen/ops/aminmax.h header for AT_PER_OPERATOR_HEADERS
- Maintain identical bounds checking logic and error handling
This optimization is particularly beneficial for large indices tensors, improving cache locality and reducing computational overhead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165394
Approved by: https://github.com/cyyever , https://github.com/Skylion007
2025-11-02 19:42:02 +00:00
23b57a445c
Remove setup-env instructions; it's confusing ( #166749 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166749
Approved by: https://github.com/mlazos
2025-11-02 19:22:53 +00:00
6c7cad6972
Use Python 3.10 typing ( #148418 )
...
Use Python 3.10 typing in some files
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148418
Approved by: https://github.com/mlazos
2025-11-02 16:16:52 +00:00
bb54296258
Fix source_fn_stack being None ( #166728 )
...
Summary: Apparently source_fn_stack can be empty
Test Plan: CI
Differential Revision: D85956753
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166728
Approved by: https://github.com/SS-JIA , https://github.com/Skylion007 , https://github.com/mlazos , https://github.com/atalman
2025-11-02 13:50:16 +00:00
5e05a0ae99
Revert "Fix: list index out of range with softmax when using 0 dim ( #166547 )"
...
This reverts commit 0674e0a0f14775f920296e9dfb8b61e4960bf99d.
Reverted https://github.com/pytorch/pytorch/pull/166547 on behalf of https://github.com/atalman due to Fail: test/test_torchfuzz_repros.py::TestFuzzerCompileIssues::test_fuzzer_issue_163971 [GH job link](https://github.com/pytorch/pytorch/actions/runs/19008635308/job/54286552036 ) [HUD commit link](0674e0a0f1 ) ([comment](https://github.com/pytorch/pytorch/pull/166547#issuecomment-3477962809 ))
2025-11-02 13:29:03 +00:00
298666631b
[user-streams] Switch to fx annotations at trace time ( #166472 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166472
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819 , #165211 , #165212 , #165356 , #164523 , #162905 , #166471
2025-11-02 11:55:51 +00:00
e471800dce
[user-streams] cleanup StreamVariable signature ( #166471 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166471
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #164819 , #165211 , #165212 , #165356 , #164523 , #162905
2025-11-02 11:55:51 +00:00
18f4259626
[dynamo] Remove retrieving objects by ID ( #162905 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162905
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819 , #165211 , #165212 , #165356 , #164523
2025-11-02 11:55:43 +00:00
d962bed157
[user-streams] Add basic stream tests ( #164523 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164523
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819 , #165211 , #165212 , #165356
2025-11-02 11:55:37 +00:00
76780b1a3d
[user-streams] Handle returning the current stream with/without device index ( #165356 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165356
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819 , #165211 , #165212
2025-11-02 11:55:30 +00:00
cee03634da
[user-streams] Track symbolic current stream ( #165212 )
...
merge into stream tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819 , #165211
2025-11-02 11:55:22 +00:00
bc03d7c974
[user-streams] Add current stream source ( #165211 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
Approved by: https://github.com/anijain2305
ghstack dependencies: #164819
2025-11-02 11:55:15 +00:00
f013e804c8
[user-streams] Fix stream graph output semantics ( #164819 )
...
Preivously, we would stash a single stream value we constructed at trace time in a global and return the same value from repeated calls to the graph.
With this PR, we construct the stream value in advance, reference the constructed value in the graph via the lookup table, and if that value is returned as an output, read the value from the lookup table and return it (in bytecode, not as a graph output, since we don't support arbitrary stream outputs).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
Approved by: https://github.com/anijain2305
2025-11-02 11:55:08 +00:00
0674e0a0f1
Fix: list index out of range with softmax when using 0 dim ( #166547 )
...
Fixes #163971
Problem:
PyTorch's inductor compiler crashed with IndexError: list index out of range when compiling code that uses 0-dimensional tensors with operations like torch.softmax(scalar_tensor, dim=0).
A 0-dim tensor has shape = torch.Size([]) (empty shape)
```
ndim = 0 (zero dimensions)
len(shape) = 0 (no indices to access)
# Line 972: Pad other_shape to match inp dimensions
other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape)
# For scalar tensors:
# inp_ndim = 0 # as input is scalar
# other_shape = []
# Result: [1] * (0 - 0) + [] = [] (still empty!)
dim = match.kwargs["dim"] # dim = 0
if isinstance(dim, int):
dim = (dim,)
# crash is happening here!
return all(statically_known_true(other_shape[d] == 1) for d in dim)
# ^^^^^^^^^^^^^^^^
# Tries other_shape[0] but other_shape = [] (empty!)
# → IndexError: list index out of range
```
The function _other_is_broadcasted_in_dim() is an optimization check for a softmax fusion pattern. It verifies whether it's safe to rewrite:
```
# From
scaled = inp * other
result = scaled - scaled.amax(dim, keepdim=True)
# To this more stable form:
result = (inp - inp.amax(dim, keepdim=True)) * other
```
The optimization is only valid if other is constant across the reduction dimension (i.e., broadcasted to size 1 in that dimension). Otherwise, scaling changes which element is the maximum.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166547
Approved by: https://github.com/jansel , https://github.com/eellison , https://github.com/leslie-fang-intel
2025-11-02 06:43:34 +00:00
b7d348a907
[vision hash update] update the pinned vision hash ( #166771 )
...
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml ).
Update the pinned vision hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166771
Approved by: https://github.com/pytorchbot
2025-11-02 04:24:38 +00:00
9f9dbe0a9a
add a curve for customized compilation in the kernel benchmarking scripts ( #166697 )
...
It's nice to add a curve with a customized compilation options so that we can compare side-by-side the perf improvement of new features.
E.g. for mix-order-reduction, by running the following command
```
python benchmarks/dynamo/genai_layers/benchmark.py --tolerance=1e-2 --exit-on-accuracy-failure --visualize rmsnorm_backward --custom-compile-name="compiled-no-fusion" --custom-compile-options='{"triton.mix_order_reduction":false}'
```
I get following output:
```
Geomean speedup for benchmark RMSNormBackward
eager 11 data points
compiled 11 data points, 15.82x speedup
quack 11 data points, 15.45x speedup
liger 11 data points, 14.06x speedup
compiled-no-fusion 11 data points, 10.26x speedup
```
The output shows that the feature on average improve perf by `15.82 / 10.26 = 1.54x` for all the shapes tested. (I remove a shape (32768, 32768) whose rnumel is too large and not representative).
The new curve also shows up in the figure:
<img width="3564" height="2368" alt="RMSNormBackward_bench" src="https://github.com/user-attachments/assets/1ffac2bc-e726-4f1e-806d-e9e5de711492 " />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166697
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #166053 , #166382 , #166461 , #166585 , #166675
2025-11-01 22:09:56 +00:00
a19e92d433
report geomean for norm bwd benchmarking ( #166675 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166675
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #166053 , #166382 , #166461 , #166585
2025-11-01 22:09:56 +00:00
c3dc0c7089
[Inductor] mix order reduction heuristics and tuning ( #166585 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166585
Approved by: https://github.com/jansel , https://github.com/PaulZhang12
ghstack dependencies: #166053 , #166382 , #166461
2025-11-01 22:09:48 +00:00
04d6a6f339
[inductor] Make mix-order-reduction split size not depends on split-reduction heuristics ( #166461 )
...
split size is critical for mix order reduction perf while the one picked by split reduction heuristics can be very bad for mix order reduction.
<img width="1197" height="596" alt="Screenshot 2025-10-27 at 11 17 16 PM" src="https://github.com/user-attachments/assets/7faa11ad-3a7a-4b29-90ed-e85fc01077ea " />
For the first shape in the chart, split reduction picks a split-size around 2000 and results in poor perf. It important to allow mix-order reduction decides split size itself. (ss_8 in the chart means split-size == 8)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166461
Approved by: https://github.com/jansel , https://github.com/v0i0
ghstack dependencies: #166053 , #166382
2025-11-01 22:09:40 +00:00
0573747b6a
[inductor] more aggressive mix order reduction ( #166382 )
...
More aggressive mix order reductions so that when rnumel is larger than 1024 we can still generate the fused kernel. Also use more warps in that case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166382
Approved by: https://github.com/jansel , https://github.com/v0i0
ghstack dependencies: #166053
2025-11-01 22:09:32 +00:00
a663eb9c80
[FlexFlash] CuteDSL flat indexer needs to be colexigraphic in coordinate space ( #166657 )
...
Benchmarks on Hopper:
Note the triton impl is not using max-autotune because I didnt feel like waiting for 90x plots
<img width="12517" height="5995" alt="combined_comparison" src="https://github.com/user-attachments/assets/d94debd9-920d-4413-b51f-b8e906e4fb01 " />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166657
Approved by: https://github.com/v0i0 , https://github.com/mlazos , https://github.com/eellison
ghstack dependencies: #166359
2025-11-01 21:18:51 +00:00
764c54ecae
[DebugMode] dispatch call hooks ( #166348 )
...
Adds customizable hooks on `__torch_dispatch__` calls for logging/recording arbitrary values.
Recording hooks store the hook outputs for each call at `debug_mode.operators[*].record`
```python
with DebugMode() as debug_mode, DebugMode.dispatch_hooks(record_hook = some_func):
# some compute
...
```
Logging hooks annotate the string dump:
```python
with DebugMode() as debug_mode, DebugMode.dispatch_hooks(log_hook = some_func):
...
```
Adds default hooks `DebugMode.record_outputs()` and `DebugMode.log_tensor_hashes()`, for checking numerical equivalence. The hashing hook borrows from the Observer. Example dump:
```
aten::sum(dt: f32[8, 32]| S(0))
aten::sum(t: f32[1, 32]) # {'hash': 3.2215590476989746}
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) # {'hash': 204.8783062621951}
_c10d_functional::wait_tensor(t: f32[8, 32]) # {'hash': 204.8783062621951}
aten::mm(t: f32[1, 8], t: f32[8, 32]) # {'hash': 12.014171155635267}
aten::sum(t: f32[1, 32]) # {'hash': 3.2215590476989746}
aten::t(t: f32[1, 8]) # {'hash': 3.7167285680770874}
aten::detach(t: f32[8, 1]) # {'hash': 3.7167285680770874}
...
```
On the FSDP2 / simple FSDP NE in https://github.com/pytorch/pytorch/pull/164939 , with hashing, this produces 2 log dumps (FSDP2: P2010198620, simple FSDP: P2010198963). I asked Claude to check the hashes, it wrote an analysis script, and was able to guess RMS norm as the root cause: P2010195076
Another throw-away example for logging per-op memory usage: https://gist.github.com/pianpwk/372082bf29467aa4aa25cb26dee24aea
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166348
Approved by: https://github.com/yushangdi
2025-11-01 21:10:43 +00:00
0d81bb7f9c
[3/N] Use 'is' in callable comparisons ( #166780 )
...
It is generally advised to use `is/is not` for comparisons against torch functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166780
Approved by: https://github.com/Skylion007
2025-11-01 20:23:56 +00:00
82fafb3304
Revert "Make PT2 compile backprop through custom op without autograd key a hard error ( #166367 )"
...
This reverts commit 84776e13744db6d59b41a063bb8714e2bffe7a06.
Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to backends/xnnpack/test/recipes/test_xnnpack_recipes.py::TestXnnpackRecipes::test_all_models_with_recipes [GH job link](https://github.com/pytorch/pytorch/actions/runs/18999845549/job/54266149620 ) [HUD commit link](84776e1374 ) ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3476757660 ))
2025-11-01 20:14:22 +00:00
401c2f9657
[FP8][H100][TF32] Disable tf32 for emulated reference computation in test_scaled_mm_vs_emulated_block_wise ( #162997 )
...
Fails with 2 mismatches otherwise
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162997
Approved by: https://github.com/Skylion007
2025-11-01 20:13:11 +00:00
13549e0e10
Revert "Avoid DDE in narrow with unbacked start ( #166361 )"
...
This reverts commit 1aef88c72d3aef629b20e97a188c9dc4bab46a1a.
Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/atalman due to examples/models/llama/tests/test_export_llama_lib.py::ExportLlamaLibTest::test_has_expected_ops_and_op_counts [GH job link](https://github.com/pytorch/pytorch/actions/runs/18993202115/job/54257916041 ) [HUD commit link](1aef88c72d ) ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3476752974 ))
2025-11-01 20:07:01 +00:00
82d86bacf3
[inductor] track reduction before splitting ( #166053 )
...
Keep tracking of the reduction before splitting.
In the mix-order reduction context, if one of the reduction is split, it makes it much harder to fuse with the other reduction. Tracking the metadata of the reduction before splitting to make the fusion possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166053
Approved by: https://github.com/jansel
2025-11-01 19:41:21 +00:00
3b5d38a3bc
Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. ( #166277 )
...
Fix https://github.com/pytorch/pytorch/issues/163894
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166277
Approved by: https://github.com/Lucaskabela
2025-11-01 19:26:20 +00:00
84776e1374
Make PT2 compile backprop through custom op without autograd key a hard error ( #166367 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367
Approved by: https://github.com/bdhirsh
2025-11-01 17:01:31 +00:00
b3861ac8e7
[reland] Warn if AccumulateGrad stream does not match producer node stream ( #166136 )
...
ghstack-source-id: 59641aa32dc6fd027abf3276017432b693aa71f8
Pull-Request-resolved: https://github.com/pytorch/pytorch/pull/165065
Fixes #ISSUE_NUMBER
Opening a new PR for codev
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166136
Approved by: https://github.com/ngimel
2025-11-01 12:33:48 +00:00
4cc64d6234
[inductor] pre grad graph bisecting ( #166344 )
...
A few things to note:
1. Customers like vllm use a custom backend (e.g. VllmBackend), split the graph, and call standalone_compile for each split. If we let the bisector override the backend, we won't bisect thru the custom backend. `test_configs.bisect_keep_custom_backend_for_inductor` is used to keep the custom backend if we are bisecting for inductor.
2. pre_grad_graph bisecting and lowering bisecting so far does not compose well with each other since an issue may be just captured by the first one we try. `test_configs.bisect_pre_grad_graph` is used to enable the 'pre_grad_graph' bisecting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166344
Approved by: https://github.com/eellison
2025-11-01 09:22:21 +00:00
1aef88c72d
Avoid DDE in narrow with unbacked start ( #166361 )
...
Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice.
The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate,
for that case we shall pass dim_size instead of start+length
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361
Approved by: https://github.com/aorenste
2025-11-01 07:10:23 +00:00
f0745ddb11
Replace c10::call_once with static initialization ( #166381 )
...
This PR replaces c10::call_once calls with static initialization when possible. C++11 semantics guarantees that static initialization is atomic. Static initialization also has lower cost than using c10::call_once.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166381
Approved by: https://github.com/malfet
2025-11-01 07:09:40 +00:00
4316df857c
[3.14] Fix torch.package.importer ( #166767 )
...
That relies on internal implementation of `picker._getattribute` which
changed from (i.e. takes object and string and returns tuple)
9ab89c026a/Lib/pickle.py (L316)
To (takes object and iterable of strings and returns object
631ba3407e/Lib/pickle.py (L315)
Test plan:
```
python -c "import torch; print(torch.package.sys_importer.get_name(torch.cuda.Stream))"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166767
Approved by: https://github.com/williamwen42
2025-11-01 05:05:47 +00:00
9d6597b1e9
Correctly use test parameters ( #166726 )
...
This PR uses unused arguments in some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166726
Approved by: https://github.com/rec , https://github.com/albanD , https://github.com/Skylion007
2025-11-01 04:43:31 +00:00
e8fadba28c
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification ( #160843 )
...
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.
Changes:
1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-11-01 04:12:11 +00:00
60333de85d
Revert "Remove setup-env instructions; it's confusing ( #166749 )"
...
This reverts commit 3dc92d69ed40fd952244e54bbda0240928756654.
Reverted https://github.com/pytorch/pytorch/pull/166749 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166749#issuecomment-3475481831 ))
2025-11-01 02:55:56 +00:00
3dc92d69ed
Remove setup-env instructions; it's confusing ( #166749 )
...
Signed-off-by: Edward Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166749
Approved by: https://github.com/mlazos
2025-11-01 01:48:15 +00:00
f91899ca6c
[2/N] Add strict parameter to Python zip calls ( #166257 )
...
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-11-01 00:35:41 +00:00
e2dc32f4ba
Replace decltype(auto) with auto ( #166537 )
...
This PR replaces `decltype(auto)` with `auto` for C++ return type deduction and simplifies some templates.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166537
Approved by: https://github.com/Skylion007
2025-11-01 00:30:23 +00:00
83cc38d9c1
[precompile] Preserve default arguments for dynamo capture ( #166654 )
...
Summary:
Handle the case where there's default arguments on function signature.
Test Plan:
pytest test/export/test_experimental.py -k test_dynamo_graph_capture_default_args
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166654
Approved by: https://github.com/tugsbayasgalan
2025-11-01 00:12:10 +00:00
8d599045cf
add shape check for avg_pool2d ( #161952 )
...
Fix https://github.com/pytorch/pytorch/issues/153312 .
**Example:**
```python
import torch
print(torch.__version__)
tensor = torch.tensor([[ -7.8130e-88, -2.2092e-138, -1.8673e+03, -7.6272e-253, 3.9203e+110,
1.8380e-51, 2.8762e+268, 2.9094e+286, 5.1816e-228, -4.4916e+191,
-7.4057e+80, -9.1955e-18, 5.6536e+225, 8.8364e-175, 1.5053e-226],
[-3.0521e+239, -2.8307e+306, 1.3297e-03, -9.9969e-132, 2.8920e-286,
2.3964e+58, -6.8138e-281, 2.0321e-305, -3.5127e+74, -4.7560e-92,
-8.9403e-99, -1.9739e-187, -2.5124e-173, 2.0458e+295, 4.4992e+52],
[ 6.8752e+21, 1.9332e+189, -8.6940e-189, -6.6743e-15, 1.4691e+41,
1.0338e+63, -2.0779e-28, -7.6642e+104, 1.3390e+284, -8.0859e+194,
8.4600e+107, 4.9115e-44, 1.1665e+285, 5.1275e+203, 9.7580e+303]],
dtype=torch.float64)
try:
res = torch.nn.functional.lp_pool1d(
tensor,
norm_type=-1.38119e+150,
kernel_size=7879455037536781369,
ceil_mode=True,
)
print("CPU result:", res)
except RuntimeError as e:
print(f"CPU error: {e}")
tensor_gpu = tensor.to("cuda:0")
try:
res = torch.nn.functional.lp_pool1d(
tensor_gpu,
norm_type=-1.38119e+150,
kernel_size=7879455037536781369,
ceil_mode=True,
)
print("GPU result:", res)
except RuntimeError as e:
print(f"GPU error: {e}")
```
**Output:**
- before
```
2.9.0a0+git8703deb
CPU result: tensor([[0.],
[0.],
[0.]], dtype=torch.float64)
GPU error: integer out of range
```
- after
```
2.9.0a0+git2e893df
CPU error: integer out of range
GPU error: integer out of range
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161952
Approved by: https://github.com/mingfeima , https://github.com/malfet
2025-10-31 22:52:41 +00:00
fd5da81fdd
[AI Codemod][DevmateFBSourceTestFailureBot] Fix for T243177299 ("Your diff, D85182174, broke some tests") ( #166753 )
...
Summary:
As per title, a bot created this diff because this test broke due to [a different PR.](https://github.com/pytorch/pytorch/pull/166026 )
<Erased bot summary in case anything we don't want to make external.>
Test Plan:
Bot ran the tests and they passed.
<Erased bot test plan in case anything we don't want to make external.>
Differential Revision: D85745809
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166753
Approved by: https://github.com/d4l3k
2025-10-31 22:49:59 +00:00
9261a1fb12
[MPS] Error out when BatchNorm is called for Complex ( #166215 )
...
Or BatchNorm or LayerNorm for Long types
Discovered while trying to enable `test_ops.py` for MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166215
Approved by: https://github.com/dcci , https://github.com/kulinseth , https://github.com/Skylion007
ghstack dependencies: #166214 , #166687
2025-10-31 22:44:29 +00:00
d80ae738c9
compile_worker: Make a timer class ( #166465 )
...
This subclass allows us to trigger an action after we haven't seen any activity
for a certain amount of seconds.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166465
Approved by: https://github.com/masnesral
2025-10-31 22:39:31 +00:00
51667435f5
[FlexFlash] Wire up mask_mod + blockmask to flash impl ( #166359 )
...
I have some local changes that I need to push to flash first
https://github.com/Dao-AILab/flash-attention/pull/1970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166359
Approved by: https://github.com/v0i0
2025-10-31 22:07:40 +00:00
2699f5410b
Revert "[xpu][feature] Integrate OneDNN SDPA training forward/backward into XPU OVERRIDEABLE Backend ( #162454 )"
...
This reverts commit fd68d409ada709450ced3030bde89ec662a3f7b7.
Reverted https://github.com/pytorch/pytorch/pull/162454 on behalf of https://github.com/atalman due to internal build failure ([comment](https://github.com/pytorch/pytorch/pull/162454#issuecomment-3475009089 ))
2025-10-31 21:58:52 +00:00
9970fb97ff
Fix Tril Triu SymInt ( #166627 )
...
Fixes #165613
### Summary:
- This MR fixes an issue where `torch.tril `and `torch.triu` with dynamic diagonal values cause torch.export to incorrectly infer unnecessary constraints between dynamic dimensions.
- Ensured proper SymInt type annotations for diagonal parameter
- Updated C++ implementation to correctly handle SymInt diagonal values.
### Impacts:
module: dynamic shapes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166627
Approved by: https://github.com/ezyang , https://github.com/Skylion007
2025-10-31 21:53:20 +00:00
dfebdcab86
[GraphPartition] cache get_free_symbol_uses ( #166338 )
...
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)
I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.
Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
ee7434be82/torch/_inductor/ir.py (L4541-L4543)
This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338
Approved by: https://github.com/eellison
2025-10-31 21:24:05 +00:00
b09fb481e0
[CD] Upgrade GCC version to 13 for XPU build ( #162474 )
...
Follow #152426
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162474
Approved by: https://github.com/zxiiro , https://github.com/atalman
2025-10-31 21:15:37 +00:00
4e7232c5da
[MPS] Fix smooth_l1_loss backward for fp16 ( #166687 )
...
And enable fp16 implementation for CPU, which simplifies OpInfo definitions for the op
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166687
Approved by: https://github.com/Skylion007
ghstack dependencies: #166214
2025-10-31 21:13:46 +00:00
93a70c717a
Revert "Add CUDA MXFP4 scaled mm support via. FBGEMM ( #166526 )"
...
This reverts commit e3ae0594d16134632ff587c9ab400d4148c83e9f.
Reverted https://github.com/pytorch/pytorch/pull/166526 on behalf of https://github.com/atalman due to Failing internal test ([comment](https://github.com/pytorch/pytorch/pull/166526#issuecomment-3474907536 ))
2025-10-31 21:10:28 +00:00
d97144d31e
[5/N] Remove unused loop variables in tests ( #166716 )
...
This PR removes unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166716
Approved by: https://github.com/Lucaskabela , https://github.com/Skylion007
2025-10-31 20:47:57 +00:00
e4043884c7
[dynamo, 3.14] fix segfault due to improper create_call_function_ex ( #166678 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166678
Approved by: https://github.com/malfet
2025-10-31 20:44:53 +00:00
4a7bc1d522
[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ ( #166569 )
...
Provides type coverage to ~3000 LOC and 200 methods in `torch/_dynamo/variables/`
This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)
### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42 , https://github.com/Skylion007
2025-10-31 20:42:27 +00:00
8209a0506b
[Pytorch] Enable aarch64 convert autovec only on clang ( #166739 )
...
Summary: We've noted issues with modern GCC versions. Until further investigation is carried, we'll leave the code only enabled on clang
Test Plan: CI
Differential Revision: D85968395
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166739
Approved by: https://github.com/mcfi , https://github.com/Skylion007 , https://github.com/robert-hardwick
2025-10-31 20:22:33 +00:00
70aeb49198
[dynamo] clarify graph break handling/logging in symbolic_convert ( #166587 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166587
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166476 , #166477 , #166586
2025-10-31 20:13:16 +00:00
cf9a834f39
[BE] Move GreenContext implementation details to cpp ( #166462 )
...
- Remove all complex defines logic from the header
- Make GreenContext constructor private, as it should only be created via the static method as singleton
- Delete unused `getContext` and `getGreenContext` methods
- Rename `CUDA_HAS_GREEN_CONTEXT` to `HAS_CUDA_GREEN_CONTEXT()`, which results in compilation error if one accidentally makes a typo
- Suppress `-Wunused-private-field` is GreenContext is not available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166462
Approved by: https://github.com/ngimel , https://github.com/eqy
2025-10-31 20:11:02 +00:00
856a7a5298
Add missing device to namedtensor tests ( #166717 )
...
This PR passes unused `device` argument to tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166717
Approved by: https://github.com/Skylion007
2025-10-31 20:04:41 +00:00
ef8d97efcf
fix broken nn_convolution test ( #166666 )
...
Summary: Broken by oss diff during oncall by third party contributor
Test Plan: buck test 'fbcode//mode/dev-nosan' fbcode//caffe2/test:nn_convolution -- --run-disabled
Differential Revision: D85899891
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166666
Approved by: https://github.com/atalman , https://github.com/seemethere , https://github.com/Skylion007
2025-10-31 19:59:50 +00:00
d2be06f673
[cpu][fix] Update ACL version to fix crashes with tensor sizes > 2^31-1 ( #165904 )
...
----
- Updates Arm Compute Library (ACL) to v52.6.0
- v52.6.0 contains https://github.com/ARM-software/ComputeLibrary/pull/1201 which fixes crashes with tensors of sizes > 2^31-1
fixes : #165654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165904
Approved by: https://github.com/malfet
2025-10-31 19:37:26 +00:00
08f4535378
Refactor AOTAutogradCacheEntry into AOTAutogradResult ( #166656 )
...
This PR refactors the name AOTAutogradCacheEntry into AOTAutogradResult, and BundledAOTAutogradCacheEntry into BundledAOTAutogradResult. It also moves all coresponding files to a new file, `aot_autograd_result`, which is analogous to `output_code.py` from Inductor.
Having all these be called cache entries made sense when all we used them for was caching. But with AOT compile using BundledAOTAutogradCacheEntry, we want a more generalized naming structure.
This is a no-op change, and all existing tests should pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166656
Approved by: https://github.com/zhxchen17
ghstack dependencies: #166650
2025-10-31 18:54:09 +00:00
30157d30f0
Add regional aot eager support to AOTAutogradCacheEntry ( #166650 )
...
This PR does two things:
- It genericizes `BundledAOTAutogradCacheEntry` to support *any* outputcode, not just CompiledFxGraphs
- It adds a brand new OutputCode for the `aot_eager_regional_inductor` backend, i.e. a graph module that has regional inductor components in it.
This allows BundledAOTAutogradCache to just integrate nicely with inductor out of the box, but more importantly, it allows the result of aot_autograd to be fully serializable when using `aot_eager_regional_inductor`. This will allow us to AOT precompile cases where we have an eager graph that has scooped up inductor bits.
It's a bit unfortunate that the naming makes BundledAOTAutogradCacheEntry sound like its primary use is for caching, but really the more common use is going to be as an AOTAutogradOutput. It may be worth revisiting how to refactor/rename these in a later PR:
- AOTAutogradCacheEntry -> AOTAutogradResult
- BundledAOTAutogradCacheEntry -> BundledAOTAutogradResult
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166650
Approved by: https://github.com/zhxchen17
2025-10-31 18:54:09 +00:00
b470e59c38
partitioner option to ignore partitioner_tag for abstract usage ( #166725 )
...
Partitioner functionality is appealing to use in different scenarios (E.g. Autoparallel)
We have special logic about "partitioner_tag" from meta that is only needed for forward/backward split.
Adding optional argument to avoid it and do only generic split based on inputs/outputs.
Potentially we want to make `_extract_graph_with_inputs_outputs` without underscore :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166725
Approved by: https://github.com/bdhirsh
2025-10-31 18:50:02 +00:00
85b85f6c2c
Revert "[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification ( #160843 )"
...
This reverts commit 108bb224f77842593009214ebf6258030b934642.
Reverted https://github.com/pytorch/pytorch/pull/160843 on behalf of https://github.com/atalman due to failing internal builds ([comment](https://github.com/pytorch/pytorch/pull/160843#issuecomment-3474354428 ))
2025-10-31 18:31:32 +00:00
b71966f67b
[PyTorch] Improve aarch64 performance of bfloat16 ops - retry ( #166028 ) ( #166641 )
...
Summary:
PR allows compiler to better optimize some bfloat16-based operations, when ran on NEON
Retrying to land the code, after noting that these expressions became available in recent compiler versions.
Current CI benchmark binary_test.py will measure affected codepaths.
Benchmarks show measurable improvements on clang-19, when targeting armv9-a+sve2:
Before:
bfloat16 add: 250.503us
bfloat16 sub: 245.674us
bfloat16 neg: 113.945us
bfloat16 abs: 115.953us
bfloat16 reciprocal: 262.602us
After:
bfloat16 add: 203.862us ---> 23% higher throughput
bfloat16 sub: 201.526us ---> 22% higher throughput
bfloat16 neg: 68.416us ---> 67% higher throughput
bfloat16 abs: 71.003us ---> 63% higher throughput
bfloat16 reciprocal: 177.834us ---> 48% higher throughput
Test Plan:
Correctness:
buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch
Performance:
buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test
Reviewed By: mcfi
Differential Revision: D85809843
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166641
Approved by: https://github.com/Skylion007 , https://github.com/malfet
2025-10-31 18:21:04 +00:00
0947765eb9
Cache even more work for return_and_correct_aliasing ( #166365 )
...
Yet another pass found even more work we can move to be done only once. This seems to knock a few microseconds off the DTensor dispatch fast path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166365
Approved by: https://github.com/bdhirsh
2025-10-31 18:03:05 +00:00
239e7b541a
[ROCm][CI] upgrade nightly wheels to ROCm 7.1 ( #166730 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166730
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
2025-10-31 17:30:47 +00:00
ffaa6578b7
Revise deprecation warning for ONNX exporter ( #166692 )
...
Updated deprecation warning for ONNX export to reflect the current state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166692
Approved by: https://github.com/titaiwangms
2025-10-31 17:23:55 +00:00
365ed62f61
Document LibTorch ABI more, add README to headeronly ( #166661 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166661
Approved by: https://github.com/mikaylagawarecki , https://github.com/albanD
2025-10-31 17:18:13 +00:00
fcc1063566
Revert "[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ ( #166569 )"
...
This reverts commit aa9c96af041b26c9c55adac490f3449b98f27d06.
Reverted https://github.com/pytorch/pytorch/pull/166569 on behalf of https://github.com/Lucaskabela due to Lintrunner not fixed due to race condition at landing ([comment](https://github.com/pytorch/pytorch/pull/166569#issuecomment-3474012637 ))
2025-10-31 16:59:33 +00:00
121235956b
update Node.is_impure check if subgraph contains impure ops ( #166609 )
...
Summary:
## Context
when `const_fold.split_const_subgraphs` sees a `call_module` node that is a GraphModule, by the existing implementation it can mark this node as const-foldable when it shouldn't.
For example, a parent graph contains a `call_module` to a subgraph that has no inputs but contain impure ops inside.
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cpu, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
when `submodule` graph is fed to const_fold.split_const_subgraph, it would come out unmodified since randn is impure.
But if the `submodule` is called by a `parent` graph, when `parent` is fed to const_fold.split_const_subgraph, it would come out folded.
```
parent after fold graph():
%_fx_const_folded_attrs : [num_users=1] = get_attr[target=_FX_CONST_FOLDED_ATTRS]
return (_fx_const_folded_attrs,)
```
This is because `node.is_impure()` check inside `const_fold.split_const_subgraph` fail through, leading the call_module node to be marked as pure.
## Fix
We can update `fx.node.Node.is_impure` function to check for ops inside a call_module node with an additional `subgraph_has_impure_ops` check:
- if a call_module node calls a GraphModule,
- check any call_function nodes are impure ops
- recursively check any call_module nodes that call GraphModule
If the call_module subgraph has impure ops, return True to `is_impure`
Test Plan: added tests to test_fx_const_fold.py
Differential Revision: D85798483
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166609
Approved by: https://github.com/blaine-rister
2025-10-31 16:58:18 +00:00
aa9c96af04
[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ ( #166569 )
...
Provides type coverage to ~3000 LOC and 200 methods in `torch/_dynamo/variables/`
This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)
### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42
2025-10-31 16:56:50 +00:00
c3b71d5499
[ROCm][CI] remove relaxed tolerance for tf32 tests ( #166478 )
...
Instead of relaxing tolerances for certain unit tests that exercise TF32 on MI300, skip the tests until hipblaslt accuracy is improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166478
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
Co-authored-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com >
2025-10-31 16:15:42 +00:00
1e3600b528
[MPS] Move logaddexp/logaddexp2 to Metal and support complex ( #166670 )
...
NOTE: Complex inputs are only supported in `logaddexp`. Since `logaddexp2` does not support complex inputs for CPU, it is not enabled for MPS in this PR either.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166670
Approved by: https://github.com/malfet
2025-10-31 16:15:02 +00:00
fee7624bd6
[PT2] set choice handler in config ( #166607 )
...
Summary:
We were setting the custom inductor choice using `torch._inductor.virtualized.V.set_choices_handler(CustomInductorChoices())`. However, this leads to inconsistent behaviors, even for jobs that are submitted back to back.
In this diff, we pass in the choice handler via an inductor config and overwrite the default behavior when the config is provided. This sovles the inconsistent behavior.
Test Plan: see D85785892 (internal only)
Differential Revision: D85785879
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166607
Approved by: https://github.com/eellison
2025-10-31 15:40:05 +00:00
24e94e021a
[ROCm][CI] create ROCm 7.1 magma tarball ( #166693 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166693
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
2025-10-31 15:20:00 +00:00
69be99ee51
Remove manually synced arch versions in tools/nightly.py ( #166616 )
...
Discussed with @atalman offline. To reduce duplicate changes and reduce the number of files to change when updating arch versions.
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166616
Approved by: https://github.com/ezyang
2025-10-31 15:11:28 +00:00
034e951b0c
[CUDA][cuBLASLt] addmm -- extend bias fusions to cases with (1 by n) shapes ( #166307 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166307
Approved by: https://github.com/eqy
2025-10-31 14:30:41 +00:00
160ab53dd5
Update weight tensor initialization in RMSNormalization ( #166550 )
...
Ensure a >1d tensor as weight for ORT compatibility.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166550
Approved by: https://github.com/titaiwangms
2025-10-31 14:29:27 +00:00
5bcfdae71d
Revert "Make PT2 compile backprop through custom op without autograd key a hard error ( #166367 )"
...
This reverts commit 4acc66f1192ab7743abcc50383aefc5447447f9d.
Reverted https://github.com/pytorch/pytorch/pull/166367 on behalf of https://github.com/atalman due to internal build failures ([comment](https://github.com/pytorch/pytorch/pull/166367#issuecomment-3473150269 ))
2025-10-31 13:44:05 +00:00
4e8ba37ce3
Revert "[BE] Move GreenContext implementation details to cpp ( #166462 )"
...
This reverts commit 5d288bc3f73873887f681e15af83c5525e6a60bd.
Reverted https://github.com/pytorch/pytorch/pull/166462 on behalf of https://github.com/atalman due to Sorry, Reverting. Failure: test/test_matmul_cuda.py::TestMatmulCudaCUDA::test_greencontext_carveout_cuda [GH job link](https://github.com/pytorch/pytorch/actions/runs/18962393091/job/54154156892 ) [HUD commit link](85b035ca9c ) ([comment](https://github.com/pytorch/pytorch/pull/166462#issuecomment-3473060299 ))
2025-10-31 13:20:48 +00:00
26534e9809
Revert "[GraphPartition] cache get_free_symbol_uses ( #166338 )"
...
This reverts commit a6b1ef17173f56ba93ac97ff4384fa4060b5e41e.
Reverted https://github.com/pytorch/pytorch/pull/166338 on behalf of https://github.com/atalman due to Failure: test/nn/test_convolution.py::TestConvolutionNN::test_conv3d_overflow_values [GH job link](https://github.com/pytorch/pytorch/actions/runs/18961173726/job/54149112920 ) [HUD commit link](a6b1ef1717 ) ([comment](https://github.com/pytorch/pytorch/pull/166338#issuecomment-3472980329 ))
2025-10-31 12:57:56 +00:00
657f8c3e21
Revert "Fix torch.full with dynamic tensor fill_value in torch.compile ( #166554 )"
...
This reverts commit 32066772b3dee643b1657b8957f32b5ac8b1390a.
Reverted https://github.com/pytorch/pytorch/pull/166554 on behalf of https://github.com/atalman due to Failure: test/nn/test_pooling.py::TestPoolingNNDeviceTypeCPU::test_max_pool_nan_inf_cpu_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18959368975/job/54144148546 ) [HUD commit link](32066772b3 ) ([comment](https://github.com/pytorch/pytorch/pull/166554#issuecomment-3472976911 ))
2025-10-31 12:55:31 +00:00
b0831930ed
[inductor] Mark / restrict tests that only work if ATen is used for matmul ( #166518 )
...
These tests only work if max_autotune=False (default), which for matmul means falling back to ATen. This PR just documents / makes that transparent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166518
Approved by: https://github.com/eellison
2025-10-31 12:29:06 +00:00
c01636e1bc
Fixes the sparse tensor issue ( #163535 )
...
Fixes #148324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163535
Approved by: https://github.com/janeyx99
2025-10-31 11:48:31 +00:00
fd68d409ad
[xpu][feature] Integrate OneDNN SDPA training forward/backward into XPU OVERRIDEABLE Backend ( #162454 )
...
This is the second PR split from https://github.com/pytorch/pytorch/pull/156272
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162454
Approved by: https://github.com/guangyey , https://github.com/EikanWang , https://github.com/drisspg
2025-10-31 11:20:38 +00:00
0d3a4f7155
[CD] Enable Inductor performance test for xpu ( #166289 )
...
Add Dynamo benchmark performance tests for XPU backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166289
Approved by: https://github.com/EikanWang , https://github.com/atalman
2025-10-31 10:52:07 +00:00
108bb224f7
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification ( #160843 )
...
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.
Changes:
1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-10-31 10:33:16 +00:00
fc8ac1216c
[4/N] Remove unused loop variables in tests ( #166690 )
...
This PR removes unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690
Approved by: https://github.com/justinchuby , https://github.com/mlazos
2025-10-31 10:20:48 +00:00
030de07aff
[2/N] Use 'is' in callable comparisons ( #166685 )
...
It is generally advised to use `is/is not` for comparisons against torch functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166685
Approved by: https://github.com/xmfan , https://github.com/mlazos
2025-10-31 08:08:07 +00:00
7d67a41db4
make FXConverter.generate use V.fake_mode instead of _detect_fake_mode_from_gm ( #166591 )
...
Summary:
FXConverter configurs _node_metadata_hook passing in `fake_mode` explicitly, which is relevant for cases down the line like `_generate_triton_call` that inserts a `triton_kernel_wrapper_mutation` node.
This `fake_mode` is obtained from `_detect_fake_mode_from_gm`, which can be different from inductor set `V.fake_mode`.
For example, while `V.fake_mode` is not None, `_detect_fake_mode_from_gm` can be **None** for a parent graph containing only a submodule which has no input args and only constants
```
parent graph():
%sub : [num_users=1] = call_module[target=sub](args = (), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%sub, slice(None, None, None)), kwargs = {})
return (getitem,)
submodule graph():
%randn : [num_users=1] = call_function[target=torch.ops.aten.randn.default](args = ([5, 10],), kwargs = {device: cuda, pin_memory: False})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%randn, 1), kwargs = {})
return (add,)
```
Getting this discrepnancy is flawed, it makes `_node_metadata_hook` try running inputs in a different "fake_mode" or no fake_mode when the rest of lowering uses `V.fake_mode`. In some cases where input is placed on custom non-gpu device, it can even complain with "requires device to be started" or tensor device mismatch.
So this diff updates FXConverter.generate to use `V.fake_mode` which is populated by inductor properly.
Test Plan:
added a test `test_const_folded_subgraph` in `test_fxir_backend.py`, this test:
- creates a graph module that calls a subgraph with no inputs and containing only const-foldable ops
- const fold the subgraph
- run FXConverter.generate, expect `fake_mode` used to code-generate is not None
On the prior implementation when `_detect_fake_mode_from_gm` was used, this test would fail as fake_mode would be `None`.
With this change, the test passes, `fake_mode` is properly collected from `V.fake_mode` which is not None.
Differential Revision: D85767475
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166591
Approved by: https://github.com/blaine-rister , https://github.com/mlazos , https://github.com/eellison
2025-10-31 05:52:07 +00:00
85b035ca9c
[nativert] Downcast triton double arguments to floats ( #166620 )
...
This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values.
---
## Limitations in (de)seriazliation
In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326 )):
```
elif isinstance(arg, float):
return ("fp32", None)
```
But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149) )) because Thrift only has the double type:
```
union Argument {
10: bool as_none;
20: TensorArgument as_tensor;
30: list<TensorArgument> as_tensors;
50: i64 as_int;
70: list<i64> as_ints;
80: double as_float; ===> actually double
...
```
`TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86) )):
```
using Constant = std::variant<
None,
int64_t,
std::vector<int64_t>,
double, ===> triton float is loaded as double
```
So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`.
D84423898 was the current workaround: wrapping float arguments with tensors.
## The Fix
Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends.
Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments.
Test Plan:
```
buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export --
```
Differential Revision: D85747160
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620
Approved by: https://github.com/XueningXu
2025-10-31 03:52:20 +00:00
267d0197bf
[dynamo] fix error_on_graph_break bug where non-empty checkpoint results in unwanted graph break resumption ( #166586 )
...
Fixes https://github.com/pytorch/pytorch/issues/166589
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166586
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #166476 , #166477
2025-10-31 03:36:27 +00:00
1dec8a67a8
[dynamo, nested graph breaks] add disable_nested_graph_breaks decorator/context manager ( #166477 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166477
Approved by: https://github.com/Lucaskabela , https://github.com/Skylion007
ghstack dependencies: #166476
2025-10-31 03:36:27 +00:00
797cd80b26
[dynamo, nested graph breaks] codegen dead nested cells correctly ( #166476 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166476
Approved by: https://github.com/Lucaskabela
2025-10-31 03:36:27 +00:00
7d39401fa0
Revert "[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ ( #166569 )"
...
This reverts commit f1e4c42b6ef3d3cea08ab3babb693e3ce42cf08b.
Reverted https://github.com/pytorch/pytorch/pull/166569 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166569#issuecomment-3471180280 ))
2025-10-31 03:31:01 +00:00
e3ae0594d1
Add CUDA MXFP4 scaled mm support via. FBGEMM ( #166526 )
...
Summary:
* Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA
* Add testing
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526
Approved by: https://github.com/drisspg , https://github.com/ngimel
2025-10-31 03:17:27 +00:00
f1e4c42b6e
[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ ( #166569 )
...
Provides type coverage to ~3000 LOC and 200 methods in `torch/_dynamo/variables/`
This is the first part of the final step to having 100% strict type coverage in dynamo - see previous comments in https://github.com/pytorch/pytorch/pull/166535 (combined into this one PR because ghstack was giving issues...)
### Coverage report:
```
mypy torch_dynamo/variables --linecount-report /tmp/coverage_log
```
Compare before to after - we go from 3826 to 7221 lines covered
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166569
Approved by: https://github.com/williamwen42
2025-10-31 02:57:59 +00:00
d3e511f07c
[Inductor] support masked vectorization for the tail_loop for fp8 datatype ( #163324 )
...
**Summary:**
Support masked vectorization for the tail_loop for fp8 datatype.
**Example:**
```
import torch
def fn(
x,
scale,
zero_point,
quant_min,
quant_max,
dtype,
):
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x,
scale,
zero_point,
quant_min,
quant_max,
dtype,
)
x = torch.relu(x)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale, zero_point, quant_min, quant_max, dtype
)
return x
quant_min = -128
quant_max = 127
dtype = torch.float8_e4m3fn
x = torch.clamp(torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, quant_min, quant_max).to(dtype)
zero_point = 100
scale = 0.01
with torch.no_grad():
compiled_fn = torch.compile(fn)
compiled_fn(x, scale, zero_point, quant_min, quant_max, dtype)
```
**Generated code:**
- Before
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0,
at::Float8_e4m3fn* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
{
for (int64_t x0_tail = static_cast<int64_t>(432L);x0_tail < static_cast<int64_t>(441L); x0_tail++)
{
auto tmp0 = in_ptr0[static_cast<int64_t>(x0_tail)];
auto tmp1 = c10::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = float(tmp1 - tmp2);
auto tmp4 = static_cast<float>(0.01);
auto tmp5 = float(tmp3 * tmp4);
auto tmp6 = c10::convert<float>(tmp5);
auto tmp7 = std::max(tmp6, decltype(tmp6)(0));
auto tmp8 = float(tmp7 * tmp2);
auto tmp9 = std::nearbyint(tmp8);
auto tmp10 = float(tmp9 + tmp2);
auto tmp11 = static_cast<float>(-128.0);
auto tmp12 = max_propagate_nan(tmp10, tmp11);
auto tmp13 = static_cast<float>(127.0);
auto tmp14 = min_propagate_nan(tmp12, tmp13);
auto tmp15 = c10::convert<at::Float8_e4m3fn>(tmp14);
out_ptr0[static_cast<int64_t>(x0_tail)] = tmp15;
}
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
# [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
- After
```
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0 = async_compile.cpp_pybinding(['const at::Float8_e4m3fn*', 'at::Float8_e4m3fn*'], r'''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C" void kernel(const at::Float8_e4m3fn* in_ptr0,
at::Float8_e4m3fn* out_ptr0)
{
{
for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(441L); x0+=static_cast<int64_t>(16L))
{
{
if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(432L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
}
if(C10_UNLIKELY(x0 >= static_cast<int64_t>(432L) && x0 < static_cast<int64_t>(441L)))
{
auto tmp0 = at::vec::Vectorized<at::Float8_e4m3fn>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
auto tmp1 = at::vec::convert<float>(tmp0);
auto tmp2 = static_cast<float>(100.0);
auto tmp3 = at::vec::Vectorized<float>(tmp2);
auto tmp4 = tmp1 - tmp3;
auto tmp5 = static_cast<float>(0.01);
auto tmp6 = at::vec::Vectorized<float>(tmp5);
auto tmp7 = tmp4 * tmp6;
auto tmp8 = (tmp7);
auto tmp9 = at::vec::clamp_min(tmp8, decltype(tmp8)(0));
auto tmp10 = tmp9 * tmp3;
auto tmp11 = tmp10.round();
auto tmp12 = tmp11 + tmp3;
auto tmp13 = static_cast<float>(-128.0);
auto tmp14 = at::vec::Vectorized<float>(tmp13);
auto tmp15 = at::vec::maximum(tmp12, tmp14);
auto tmp16 = static_cast<float>(127.0);
auto tmp17 = at::vec::Vectorized<float>(tmp16);
auto tmp18 = at::vec::minimum(tmp15, tmp17);
auto tmp19 = at::vec::convert<at::Float8_e4m3fn>(tmp18);
tmp19.store(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(9L));
}
}
}
}
}
''')
async_compile.wait(globals())
del async_compile
class Runner:
def __init__(self, partitions):
self.partitions = partitions
def recursively_apply_fns(self, fns):
new_callables = []
for fn, c in zip(fns, self.partitions):
new_callables.append(fn(c))
self.partitions = new_callables
def call(self, args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, 7, 7, 9), (441, 63, 9, 1))
buf0 = empty_strided_cpu((1, 7, 7, 9), (441, 63, 9, 1), torch.float8_e4m3fn)
# [Provenance debug handles] cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0:1
cpp_fused_dequantize_per_tensor_quantize_per_tensor_relu_0(arg0_1, buf0)
del arg0_1
return (buf0, )
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163324
Approved by: https://github.com/Xia-Weiwen , https://github.com/mingfeima , https://github.com/jansel
2025-10-31 02:53:56 +00:00
d3be06cbdc
[MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode ( #165963 )
...
Summary:
This diff is the binding and hook layer for MTIA Graph, including
1. binding between Python and C++
2. hook between Pytorch and mtia fbcode
<img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340 " />
[Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00 )
Test Plan: Will be tested in the python implementation which will use the binding and hook
Differential Revision: D84457757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963
Approved by: https://github.com/malfet , https://github.com/albanD
2025-10-31 02:52:51 +00:00
1129605415
[ROCm][CI] create ROCm 7.1 images for binary builds ( #166665 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166665
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
2025-10-31 02:52:37 +00:00
a6b1ef1717
[GraphPartition] cache get_free_symbol_uses ( #166338 )
...
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs.
ee7434be82/torch/_inductor/scheduler.py (L4869-L4885)
I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node.
Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times.
ee7434be82/torch/_inductor/ir.py (L4541-L4543)
This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166338
Approved by: https://github.com/eellison
2025-10-31 02:50:10 +00:00
12577064dd
[MPS] Fix crash when max/min ops called for complex types ( #166214 )
...
Raise an exception, as it's meaningless and results in segfault otherwise:
```
% python -c "import torch;torch.rand(10, dtype=torch.cfloat, device='mps').amax()"
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: error: 'mps.reduction_max' op operand #0 must be tensor of mps native type values, but got 'tensor<10xcomplex<f32>>'
(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:176:0: note: see current operation: %2 = "mps.reduction_max"(%arg0, %1) <{keep_dims, propagate_nans}> : (tensor<10xcomplex<f32>>, tensor<1xsi32>) -> tensor<1xcomplex<f32>>
/AppleInternal/Library/BuildRoots/4~B6shugDBannYeMBGCfhw7wjvNJOfy4BrawZ7TdI/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1347: failed assertion `original module failed verification'
zsh: abort python -c
```
To be tested by `test_ops.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166214
Approved by: https://github.com/dcci , https://github.com/kulinseth , https://github.com/Skylion007
ghstack dependencies: #166272
2025-10-31 02:37:20 +00:00
24b6eb7727
[Inductor] Enable Custom op Autotune Decompositions and Parameter Tuning ( #164212 )
...
This PR introduces CustomOp autotuning. It allows user to provide a CustomOpConfig:
(1) to register (optional) multiple decomposition implementations for custom operations and
(2) to register parameter tuning knobs and values they want to tune for the decompositions
so that inductor automatically select the best-performing variant through Inductor's autotune benchmarking.
Example:
```python
register_custom_op_autotuning(
custom_op=my_attention_op,
configs=[
CustomOpConfig(attention_impl, head_dim=32, method='chunked'),
CustomOpConfig(attention_impl, head_dim=64, method='tiled'),
CustomOpConfig(head_dim=128), # no decompositions
],
input_gen_fns={
"query": lambda fake: torch.randn_like(fake, device='cuda'),
"key": lambda fake: torch.randn_like(fake, device='cuda'),
"value": lambda fake: torch.randn_like(fake, device='cuda'),
}
)
```
**CustomOpConfig**: Each CustomOpConfig defines exactly one autotuning variant with specific parameter values and optional decomposition implementation with PyTorch aten ops. Users can register their own tuning knobs and optional decomposition functions for the same custom operation. The system automatically benchmarks all variants to select the best performing. If no decomposition is provided in the config, the CustomOp's default implementation will be used.
**Custom Input Generation**: Users can provide custom input generators via an optional `input_gen_fns` to control how synthetic inputs are created during benchmarking. This enables more realistic performance testing by generating inputs that match expected data distributions and characteristics for each tensor argument.
**More Examples with autotune logs:**:
1. Allow user to register customOp decompositions with tuning parameters for autotuning. Example usage:
```python
from torch._inductor.kernel.custom_op import CustomOpConfig, register_custom_op_autotuning
def decompose_k_implementation(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4) -> torch.Tensor:
"""Matrix multiply with k-way decomposition."""
# Implementation...with k_splits
@torch.library.custom_op("my_lib::decompose_k", mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int
) -> torch.Tensor:
return decompose_k_implementation(a, b, k_splits)
# Register autotuning with different k_splits values
register_custom_op_autotuning(
custom_op=test_decompose_k_op,
configs=[
CustomOpConfig(decompose_k_implementation, k_splits=2),
CustomOpConfig(decompose_k_implementation, k_splits=32),
CustomOpConfig(decompose_k_implementation, k_splits=64),
CustomOpConfig(k_splits=128), # can make decomposition optional, then use default impl test_decompose_k_op
CustomOpConfig(k_splits=256)
],
input_gen_fns={
"a": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
"b": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
}
)
```
Example result:
```
{"num_choices": 6, "num_triton_choices": 0, "best_kernel": "test_decompose_k_autotuned_fallback_default", "best_time": 0.09980800002813339}
AUTOTUNE test_decompose_k_autotuned(256x65536, 65536x1024)
strides: [65536, 1], [1024, 1]
dtypes: torch.float16, torch.float16
test_decompose_k_autotuned_fallback_default 0.0998 ms 100.0%
test_decompose_k_autotuned_decompose_k_implementation_k_splits_2_0 0.1096 ms 91.0% CustomOp decompose_k_implementation_k_splits_2
test_decompose_k_autotuned_decompose_k_implementation_k_splits_32_1 0.1277 ms 78.2% CustomOp decompose_k_implementation_k_splits_32
test_decompose_k_autotuned_decompose_k_implementation_k_splits_64_2 0.1454 ms 68.6% CustomOp decompose_k_implementation_k_splits_64
test_decompose_k_autotuned_decompose_k_implementation_k_splits_128_3 0.1536 ms 65.0% CustomOp decompose_k_implementation_k_splits_128
test_decompose_k_autotuned_decompose_k_implementation_k_splits_256_4 0.2084 ms 47.9% CustomOp decompose_k_implementation_k_splits_256
```
2. Allow user to tune parameter knob by passing the parameter and values in the CustomOpConfig.
**Example**
```python
def mlp_variants(input_tensor, gate_weight, up_weight, down_weight, method):
"""MLP implementation with different computational approaches."""
if method == 0:
# Standard separate matmuls
# ... implementation
elif method == 1:
# Batched approach with torch.mm
# ... implementation
elif method == 2:
# Fused weights approach
# ... implementation
@torch.library.custom_op("my_lib::mlp_op", mutates_args=())
def mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
register_custom_op_autotuning(
custom_op=mlp_op,
configs=[
CustomOpConfig(method=0),
CustomOpConfig(method=1),
CustomOpConfig(method=2),
# method=0 is the default fallback in the original op
],
input_gen_fns={
"input_tensor": lambda fake: torch.randn_like(fake, device='cuda') * 0.1,
"gate_weight": lambda fake: torch.randn_like(fake, device='cuda') * 0.05,
# ... other input generators
}
)
```
Example result:
```
AUTOTUNE test_mlp_autotuned(4x32x512, 512x1024, 512x1024, 1024x256)
test_mlp_autotuned_mlp_variants_method_2 0.0181 ms 100.0% CustomOp mlp_variants_method_2
test_mlp_autotuned_mlp_variants_method_1 0.0185 ms 97.8% CustomOp mlp_variants_method_1
test_mlp_autotuned_mlp_default_fallback_method_0 0.0198 ms 91.4% CustomOp fallback
```
### Test Suite (`test/inductor/test_custom_op_autotune.py`)
* **RMSNorm autotuning**: Tests different RMSNorm implementations with dynamic input shapes
* **MLP autotuning**: Tests different MLP decomposition and tuning "method" parameter
* **DecomposeK**: Tests different k_splits values for matrix multiplication decomposition with k dim split
* **Multi-parameter tuning**: Tests configs with multiple tuning parameters (scale_mode, chunk_size)
### Next Step:
- Enable Max-autotune with user passed in max-autotune config. https://github.com/pytorch/pytorch/pull/165526/files
- Support inline epilogue fusion for selected best customop decomposition with surrounding elementwise ops. https://github.com/pytorch/pytorch/pull/165952/files
- Support customop autotune considering fusion with multiTemplateBuffer. WIP
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164212
Approved by: https://github.com/zou3519
2025-10-31 02:28:00 +00:00
32066772b3
Fix torch.full with dynamic tensor fill_value in torch.compile ( #166554 )
...
Fixes #166253
## Summary
When `torch.full` is called with a 0-D tensor as `fill_value` inside a `torch.compile`'d function, the value was being incorrectly cached, causing subsequent calls with different values to return the first value.
## Root Cause
The Dynamo handler for `torch.full` was calling `aten._local_scalar_dense` to convert tensor fill_values to Python scalars at compile time, which baked the value into the compiled graph as a constant.
## Solution
Modified the Dynamo handler to decompose `torch.full(size, tensor_fill_value)` into `empty(size).fill_(tensor_fill_value)` when `fill_value` is a `TensorVariable`, keeping the fill value dynamic in the compiled graph.
## Testing
Added test case that verifies torch.full works correctly with dynamic tensor fill_values across multiple calls and dtypes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166554
Approved by: https://github.com/Lucaskabela
2025-10-31 00:56:02 +00:00
47f0024310
[CI][BE] Factor out repeated test code ( #166481 )
...
Into `_run_single_arg_fwd`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166481
Approved by: https://github.com/Skylion007
2025-10-31 00:52:50 +00:00
98d640bb11
Remove AT_USE_HIPSPARSE_GENERIC_API ( #166393 )
...
This macro is not used in OSS anymore.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166393
Approved by: https://github.com/ezyang
2025-10-31 00:49:09 +00:00
5d288bc3f7
[BE] Move GreenContext implementation details to cpp ( #166462 )
...
- Remove all complex defines logic from the header
- Make GreenContext constructor private, as it should only be created via the static method as singleton
- Delete unused `getContext` and `getGreenContext` methods
- Rename `CUDA_HAS_GREEN_CONTEXT` to `HAS_CUDA_GREEN_CONTEXT()`, which results in compilation error if one accidentally makes a typo
- Suppress `-Wunused-private-field` is GreenContext is not available
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166462
Approved by: https://github.com/ngimel , https://github.com/eqy
2025-10-31 00:48:01 +00:00
bfb47ec50e
[dynamo] support tracing new typing union syntax X | Y ( #166599 )
...
To do in a followup - I think there's an approach to reconstruct typing variables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166599
Approved by: https://github.com/SherlockNoMad , https://github.com/anijain2305 , https://github.com/Skylion007
2025-10-30 23:59:27 +00:00
7a0cd8ed09
[ROCm] Disable __builtin_amdgcn_rcpf for gfx90a ( #166454 )
...
Improves accuracy for some failing tests.
test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py::TestClipGradNormWorldSize4::test_clip_grad_norm_2d [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930221123/job/54046876467 ) [HUD commit link](f20bf77874 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166454
Approved by: https://github.com/jerrymannil , https://github.com/jeffdaily
2025-10-30 23:39:00 +00:00
984e64b2cd
[inductor] Fix constant folder ( #166655 )
...
Fixes https://fb.workplace.com/groups/1028545332188949/permalink/1351999569843522/ where the resulting graph of constant folder uses a sym node which has been created later. Graph diff: https://www.internalfb.com/intern/diffing/?paste_number=2014609054
Before:
```
%full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
%select_18 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%full_65, 1, 0), kwargs = {})
%mul_2792 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_18, 0), kwargs = {})
%embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %mul_2792), kwargs = {})
```
After:
```
%full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
%full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_150], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
%embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %full_default_1), kwargs = {})
...
%sym_size_int_150 : [num_users=7] = call_function[target=torch.ops.aten.sym_size.int](args = (%view_193, 0), kwargs = {})
```
I couldn't figure out a small repro for this :/
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166655
Approved by: https://github.com/eellison
2025-10-30 22:51:28 +00:00
b9bcb37f40
[DebugMode] store stringify args by default ( #166347 )
...
DebugMode currently stores dispatch call args & kwargs, which is all intermediate tensors and more. This quickly OOMed on GPU when trying to debug some torchtitan / llama 8b models.
This defaults to storing the stringified version, adding a flag `DebugMode(store_original_args=True)` if users want to store the original args as-is (and for BC).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166347
Approved by: https://github.com/yushangdi
2025-10-30 22:12:23 +00:00
7e3b9d105e
[CP][BE][2/2] Refactor the code structure ( #166501 )
...
Our CP codebase now contains several files and we are adding more. This
PR refactors the code to consolidate the files into a context_parallel
folder but keep the import so that the existing users of CP won't be
affected.
Unfortunately, we have to split this PR into two PRs as the PyTorch
infra cannot accept a PR with 3000+ LoC change and git cannot recognize
that _context_parallel/_attention.py is moved from _attention.py because
we want to keep BC.
This is the second PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166501
Approved by: https://github.com/Skylion007
ghstack dependencies: #166456
2025-10-30 22:07:07 +00:00
45c3f02d69
[ROCm] moved gfx1100 back to experimental status for AOTriton ( #166397 )
...
According to next commit to AOTriton:
8625c4faee
These changes missed in 0.11b release:
https://github.com/pytorch/pytorch/pull/161754
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166397
Approved by: https://github.com/jeffdaily
2025-10-30 21:43:01 +00:00
f5543e3741
[wip] fix searchsorted non dense ( #165064 )
...
Fix for https://github.com/pytorch/pytorch/issues/163528
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165064
Approved by: https://github.com/benjaminglass1 , https://github.com/mlazos
2025-10-30 21:21:24 +00:00
5fc2c7a2a1
[ROCm][inductor] More configs for pointwise kernels. ( #166470 )
...
This config improves performance by 250% on some kernels that contain `t1.atomic_add(...)`. Again, we conditionalize for ROCm/HIP, so there is no impact to NV.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166470
Approved by: https://github.com/PaulZhang12 , https://github.com/mlazos , https://github.com/eellison , https://github.com/jansel
2025-10-30 21:20:12 +00:00
7692fa09cd
[Code Clean] Clean asserts in torch/ao/quantization/fx/* ( #165420 )
...
Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/fx/* (177 errors)
fix partialy #164878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165420
Approved by: https://github.com/RohitRathore1 , https://github.com/fffrog , https://github.com/albanD
2025-10-30 20:53:36 +00:00
df71b70727
[cuDNN][conv] Re-enable cuDNN for 3D convolutions (fixed in 9.15+) ( #166480 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166480
Approved by: https://github.com/Skylion007 , https://github.com/malfet
2025-10-30 20:47:20 +00:00
80ba6e458f
Add warning when users have incomplete setup for type checking ( #166603 )
...
Looking for feedback on this approach.
Received user reports of spurious pyrefly errors for users using hg instead of git. I think this was due to the fact that when using a venv and git, `make setup-env` installs requirements and pulls from a nightly torch wheel, which is needed for pyrefly to type check properly.
Initial documentation for `make setup-env` I found here: https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#developing-pytorch
Testing:
```
hg clone --git ssh://git@github.com/pytorch/pytorch.git
conda create -n pytorch_env python=3.10 # (or manually create venv instead of using script)
cd pytorch
pip install -r requirements.txt
pip install -r requirements-build.txt
lintrunner init
# check how many pyrefly errors - 15,709 errors (11,693 ignored)
lintrunner # confirm error message / warning appears
>>> General linter failure:
Warning (PYREFLY) nightly-wheel-not-run
pytorch-nightly.pth not found. You may need to run make setup-env or make
setup-env-conda to install nightly binaries and type stubs.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166603
Approved by: https://github.com/aorenste
2025-10-30 20:37:44 +00:00
0d50e5d8d4
[3/N] Fix unused loop variables ( #166509 )
...
This PR removes unused loop variables in tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166509
Approved by: https://github.com/Lucaskabela , https://github.com/Skylion007
2025-10-30 20:13:51 +00:00
99b05d1b78
Better 1x128, 128x128 error handling on non-Hopper ( #166639 )
...
Summary:
Blockwise 1x128 and 128x128 scaling is only available on CUDA >= 12.9
and only on Hopper GPUs. Attempting to run on B200 would give a
hard-to-debug `CUBLAS_STATUS_NOT_SUPPORTED`.
Add a more helpful `NotImplementedError` to catch this case.
Also more explicitly disable ROCm builds for relevant methods, based on
lack of support per [hipBLASlt
docs](https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/reference/datatypes.html#_CPPv4N28hipblasLtMatmulMatrixScale_t40HIPBLASLT_MATMUL_MATRIX_SCALE_VEC128_32FE ).
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166639
Approved by: https://github.com/drisspg
2025-10-30 20:13:06 +00:00
f911d64750
[CUDA] xFail max-autotune grouped gemm tests on devices with insufficient SM count ( #165921 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165921
Approved by: https://github.com/ngimel
2025-10-30 20:05:07 +00:00
52db60170d
Enable verify_dynamo on Python 3.13 ( #166497 )
...
Dynamo now supports Python 3.13.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166497
Approved by: https://github.com/Lucaskabela , https://github.com/williamwen42
2025-10-30 19:52:32 +00:00
56838bad5f
[CP][BE][1/2] Refactor the code structure ( #166456 )
...
Our CP codebase now contains several files and we are adding more. This PR refactors the code to consolidate the files into a context_parallel folder but keep the import so that the existing users of CP won't be affected.
Unfortunately, we have to split this PR into two PRs as the PyTorch infra cannot accept a PR with 3000+ LoC change and git cannot recognize that _context_parallel/_attention.py is moved from _attention.py because we want to keep BC.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166456
Approved by: https://github.com/Skylion007
2025-10-30 19:46:49 +00:00
ad3a56ab98
Add a compile-time flag to trigger verbose logging for device-side asserts ( #166171 )
...
Summary:
Using `CUDA_KERNEL_ASSERT_PRINTF` inside kernels allows us to log invalid values to the console (that can be in turn used to surface _hopefully_ more clearer error messages).
This does have an impact in the number of registers needed for the values being logged (I confirmed via diffing PTX that there is no other impact relative to using `__assert_fail`)
To avoid causing perf bottlenecks, this change adds a compile-time switch to enable more verbose errors in some of the common kernels that cause DSAs. There is also a Buck config that can be used to configure this switch more conveniently.
## Alternatives considered
I considered making the behavior of `CUDA_KERNEL_ASSERT_PRINTF` controllable via a compile-time macro instead of writing another wrapper for it but there are kernels where the extra register pressure is not as severe and in those cases, having more useful error messages by default is pretty useful.
Test Plan:
## Simple Python Driver:
```
# scatter_errors.py
import torch
def main() -> None:
a = torch.rand(128, device="cuda:0")
idx = torch.randint(0, 128, (100,), device="cuda:0")
idx[0] = 9999
b = torch.scatter(a, 0, idx, 555.0)
print(b)
```
When running normally via:
```
$ buck2 run @//mode/opt :scatter_errors
```
we see the followng DSA message:
```
fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
```
Running via:
```
$ buck2 run @//mode/opt -c fbcode.c10_enable_verbose_assert=1 :scatter_errors
```
however produces:
```
[CUDA_KERNEL_ASSERT] fbcode/caffe2/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:410: operator(): block: [0,0,0], thread: [0,0,0]: Assertion failed: `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"`: Expected 0 <= idx_dim < index_size (128), but got idx_dim = 9999
```
Differential Revision: D85185987
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166171
Approved by: https://github.com/ngimel
2025-10-30 19:43:46 +00:00
a7fd0b4001
[ROCm][CI] fix disk space message ( #166645 )
...
Fixes diskspace cutoff to say that the machine does not have difference=100 - diskspace_cutoff_int space available.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166645
Approved by: https://github.com/jeffdaily
2025-10-30 19:38:34 +00:00
181ee3bd42
fix: Add missing signals_to_handle to launcher logging ( #166631 )
...
Fixes #166630
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166631
Approved by: https://github.com/Skylion007
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com >
2025-10-30 19:31:25 +00:00
0ec0549823
Introduce a new API torch.xpu.get_per_process_memory_fraction ( #165511 )
...
# Motivation
Aligned with other backends, this PR introduces a new API torch.xpu.get_per_process_memory_fraction to allow user to retrieve the allowed memory fraction per a single process.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165511
Approved by: https://github.com/EikanWang , https://github.com/ezyang
ghstack dependencies: #165508 , #165509 , #165510
2025-10-30 19:30:09 +00:00
8221ee6db9
[xpu] Fix type annotation for ProcessGroupXCCL ( #166418 )
...
After #163049 , this PR fixes the type annotations to match the actual implementation for ProcessGroupXCCL::Options.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166418
Approved by: https://github.com/guangyey , https://github.com/ezyang
2025-10-30 19:29:06 +00:00
b939de26d1
Avoid writing temporary modules to disk ( #157713 )
...
In some cases the warning from #147744 still gets emitted because [atexit hooks aren't called](https://github.com/python/cpython/pull/114279 ).
Even in those cases, if the atexit hooks _were_ called you could end up with issues due to the directory being deleted in one process, but still being used elsewhere.
It's better all round to load these modules entirely in-memory.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157713
Approved by: https://github.com/xush6528
2025-10-30 19:11:16 +00:00
694db5f549
Use 'is' in callable comparisons ( #166624 )
...
Just like we use `is/is not` for class comparisons, it is generally advised to use `is/is not` for comparisons against torch functions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166624
Approved by: https://github.com/Lucaskabela , https://github.com/Skylion007
2025-10-30 19:00:09 +00:00
639a0b1239
Remove torch.distributed.tensor.OpSchema.has_symints ( #163667 )
...
It appears to be unused based on `cd torch; rg has_symints`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163667
Approved by: https://github.com/xmfan , https://github.com/azahed98 , https://github.com/albanD
ghstack dependencies: #162990
2025-10-30 18:57:17 +00:00
398775a43e
[CodeClean] Replace std::runtime_error with TORCH_CHECK ( #165119 )
...
As the title stated.
**Changes**:
- torch/csrc/inductor(Part 2)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165119
Approved by: https://github.com/janeyx99
ghstack dependencies: #165139
2025-10-30 18:43:58 +00:00
fcd5f8c352
[CodeClean] Remove the Unused MACRO for AOT Inductor Runtime ( #165139 )
...
As the title stated.
- AOTI_TORCH_CHECK depend on TORCH_CHECK_MSG which located in c10/util/Exception.h, which maybe break BC
- AOTI_TORCH_CHECK is not used everywhere
- STD_TORCH_CHECK have ABI check tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165139
Approved by: https://github.com/Skylion007 , https://github.com/janeyx99
2025-10-30 18:43:58 +00:00
4acc66f119
Make PT2 compile backprop through custom op without autograd key a hard error ( #166367 )
...
Signed-off-by: Edward Z. Yang <ezyang@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166367
Approved by: https://github.com/bdhirsh
2025-10-30 18:43:07 +00:00
8f40a0c634
Revert "address DDE in matmul decomp ( #166541 )"
...
This reverts commit 90519402c2006237f891289a0afdec804515aa73.
Reverted https://github.com/pytorch/pytorch/pull/166541 on behalf of https://github.com/atalman due to breaks internal test ([comment](https://github.com/pytorch/pytorch/pull/166541#issuecomment-3469382334 ))
2025-10-30 18:11:33 +00:00
a5c3c08d10
[Pytorch] Use exp_u20 for aarch64's erf ( #166594 )
...
Summary:
After a precision study, we concluded it is ok to use ACL's exp function on f32's erf()
We can keep erf inline this way.
Benchmarks show about 91% higher throughput when processing a tensor of 1M elements, compiling with clang-19:
Before:
f32 erf: 2539.179us
After:
f32 erf: 1329.063us
Test Plan:
Correctness:
buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch
Performance:
buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test
Differential Revision: D85730452
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166594
Approved by: https://github.com/mcfi , https://github.com/fadara01
2025-10-30 18:09:05 +00:00
a553ea9ea4
Fix missing symbol when printing guards ( #165723 )
...
Fixes #165177
When converting guards to sources if we were unable to get the expected symbol from symbol_to_source then try to get it from var_to_sources.
I was unable to make a simpler repro than what was described in the issue (which relies on llama3 - so inappropriate for a unit test).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165723
Approved by: https://github.com/bobrenjc93
2025-10-30 18:03:51 +00:00
ba71e9ca9a
[DeviceMesh] Isolate pg creation logic in Device Mesh into a separate func _init_one_process_group ( #166614 )
...
To makes pg cache change easier and code modularization, we isolate the logic of process group creation into a separate function named `_init_one_process_group`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166614
Approved by: https://github.com/lw
2025-10-30 17:57:41 +00:00
694d205143
Revert "shrink_group implementation to expose ncclCommShrink API ( #164518 )"
...
This reverts commit 311ea0dec0c50f395e6dac7b3875e81ee243fceb.
Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/atalman due to breaks internal builds Error: from logging_utils import ( ModuleNotFoundError: No module named 'logging_utils' ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3469308568 ))
2025-10-30 17:52:29 +00:00
629293f568
bucket all reduce ( #166528 )
...
Bucket all reduce in bucketer, thanks to @IvanKobzarev's earlier pr.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166528
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #166527
2025-10-30 17:12:34 +00:00
c37802a8c4
use multi-dtype bucketing ( #166527 )
...
Make the bucketer use multi-dtype bucketing for all gathers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166527
Approved by: https://github.com/IvanKobzarev , https://github.com/ezyang
2025-10-30 16:54:49 +00:00
0a3ac47c0a
Revert "[user-streams] Fix stream graph output semantics ( #164819 )"
...
This reverts commit f5cb9a4c68d9271c58ef4d3257210984b8e85099.
Reverted https://github.com/pytorch/pytorch/pull/164819 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/164819#issuecomment-3469018283 ))
2025-10-30 16:53:32 +00:00
e83be7042e
Fix pyrefly errors on main ( #166548 )
...
Fixes existing errors to keep noise from lintrunner to a minimum
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166548
Approved by: https://github.com/Lucaskabela , https://github.com/mlazos
2025-10-30 16:47:27 +00:00
fb545fb068
Add MXFP4 grouped gemm support via. FBGEMM kernels ( #166530 )
...
Summary:
* Extend `_scaled_grouped_mm_v2` to include MXFP4 support
* Add testing to existing grouped routines
Test Plan:
```
pytest -svv -k "mxfp4 and group" test/test_scaled_matmul_cuda.py
```
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com >
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166530
Approved by: https://github.com/drisspg
2025-10-30 16:46:11 +00:00
2df2c316e2
[devx] Fix invalid symbol definition emitted in fx_graph_runnable.py ( #166529 )
...
Summary: When emitting symbolic variable definition in fx_graph_runnable.py, we need to check if a SymNode is actually an expression, so that we won't generate something like "s27*s53**2 = 36".
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166529
Approved by: https://github.com/mlazos
ghstack dependencies: #166432
2025-10-30 16:40:12 +00:00
08b0a8f11a
[Inductor] Fix an inductor_provenance bug ( #166432 )
...
Summary: Fix an inductor_provenance related error seen when running TORCH_COMPILE_DEBUG generated fx_graph_runnable.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166432
Approved by: https://github.com/mlazos
2025-10-30 16:40:12 +00:00
3f1824742c
Revert "Fix comparing inductor actual strides vs bw graph for activations should not throw DDE. ( #166277 )"
...
This reverts commit b2a0f90501dd3a16a6ccaf4c49e1c10f6df4ce1d.
Reverted https://github.com/pytorch/pytorch/pull/166277 on behalf of https://github.com/atalman due to Breaks internal executorch tests ([comment](https://github.com/pytorch/pytorch/pull/166277#issuecomment-3468696623 ))
2025-10-30 15:49:23 +00:00
bbb7d2270b
[inductor] print 0.0 as 0 for triton ( #164291 )
...
Fixes https://github.com/pytorch/pytorch/issues/164157
Fixes https://github.com/pytorch/pytorch/issues/164086
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164291
Approved by: https://github.com/bobrenjc93 , https://github.com/mlazos
2025-10-30 15:15:25 +00:00
6a5a436624
DTensor: C++ compute_global_tensor_info ( #162990 )
...
compute_global_tensor_info is on the hot path for DTensor.{from,to}_local. More incremental progress toward C++.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162990
Approved by: https://github.com/ezyang
2025-10-30 15:10:54 +00:00
ad559072db
[triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug ( #166568 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166568
Approved by: https://github.com/minjang
2025-10-30 14:54:54 +00:00
ad02bd13df
Revert "[user-streams] Add current stream source ( #165211 )"
...
This reverts commit 79aee77381b21d41c77148e5ff84c4b351aaf144.
Reverted https://github.com/pytorch/pytorch/pull/165211 on behalf of https://github.com/atalman due to failure: test/test_python_dispatch.py::TestPythonDispatch::test_return_stream [GH job link](https://github.com/pytorch/pytorch/actions/runs/18942517662/job/54086481693 ) [HUD commit link](7563f61cc8 ) ([comment](https://github.com/pytorch/pytorch/pull/165211#issuecomment-3468332362 ))
2025-10-30 14:34:43 +00:00
7563f61cc8
Make bucketing aware of collective LIFO semantics ( #166324 )
...
In the initial pr for overlapping preserving bucketing, for a graph like:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
We would add dependencies from mm -> ag, and wait from wait -> hiding_compute, to prevent bucketing reordering these collectives so that overlap no long occurred. however, there is an additional way for bucketing to prevent overlap.
If we were to reorder another collective so the graph looked like:
```
def foo(...):
ag = all_gather(...)
ar = all_reduce(...)
wait(ar)
hiding_compute = mm(...)
wait(ag)
```
Overlap would not occur, because the wait for the all reduce would also force realization of every collective enqueued on the same stream prior to the all reduce. NCCL uses a single stream per process group.
To model, we set a set a strict ordering of all collective starts, waits, and hiding compute initially when bucketing. Then, when trying to add a collective to a bucket, we will see if we interfere with overlap for all of the following possible bucketings:
[move collective start to bucket start, move bucket start to collective start] x [move collective wait to bucket wait x move bucket wait to collective wait].
For any of these positions, we check if overlap would have been interfered with because of stream queue semantics. Then, if not, we remove the moving start and wait from the constrained ordering of collectives, and see if it's topologically valid to merge the nodes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166324
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #166309
2025-10-30 13:37:00 +00:00
fa8e073a4e
Revert "[triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug ( #166568 )"
...
This reverts commit d46d8d6f54b15ded4f2483c7bde31be124281ab8.
Reverted https://github.com/pytorch/pytorch/pull/166568 on behalf of https://github.com/atalman due to Failed test/test_extension_utils.py::TestExtensionUtils::test_external_module_register_with_renamed_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18931754443/job/54050880312 ) [HUD commit link](d46d8d6f54 ) ([comment](https://github.com/pytorch/pytorch/pull/166568#issuecomment-3468008894 ))
2025-10-30 13:31:47 +00:00
95b5534773
Revert "[user-streams] Track symbolic current stream ( #165212 )"
...
This reverts commit a5335263d32b5be2b2647661334d81225c3cc3fc.
Reverted https://github.com/pytorch/pytorch/pull/165212 on behalf of https://github.com/atalman due to test/test_rename_privateuse1_to_existing_device.py::TestRenamePrivateuseoneToExistingBackend::test_external_module_register_with_existing_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930365446/job/54046768884 ) [HUD commit link](a5335263d3 ) ([comment](https://github.com/pytorch/pytorch/pull/165212#issuecomment-3467968796 ))
2025-10-30 13:24:56 +00:00
9ee1afbf66
Revert "[user-streams] Handle returning the current stream with/without device index ( #165356 )"
...
This reverts commit f1af679270392c83e03808c8af5e2cbe3cdf16ce.
Reverted https://github.com/pytorch/pytorch/pull/165356 on behalf of https://github.com/atalman due to test/test_rename_privateuse1_to_existing_device.py::TestRenamePrivateuseoneToExistingBackend::test_external_module_register_with_existing_backend [GH job link](https://github.com/pytorch/pytorch/actions/runs/18930365446/job/54046768884 ) [HUD commit link](a5335263d3 ) ([comment](https://github.com/pytorch/pytorch/pull/165356#issuecomment-3467967061 ))
2025-10-30 13:22:24 +00:00
f60751024e
Revert "[2/N] Add strict parameter to Python zip calls ( #166257 )"
...
This reverts commit 39e5cdddf7e57881c52473d1288a66f0222527e1.
Reverted https://github.com/pytorch/pytorch/pull/166257 on behalf of https://github.com/atalman due to Failing: test/distributed/fsdp/test_fsdp_mixed_precision.py::TestFSDPTrainEval::test_train_ema_eval_flow [GH job link](https://github.com/pytorch/pytorch/actions/runs/18934047991/job/54057218160 ) [HUD commit link](39e5cdddf7 ) ([comment](https://github.com/pytorch/pytorch/pull/166257#issuecomment-3467955332 ))
2025-10-30 13:20:00 +00:00
2de4cf2102
[1/N] Remove unused loop variables ( #166258 )
...
This PR removes unused loop variables.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166258
Approved by: https://github.com/Lucaskabela , https://github.com/mlazos
2025-10-30 12:22:25 +00:00
369f2d6951
[3/N] fix typo in other folders ( #166606 )
...
fix typo in other folders
#166374
#166126
_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00
32920926f0
[xpu][fix] [Inductor] Avoid using tl.sqrt_rn on XPU before triton is ready ( #165740 )
...
Fixes #165738
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165740
Approved by: https://github.com/etaf , https://github.com/EikanWang , https://github.com/chuanqi129 , https://github.com/desertfire
2025-10-30 09:24:24 +00:00
39e5cdddf7
[2/N] Add strict parameter to Python zip calls ( #166257 )
...
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
2025-10-30 08:10:10 +00:00
2829d48bd1
[xpu][test][1/N] Port 3 fsdp distributed test cases to Intel GPU ( #161476 )
...
For https://github.com/pytorch/pytorch/issues/114850 , we will port 3 distributed tests to Intel GPU.
We could enable Intel GPU with the following methods and try the best to keep the original code styles:
- use "torch.accelerator.current_accelerator()" to determine the accelerator backend
- use "requires_accelerator_dist_backend" to enable "xccl"
- enabled XPU for some test path
- skip some test cases that Intel GPU does not support
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161476
Approved by: https://github.com/weifengpy , https://github.com/guangyey
2025-10-30 07:30:04 +00:00
f1af679270
[user-streams] Handle returning the current stream with/without device index ( #165356 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165356
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304 , #164522 , #164819 , #165211 , #165212
2025-10-30 07:20:25 +00:00
d46d8d6f54
[triton][sigmoid] Fix kernel cache and serialization issue for triton sigmoid + CUDA kernel bug ( #166568 )
...
Differential Revision: D85792537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166568
Approved by: https://github.com/minjang
2025-10-30 06:17:39 +00:00
a5335263d3
[user-streams] Track symbolic current stream ( #165212 )
...
merge into stream tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165212
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304 , #164522 , #164819 , #165211
2025-10-30 04:58:53 +00:00
79aee77381
[user-streams] Add current stream source ( #165211 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165211
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304 , #164522 , #164819
2025-10-30 04:58:53 +00:00
f5cb9a4c68
[user-streams] Fix stream graph output semantics ( #164819 )
...
Preivously, we would stash a single stream value we constructed at trace time in a global and return the same value from repeated calls to the graph.
With this PR, we construct the stream value in advance, reference the constructed value in the graph via the lookup table, and if that value is returned as an output, read the value from the lookup table and return it (in bytecode, not as a graph output, since we don't support arbitrary stream outputs).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304 , #164522
2025-10-30 04:58:46 +00:00
f20bf77874
[audio hash update] update the pinned audio hash ( #166597 )
...
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml ).
Update the pinned audio hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166597
Approved by: https://github.com/pytorchbot
2025-10-30 04:28:30 +00:00
75f798e05b
[inductor][mi350] add tech specs for MI350 ( #166576 )
...
Summary:
was digging through matmul padding for other work, and I noticed that the compute bound checking won't work on MI350 since we haven't supplied the tech specs yet.
I added MI350 specs following the predefined format
Test Plan: CI
Differential Revision: D85804980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166576
Approved by: https://github.com/leitian
2025-10-30 03:46:52 +00:00
476b149a00
bwd pass ( #164504 )
...
**Summary**
This implements the backward pass for the Varlen API and registers `_varlen_attn()` as a custom op.
**Benchmarking**
To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.
Settings:
- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs
| | Variable Length API | SDPA |
|--------|--------------------|----------|
| Runtime | 0.8189142608642578 ms | 3.263883056640625 ms |
| TFLOPs | 268.652 | 158.731 |
We can see that runtime for Varlen is >3x faster
**Testing**
Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen gradients vs SDPA.
For custom op testing, `test_custom_op_registration` uses logging mode to verify that `_varlen_attn()` was called and tests with `torch.compile`. `test_custom_op_compliances` uses `torch.library.opcheck()` to verify.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
Approved by: https://github.com/drisspg
2025-10-30 03:46:37 +00:00
845da9c817
[ONNX] Ignore pyrefly errors in torchlib ( #166588 )
...
Fixes #166475
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166588
Approved by: https://github.com/titaiwangms
2025-10-30 03:43:52 +00:00
0918bf321c
[xpu][test] Reuse native_mm and mix_order_reduction for Intel GPU. ( #166384 )
...
This PR reused native_mm and mix_order_reduction for Intel GPU and enabled the corresonding test.
Fixes #165370
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166384
Approved by: https://github.com/jansel
2025-10-30 03:38:35 +00:00
90519402c2
address DDE in matmul decomp ( #166541 )
...
Address https://github.com/pytorch/pytorch/issues/165081
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166541
Approved by: https://github.com/mlazos
2025-10-30 03:19:29 +00:00
791ca80d3a
Enable local tensor mode for DTensor attention and convolution tests ( #166406 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166406
Approved by: https://github.com/ezyang
2025-10-30 02:48:02 +00:00
5cbdade914
Fix a syntactic error in test_indexing.py ( #166390 )
...
This PR fixes a syntactic error in test_indexing.py by a misplaced `if else` expression.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166390
Approved by: https://github.com/jerryzh168
2025-10-30 02:28:01 +00:00
0187db88d4
[ROCm][CI] Create periodic-rocm-mi200.yml ( #166544 )
...
* We are separating out the rocm jobs of the periodic workflow
* We are introducing a new label `ciflow/periodic-rocm-mi200` to allow us to run distributed tests only on ROCm runners, without triggering many other jobs on the `periodic.yml` workflow (via `ciflow/periodic`)
* This new workflow will also be triggered via the `ciflow/periodic`, thus maintaining the old status quo.
* We are reverting to the `linux.rocm.gpu.4` label since it targets a lot more CI nodes at this point than the K8s/ARC-based `linux.rocm.gpu.mi250.4` label, as that is still having some network/scaling issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166544
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com >
2025-10-30 02:08:07 +00:00
311ea0dec0
shrink_group implementation to expose ncclCommShrink API ( #164518 )
...
Closes #164529
To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink ) API to PyTorch.
This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization.
For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518
Approved by: https://github.com/kwen2501
2025-10-30 01:50:54 +00:00
47a035b3f4
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:11:38 -07:00
3624c4ab97
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:11:38 -07:00
3c9acddda1
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:06:23 -07:00
05586b62cc
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:06:23 -07:00
4c18855de5
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 13:25:47 -07:00
fc57ec26ba
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 13:25:47 -07:00
a14ec2a79c
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:27:43 -07:00
b070bd0722
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:27:43 -07:00
331a3d5d1c
Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:19:53 -07:00
170ecc66e9
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:19:53 -07:00
bb7a32364f
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:16:04 -07:00
df278c8a0e
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:16:04 -07:00
a227f7d5ee
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:10:57 -07:00
b5801b5449
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:10:57 -07:00
45b4521b83
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-13 12:20:49 -07:00
a5b1ef2d1d
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-13 12:20:49 -07:00
702e868c31
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 14:01:51 -07:00
36f4550411
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 14:01:51 -07:00
1c02006361
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 11:54:13 -07:00
b33759a6c5
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 11:54:13 -07:00
d3e3e504cf
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 09:10:09 -07:00
ca1160f112
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 09:10:09 -07:00
dd749f54c9
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 08:59:22 -07:00
2bf5728496
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 08:59:22 -07:00
452575e225
Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-06 14:54:22 -07:00
970c40c3c0
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-06 14:54:22 -07:00
fd4b78b142
Update on "[WIP][Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos
[ghstack-poisoned]
2025-10-03 16:13:16 -07:00
e6772939b0
Update on "[WIP][Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos
[ghstack-poisoned]
2025-10-03 14:37:05 -07:00
a2d2747597
Update on "[WIP][Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben
[ghstack-poisoned]
2025-09-30 11:59:45 -07:00
8704f71a52
[WIP][Inductor] Mutable custom op pattern matching
...
[ghstack-poisoned]
2025-09-30 11:57:14 -07:00