From 4967ad8baa724b8b1acc123698bb1265723feb87 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Fri, 19 Sep 2025 17:01:36 +0000 Subject: [PATCH] [Graph Partition] improve custom op output alias (#163227) For a custom op with multiple outputs, we will see the following generated code: ``` buf1 = op1(arg0) buf3 = buf0[0] buf4 = buf0[1] del buf1 # <--- if buf1 is not accessed in the future ``` If `buf1` is not accessed in the future, it's good to deallocate early. So we don't delay `del` until both buf3 and buf4 are not used anymore. Note that buf3 and buf4 hold reference to the data such that `del buf1` does not prevent their usage. However, when there are mutating args, we don't see `del buf1` immediately. ```python @torch.library.custom_op( "mylib::op1", mutates_args=["x"], schema="(Tensor(a!)? x) -> (Tensor, Tensor)", device_types="cuda", ) def op1(x) -> tuple[torch.Tensor, torch.Tensor]: x = x + 1 return (x + 1, x + 2) ``` image Why? Because `buf3` is a MultiOutput with `buf1` as input and believes `buf1` (an output of FallbackKernel op1) has inputs that alias output. https://github.com/pytorch/pytorch/blob/72fedf05752069c9e8b97c64397aedf6ee2bf5ec/torch/_inductor/ir.py#L7976-L7982 According to `[NOTE: FallbackKernel supported operators]`, as a mutating op that are auto-functionalizable, buf1's output should NOT alias any of the inputs. This PR improves get_inputs_that_alias_output of Fallback Kernel. Use case: [moe custom op in vllm](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/layer.py#L2057-L2064) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163227 Approved by: https://github.com/zou3519 --- test/inductor/test_cudagraph_trees.py | 54 +++++++++++++++++++++++++++ test/inductor/test_perf.py | 8 ++-- torch/_inductor/ir.py | 20 +++++++++- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 91e65cad8299..cab895c3bcc1 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -3231,6 +3231,60 @@ if HAS_CUDA_AND_TRITON: # splitting on 1 custom gives 2 cudagraphs self.assertEqual(self.get_manager().new_graph_id().id, 2) + @config.patch(implicit_fallbacks=True) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_custom_op_mutation_late_free(self): + @torch.library.custom_op( + "mylib::op1", + mutates_args=["x"], + schema="(Tensor(a!)? x) -> (Tensor, Tensor)", + device_types="cuda", + ) + def op1(x) -> tuple[torch.Tensor, torch.Tensor]: + x = x + 1 + return (x + 1, x + 2) + + @op1.register_fake + def _(x) -> tuple[torch.Tensor, torch.Tensor]: + return (torch.empty_like(x), torch.empty_like(x)) + + @torch.library.custom_op( + "mylib::cg_unsafe_op", + mutates_args=[], + schema="(Tensor x, Tensor y, Tensor x1, Tensor y1) -> Tensor", + device_types="cuda", + tags=(torch._C.Tag.cudagraph_unsafe,), + ) + def cg_unsafe_op(x0, x1, y0, y1) -> torch.Tensor: + return x0 + x1 + y0 + y1 + + @cg_unsafe_op.register_fake + def _(x0, x1, y0, y1) -> torch.Tensor: + return torch.empty_like(x0) + + def f(x): + x = x + 1 + x = op1(x) + x0, x1 = x[0], x[1] + y0 = x0 + 1 + y1 = x1 + 1 + y = cg_unsafe_op(x0, x1, y0, y1) + z = y + x0 + x1 + z0, z1 = op1(z) + z2 = z0 + z1 + res = cg_unsafe_op(z2, z2, y, y) + return res + + x = torch.randn(2, 2, device="cuda") + x_cloned = x.clone() + eager_out = f(x) + + f_compiled = torch.compile(f, mode="reduce-overhead") + + for _ in range(5): + compiled_out = f_compiled(x_cloned) + self.assertEqual(eager_out, compiled_out) + @config.patch(implicit_fallbacks=True) @torch._inductor.config.patch("graph_partition", True) def test_graph_partition_custom_op_dynamoc_shapes(self): diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 83cd236875f4..2dd6d498936f 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -1156,11 +1156,13 @@ class InplacingTests(TestCase): torch.compile(f, fullgraph=True), ) - # Check that we are allocating the minimum number of intermediate buffers + # Check that we are not allocate intermediate buffers + # which can be reused. matches = re.findall(r"empty_strided_\w+\(", code) - self.assertEqual(len(matches), 1) + self.assertEqual(len(matches), 0) + self.assertEqual("in_out" in code, True) - self.assertExpectedInline(count_numel(f), """39""") + self.assertExpectedInline(count_numel(f), """45""") @requires_cuda_and_triton def test_inplace_triton_kernel_v1(self): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 77500888d223..5d27a0335600 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7550,7 +7550,25 @@ class FallbackKernel(ExternKernelAlloc): return get_schema_info(self.op_overload).is_mutable() def get_inputs_that_alias_output(self) -> Sequence[str]: - return self.alias_names + assert isinstance( + self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ), ( + f"Fails to create FallbackKernel for {self.op_overload}: " + f"{type(self.op_overload)} not supported" + ) + + # See [Note: FallbackKernel supported operators]: for a mutating + # op that is auto-functionalizable, its outputs does NOT + # alias any of the inputs. + if ( + not isinstance(self.op_overload, torch._ops.HigherOrderOperator) + and "_c10d_functional" not in self.op_overload.name() + and self.op_overload._schema.is_mutable + and can_auto_functionalize(self.op_overload) + ): + return [] + else: + return self.alias_names def get_mutation_names(self) -> Sequence[str]: assert len(self.mutation_names) <= 1