[Graph Partition] fix partition x memory plan issue (#165514)

For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)):

```
def partition_0(args):
    ...
    del buf0
    return (buf3, buf4, buf5, buf2, primals_4, )

...

  File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0
    return (buf3, buf4, buf5, buf2, primals_4, )
                              ^^^^
NameError: name 'buf2' is not defined. Did you mean: 'buf0'?
```

When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)):

```
def call(self, args):
    ...
    buf2 = buf0; del buf0  # reuse
    ...
```

Note that the issue is buf0 is not reused for buf2 when using graph partition.

Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514
Approved by: https://github.com/ProExpertProg, https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-10-15 21:52:16 +00:00
committed by PyTorch MergeBot
parent fa1539594b
commit f071f17911
3 changed files with 126 additions and 3 deletions

View File

@ -974,6 +974,125 @@ if HAS_CUDA_AND_TRITON:
num_partitions = get_num_partitions(code)
self.assertEqual(num_partitions, 1)
@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("implicit_fallbacks", True)
def test_graph_partition_with_memory_plan_reuse(self):
BATCH_SIZE = 16
MLP_SIZE = 128
HIDDEN_SIZE = 128
RANDOM_SEED = 0
@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
tags=(torch._C.Tag.cudagraph_unsafe,),
)
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)
@attention.register_fake
def _(q, k, v, out):
return None
class ParentModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
class Attention(torch.nn.Module):
def __init__(self, mlp_size: int, hidden_size: int) -> None:
super().__init__()
self.pre_attn = torch.nn.Linear(mlp_size, hidden_size, bias=False)
self.post_attn = torch.nn.Linear(hidden_size, mlp_size, bias=False)
self.rms_norm_weight = torch.nn.Parameter(torch.ones(hidden_size))
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
x_f32 = x.float()
return (
x_f32
* torch.rsqrt(
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6
)
* self.rms_norm_weight
).to(x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pre_attn(x)
x = self.rms_norm_ref(x)
attn_output = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, attn_output)
x = attn_output
x = self.rms_norm_ref(x)
x = self.post_attn(x)
return x
class CompiledAttention(torch.nn.Module):
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
) -> None:
super().__init__()
self.attn = Attention(mlp_size, hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x)
class CompiledAttentionTwo(CompiledAttention):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.attn(x) + x
class SimpleModelWithTwoGraphs(ParentModel):
def __init__(
self,
*,
mlp_size: int,
hidden_size: int,
) -> None:
super().__init__()
self.attn_one = CompiledAttention(
mlp_size=mlp_size,
hidden_size=hidden_size,
)
self.attn_two = CompiledAttentionTwo(
mlp_size=mlp_size,
hidden_size=hidden_size,
)
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz = x.shape[0]
# CUDAGraph expects same tensor addresses for each run
self.hidden_states[:bsz].copy_(x)
x = self.attn_one(self.hidden_states[:bsz])
self.hidden_states[:bsz].copy_(x)
x = self.attn_two(self.hidden_states[:bsz])
return x
eager_model = (
SimpleModelWithTwoGraphs(
mlp_size=MLP_SIZE,
hidden_size=HIDDEN_SIZE,
)
.eval()
.cuda()
)
compiled_model = torch.compile(eager_model, mode="reduce-overhead")
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
for _ in range(3):
eager_out = eager_model(inputs)
compiled_out = compiled_model(inputs)
self.assertEqual(eager_out, compiled_out)
@torch._inductor.config.patch("graph_partition", True)
@torch._inductor.config.patch("triton.cudagraph_trees", False)
def test_graph_partition_gc(self):

View File

@ -1808,7 +1808,8 @@ class PythonWrapperCodegen(CodeGen):
self.lines = MemoryPlanner(self).plan(self.lines)
def memory_plan_reuse(self):
out_names = V.graph.get_output_names()
outputs = self.get_graph_outputs()
out_names = V.graph._get_output_names(outputs)
while (
self.lines

View File

@ -2479,11 +2479,11 @@ class GraphLowering(torch.fx.Interpreter):
return mod
def get_output_names(self) -> list[str]:
def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
names = []
shape_counter = itertools.count(0)
none_counter = itertools.count(0)
for node in self.graph_outputs:
for node in graph_outputs:
if isinstance(node, ir.NoneAsConstantBuffer):
names.append(f"{self.name}_none{next(none_counter)}")
elif isinstance(node, ir.ShapeAsConstantBuffer):
@ -2492,6 +2492,9 @@ class GraphLowering(torch.fx.Interpreter):
names.append(node.get_name())
return names
def get_output_names(self) -> list[str]:
return self._get_output_names(self.graph_outputs)
def is_unspec_arg(self, name: str) -> bool:
# dynamo wraps unspec variable as 0d CPU tensor,
# need to convert to scalar during codegen (triton only)