2e61e1e740
Update base for Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 16:07:25 -08:00
60f98d5a27
Update base for Update on "[Inductor] Mutable custom op pattern matching and safety checker"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:49:12 -08:00
67dbbe261d
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 15:08:43 -08:00
1015260e07
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-11-06 14:34:08 -08:00
3624c4ab97
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:11:38 -07:00
05586b62cc
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 16:06:23 -07:00
fc57ec26ba
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 13:25:47 -07:00
b070bd0722
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:27:43 -07:00
170ecc66e9
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
TL;DR
TorchInductor now supports pattern matching mutable custom ops directly by unwrapping auto_functionalized wrappers and inserting explicit dependency edges. This enables stable fusion patterns across PyTorch versions.
Problem:
vLLM has mutable custom ops such as (`rms_norm`, `static_scaled_fp8_quant`) that require pattern matching for [fusion passes](824a3f403f/vllm/compilation/fusion.py (L122-L131) ). Currently they pattern match against `auto_functionalized(mutable_op)` wrappers, but vLLM is upgrading to `auto_functionalized_v2` (soon v3) with incompatible semantics that break existing patterns.
`auto_functionalized_v2` decomposes to: view + clone + functional_op + copy_. The specific view operations vary based on which inputs are mutated, making it difficult to write stable patterns that match view+op combinations.
Why current pattern matcher not support the raw custom mutating op ?
Consider this mutable op sequence:
```python
foo_inplace(x) # Mutates tensor x
bar_out(x, out) # Uses mutated x, produces out
```
FX Graph Representation:
```python
%x = placeholder()
%out = placeholder()
%foo_result = call_function(foo_inplace, (%x,))
%bar_result = call_function(bar_out, (%x, %out)) # Missing dependency!
```
There is no explicit edge from `foo_inplace` to `bar_out` even though `bar_out` depends on `foo_inplace` mutation. Without explicit edges, pattern matchers cannot reliably detect op sequences or ensure correct execution order.
High level idea:
- Identify mutation ops using operator schemas
- For each mutated tensor, find all storages (including views/aliases) via GraphAliasTracker
- Insert DEP_OP after each mutation
- Redirect later users of aliased storages to depend on DEP_OP
Example:
Custom ops definitions
```python
torch.library.custom_op("mylib::foo_inplace", mutates_args={"x"})
def foo_inplace(x: torch.Tensor) -> None:
x.add_(1)
torch.library.custom_op("mylib::bar_out", mutates_args={"out"})
def bar_out(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x + 2)
torch.library.custom_op("mylib::foobar_out", mutates_args={"x", "out"})
def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
x.add_(1)
out.copy_(x + 2)
# pattern registration
def pattern(x, out):
foo_inplace(x)
bar_out(x, out)
return x, out
def replacement(x, out):
foobar_out(x, out)
return x, out
```
Pattern graph after add_implict_edges (used for matching)
```python
graph():
%x_1 : [num_users=2] = placeholder[target=x_1]
%out_1 : [num_users=2] = placeholder[target=out_1]
%foo_inplace : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%x_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%x_1,), kwargs = {writer_token: %foo_inplace})
%bar_out : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %out_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%out_1,), kwargs = {writer_token: %bar_out})
return (op_for_dependencies, op_for_dependencies_1)
```
Case : mutates a clone of graph input
```python
def f(x, out):
x = x.clone()
out = out.clone()
foo_inplace(x)
bar_out(x, out)
return out
```
before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
return (getitem_3,)
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%as_strided_default_3,), kwargs = {})
%op_for_dependencies : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_3,), kwargs = {writer_token: %foo_inplace_default})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %as_strided_default_1), kwargs = {})
%op_for_dependencies_1 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%as_strided_default_1,), kwargs = {writer_token: %bar_out_default})
return (op_for_dependencies_1,)
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%as_strided_default_2 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg0_1, [3], [1], 0), kwargs = {})
%clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default_2,), kwargs = {})
%as_strided_default_3 : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default_1, [3], [1], 0), kwargs = {})
%as_strided_default : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%arg1_1, [3], [1], 0), kwargs = {})
%clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%as_strided_default,), kwargs = {})
%as_strided_default_1 : [num_users=2] = call_function[target=torch.ops.aten.as_strided.default](args = (%clone_default, [3], [1], 0), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%as_strided_default_3, %as_strided_default_1), kwargs = {})
return (as_strided_default_1,)
```
Case: multiple writers and readers
```python
def f(
x: torch.Tensor, y: torch.Tensor, outx: torch.Tensor, outy: torch.Tensor
):
foo_inplace(x.view(-1))
foo_inplace(y.view(-1))
bar_out(x, outx)
bar_out(y, outy)
return outx, outy
```
Before mutable custom op pass
```python
graph():
%arg0_1 : [num_users=2] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%auto_functionalized_v2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg0_1]})
%getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2, 1), kwargs = {})
%auto_functionalized_v2_1 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.foo_inplace.default,), kwargs = {_x_base_index: 0, _x_alias: True, _all_bases: [%arg1_1]})
%getitem_3 : [num_users=2] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_1, 1), kwargs = {})
%auto_functionalized_v2_2 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_1, _out_base_index: 0, _all_bases: [%arg2_1]})
%getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_2, 1), kwargs = {})
%auto_functionalized_v2_3 : [num_users=1] = call_function[target=torch.ops.higher_order.auto_functionalized_v2](args = (mylib.bar_out.default,), kwargs = {x: %getitem_3, _out_base_index: 0, _all_bases: [%arg3_1]})
%getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%auto_functionalized_v2_3, 1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %getitem_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %getitem_3), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %getitem_5), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %getitem_7), kwargs = {})
return ()
```
after decompose auto_functionalized
```python
graph():
%arg0_1 : [num_users=3] = placeholder[target=arg0_1]
%arg1_1 : [num_users=3] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%alias_default : [num_users=1] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=0] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%bar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg0_1, %arg2_1), kwargs = {})
%bar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.bar_out.default](args = (%arg1_1, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %arg0_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg1_1, %arg1_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
after add_implict_edges
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%foo_inplace_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default_1,), kwargs = {})
%op_for_dependencies : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default_1,), kwargs = {writer_token: %foo_inplace_default_1})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foo_inplace_default : [num_users=1] = call_function[target=torch.ops.mylib.foo_inplace.default](args = (%alias_default,), kwargs = {})
%op_for_dependencies_1 : [num_users=2] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%alias_default,), kwargs = {writer_token: %foo_inplace_default})
%bar_out_default_1 : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies, %arg2_1), kwargs = {})
%op_for_dependencies_2 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg2_1,), kwargs = {writer_token: %bar_out_default_1})
%bar_out_default : [num_users=1] = call_function[target=torch.ops.mylib.bar_out.default](args = (%op_for_dependencies_1, %arg3_1), kwargs = {})
%op_for_dependencies_3 : [num_users=1] = call_function[target=torch.ops.pattern_matcher.op_for_dependencies](args = (%arg3_1,), kwargs = {writer_token: %bar_out_default})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies, %op_for_dependencies), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_1, %op_for_dependencies_1), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_2, %op_for_dependencies_2), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%op_for_dependencies_3, %op_for_dependencies_3), kwargs = {})
return ()
```
after remove_implict_edges (pattern match happened foo_inplace + bar -> foobar_out)
```python
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%alias_default_1 : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg0_1,), kwargs = {})
%alias_default : [num_users=2] = call_function[target=torch.ops.aten.alias.default](args = (%arg1_1,), kwargs = {})
%foobar_out_default_1 : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default_1, %arg2_1), kwargs = {})
%foobar_out_default : [num_users=0] = call_function[target=torch.ops.mylib.foobar_out.default](args = (%alias_default, %arg3_1), kwargs = {})
%copy_ : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default_1, %alias_default_1), kwargs = {})
%copy__1 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%alias_default, %alias_default), kwargs = {})
%copy__2 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg2_1, %arg2_1), kwargs = {})
%copy__3 : [num_users=0] = call_function[target=torch.ops.aten.copy_.default](args = (%arg3_1, %arg3_1), kwargs = {})
return ()
```
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 11:19:53 -07:00
df278c8a0e
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:16:04 -07:00
b5801b5449
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-16 02:10:57 -07:00
a5b1ef2d1d
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-13 12:20:49 -07:00
36f4550411
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 14:01:51 -07:00
b33759a6c5
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 11:54:13 -07:00
ca1160f112
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 09:10:09 -07:00
2bf5728496
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-10 08:59:22 -07:00
970c40c3c0
Update base for Update on "[Inductor] Mutable custom op pattern matching"
...
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos choijon5
[ghstack-poisoned]
2025-10-06 14:54:22 -07:00