mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Graph Partition] improve custom op output alias (#163227)
For a custom op with multiple outputs, we will see the following generated code:
```
buf1 = op1(arg0)
buf3 = buf0[0]
buf4 = buf0[1]
del buf1 # <--- if buf1 is not accessed in the future
```
If `buf1` is not accessed in the future, it's good to deallocate early. So we don't delay `del` until both buf3 and buf4 are not used anymore. Note that buf3 and buf4 hold reference to the data such that `del buf1` does not prevent their usage.
However, when there are mutating args, we don't see `del buf1` immediately.
```python
@torch.library.custom_op(
"mylib::op1",
mutates_args=["x"],
schema="(Tensor(a!)? x) -> (Tensor, Tensor)",
device_types="cuda",
)
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
x = x + 1
return (x + 1, x + 2)
```
<img width="661" height="821" alt="image" src="https://github.com/user-attachments/assets/3d1d1f5a-9749-4652-bb02-da593c78702d" />
Why? Because `buf3` is a MultiOutput with `buf1` as input and believes `buf1` (an output of FallbackKernel op1) has inputs that alias output.
72fedf0575/torch/_inductor/ir.py (L7976-L7982)
According to `[NOTE: FallbackKernel supported operators]`, as a mutating op that are auto-functionalizable, buf1's output should NOT alias any of the inputs. This PR improves get_inputs_that_alias_output of Fallback Kernel.
Use case: [moe custom op in vllm](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py#L2057-L2064)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163227
Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
e631d76002
commit
4967ad8baa
@ -3231,6 +3231,60 @@ if HAS_CUDA_AND_TRITON:
|
||||
# splitting on 1 custom gives 2 cudagraphs
|
||||
self.assertEqual(self.get_manager().new_graph_id().id, 2)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_custom_op_mutation_late_free(self):
|
||||
@torch.library.custom_op(
|
||||
"mylib::op1",
|
||||
mutates_args=["x"],
|
||||
schema="(Tensor(a!)? x) -> (Tensor, Tensor)",
|
||||
device_types="cuda",
|
||||
)
|
||||
def op1(x) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = x + 1
|
||||
return (x + 1, x + 2)
|
||||
|
||||
@op1.register_fake
|
||||
def _(x) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return (torch.empty_like(x), torch.empty_like(x))
|
||||
|
||||
@torch.library.custom_op(
|
||||
"mylib::cg_unsafe_op",
|
||||
mutates_args=[],
|
||||
schema="(Tensor x, Tensor y, Tensor x1, Tensor y1) -> Tensor",
|
||||
device_types="cuda",
|
||||
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||
)
|
||||
def cg_unsafe_op(x0, x1, y0, y1) -> torch.Tensor:
|
||||
return x0 + x1 + y0 + y1
|
||||
|
||||
@cg_unsafe_op.register_fake
|
||||
def _(x0, x1, y0, y1) -> torch.Tensor:
|
||||
return torch.empty_like(x0)
|
||||
|
||||
def f(x):
|
||||
x = x + 1
|
||||
x = op1(x)
|
||||
x0, x1 = x[0], x[1]
|
||||
y0 = x0 + 1
|
||||
y1 = x1 + 1
|
||||
y = cg_unsafe_op(x0, x1, y0, y1)
|
||||
z = y + x0 + x1
|
||||
z0, z1 = op1(z)
|
||||
z2 = z0 + z1
|
||||
res = cg_unsafe_op(z2, z2, y, y)
|
||||
return res
|
||||
|
||||
x = torch.randn(2, 2, device="cuda")
|
||||
x_cloned = x.clone()
|
||||
eager_out = f(x)
|
||||
|
||||
f_compiled = torch.compile(f, mode="reduce-overhead")
|
||||
|
||||
for _ in range(5):
|
||||
compiled_out = f_compiled(x_cloned)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
@config.patch(implicit_fallbacks=True)
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
def test_graph_partition_custom_op_dynamoc_shapes(self):
|
||||
|
@ -1156,11 +1156,13 @@ class InplacingTests(TestCase):
|
||||
torch.compile(f, fullgraph=True),
|
||||
)
|
||||
|
||||
# Check that we are allocating the minimum number of intermediate buffers
|
||||
# Check that we are not allocate intermediate buffers
|
||||
# which can be reused.
|
||||
matches = re.findall(r"empty_strided_\w+\(", code)
|
||||
self.assertEqual(len(matches), 1)
|
||||
self.assertEqual(len(matches), 0)
|
||||
self.assertEqual("in_out" in code, True)
|
||||
|
||||
self.assertExpectedInline(count_numel(f), """39""")
|
||||
self.assertExpectedInline(count_numel(f), """45""")
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_inplace_triton_kernel_v1(self):
|
||||
|
@ -7550,7 +7550,25 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
return get_schema_info(self.op_overload).is_mutable()
|
||||
|
||||
def get_inputs_that_alias_output(self) -> Sequence[str]:
|
||||
return self.alias_names
|
||||
assert isinstance(
|
||||
self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
||||
), (
|
||||
f"Fails to create FallbackKernel for {self.op_overload}: "
|
||||
f"{type(self.op_overload)} not supported"
|
||||
)
|
||||
|
||||
# See [Note: FallbackKernel supported operators]: for a mutating
|
||||
# op that is auto-functionalizable, its outputs does NOT
|
||||
# alias any of the inputs.
|
||||
if (
|
||||
not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
|
||||
and "_c10d_functional" not in self.op_overload.name()
|
||||
and self.op_overload._schema.is_mutable
|
||||
and can_auto_functionalize(self.op_overload)
|
||||
):
|
||||
return []
|
||||
else:
|
||||
return self.alias_names
|
||||
|
||||
def get_mutation_names(self) -> Sequence[str]:
|
||||
assert len(self.mutation_names) <= 1
|
||||
|
Reference in New Issue
Block a user