[Graph Partition] improve custom op output alias (#163227)

For a custom op with multiple outputs, we will see the following generated code:
```
buf1 = op1(arg0)
buf3 = buf0[0]
buf4 = buf0[1]
del buf1 # <--- if buf1 is not accessed in the future
```

If `buf1` is not accessed in the future, it's good to deallocate early. So we don't delay `del` until both buf3 and buf4 are not used anymore. Note that buf3 and buf4 hold reference to the data such that `del buf1` does not prevent their usage.

However, when there are mutating args, we don't see `del buf1` immediately.

```python
@torch.library.custom_op(
    "mylib::op1",
    mutates_args=["x"],
    schema="(Tensor(a!)?  x) -> (Tensor, Tensor)",
    device_types="cuda",
)
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
    x = x + 1
    return (x + 1, x + 2)
```

<img width="661" height="821" alt="image" src="https://github.com/user-attachments/assets/3d1d1f5a-9749-4652-bb02-da593c78702d" />

Why? Because `buf3` is a MultiOutput with `buf1` as input and believes `buf1` (an output of FallbackKernel op1) has inputs that alias output.
72fedf0575/torch/_inductor/ir.py (L7976-L7982)

According to `[NOTE: FallbackKernel supported operators]`, as a mutating op that are auto-functionalizable, buf1's output should NOT alias any of the inputs. This PR improves get_inputs_that_alias_output of Fallback Kernel.

Use case: [moe custom op in vllm](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py#L2057-L2064)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163227
Approved by: https://github.com/zou3519
This commit is contained in:
Boyuan Feng
2025-09-19 17:01:36 +00:00
committed by PyTorch MergeBot
parent e631d76002
commit 4967ad8baa
3 changed files with 78 additions and 4 deletions

View File

@ -3231,6 +3231,60 @@ if HAS_CUDA_AND_TRITON:
# splitting on 1 custom gives 2 cudagraphs
self.assertEqual(self.get_manager().new_graph_id().id, 2)
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_custom_op_mutation_late_free(self):
@torch.library.custom_op(
"mylib::op1",
mutates_args=["x"],
schema="(Tensor(a!)? x) -> (Tensor, Tensor)",
device_types="cuda",
)
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
x = x + 1
return (x + 1, x + 2)
@op1.register_fake
def _(x) -> tuple[torch.Tensor, torch.Tensor]:
return (torch.empty_like(x), torch.empty_like(x))
@torch.library.custom_op(
"mylib::cg_unsafe_op",
mutates_args=[],
schema="(Tensor x, Tensor y, Tensor x1, Tensor y1) -> Tensor",
device_types="cuda",
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def cg_unsafe_op(x0, x1, y0, y1) -> torch.Tensor:
return x0 + x1 + y0 + y1
@cg_unsafe_op.register_fake
def _(x0, x1, y0, y1) -> torch.Tensor:
return torch.empty_like(x0)
def f(x):
x = x + 1
x = op1(x)
x0, x1 = x[0], x[1]
y0 = x0 + 1
y1 = x1 + 1
y = cg_unsafe_op(x0, x1, y0, y1)
z = y + x0 + x1
z0, z1 = op1(z)
z2 = z0 + z1
res = cg_unsafe_op(z2, z2, y, y)
return res
x = torch.randn(2, 2, device="cuda")
x_cloned = x.clone()
eager_out = f(x)
f_compiled = torch.compile(f, mode="reduce-overhead")
for _ in range(5):
compiled_out = f_compiled(x_cloned)
self.assertEqual(eager_out, compiled_out)
@config.patch(implicit_fallbacks=True)
@torch._inductor.config.patch("graph_partition", True)
def test_graph_partition_custom_op_dynamoc_shapes(self):

View File

@ -1156,11 +1156,13 @@ class InplacingTests(TestCase):
torch.compile(f, fullgraph=True),
)
# Check that we are allocating the minimum number of intermediate buffers
# Check that we are not allocate intermediate buffers
# which can be reused.
matches = re.findall(r"empty_strided_\w+\(", code)
self.assertEqual(len(matches), 1)
self.assertEqual(len(matches), 0)
self.assertEqual("in_out" in code, True)
self.assertExpectedInline(count_numel(f), """39""")
self.assertExpectedInline(count_numel(f), """45""")
@requires_cuda_and_triton
def test_inplace_triton_kernel_v1(self):

View File

@ -7550,7 +7550,25 @@ class FallbackKernel(ExternKernelAlloc):
return get_schema_info(self.op_overload).is_mutable()
def get_inputs_that_alias_output(self) -> Sequence[str]:
return self.alias_names
assert isinstance(
self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
), (
f"Fails to create FallbackKernel for {self.op_overload}: "
f"{type(self.op_overload)} not supported"
)
# See [Note: FallbackKernel supported operators]: for a mutating
# op that is auto-functionalizable, its outputs does NOT
# alias any of the inputs.
if (
not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
and "_c10d_functional" not in self.op_overload.name()
and self.op_overload._schema.is_mutable
and can_auto_functionalize(self.op_overload)
):
return []
else:
return self.alias_names
def get_mutation_names(self) -> Sequence[str]:
assert len(self.mutation_names) <= 1