mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Graph Partition] fix partition x memory plan issue (#165514)
For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)): ``` def partition_0(args): ... del buf0 return (buf3, buf4, buf5, buf2, primals_4, ) ... File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0 return (buf3, buf4, buf5, buf2, primals_4, ) ^^^^ NameError: name 'buf2' is not defined. Did you mean: 'buf0'? ``` When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)): ``` def call(self, args): ... buf2 = buf0; del buf0 # reuse ... ``` Note that the issue is buf0 is not reused for buf2 when using graph partition. Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165514 Approved by: https://github.com/ProExpertProg, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa1539594b
commit
f071f17911
@ -974,6 +974,125 @@ if HAS_CUDA_AND_TRITON:
|
||||
num_partitions = get_num_partitions(code)
|
||||
self.assertEqual(num_partitions, 1)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
@torch._inductor.config.patch("implicit_fallbacks", True)
|
||||
def test_graph_partition_with_memory_plan_reuse(self):
|
||||
BATCH_SIZE = 16
|
||||
MLP_SIZE = 128
|
||||
HIDDEN_SIZE = 128
|
||||
RANDOM_SEED = 0
|
||||
|
||||
@torch.library.custom_op(
|
||||
"silly::attention",
|
||||
mutates_args=["out"],
|
||||
tags=(torch._C.Tag.cudagraph_unsafe,),
|
||||
)
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
out.copy_(q + k + v)
|
||||
|
||||
@attention.register_fake
|
||||
def _(q, k, v, out):
|
||||
return None
|
||||
|
||||
class ParentModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(self, mlp_size: int, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.pre_attn = torch.nn.Linear(mlp_size, hidden_size, bias=False)
|
||||
self.post_attn = torch.nn.Linear(hidden_size, mlp_size, bias=False)
|
||||
self.rms_norm_weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_f32 = x.float()
|
||||
return (
|
||||
x_f32
|
||||
* torch.rsqrt(
|
||||
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6
|
||||
)
|
||||
* self.rms_norm_weight
|
||||
).to(x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.pre_attn(x)
|
||||
x = self.rms_norm_ref(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.rms_norm_ref(x)
|
||||
x = self.post_attn(x)
|
||||
return x
|
||||
|
||||
class CompiledAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.attn = Attention(mlp_size, hidden_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x)
|
||||
|
||||
class CompiledAttentionTwo(CompiledAttention):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.attn(x) + x
|
||||
|
||||
class SimpleModelWithTwoGraphs(ParentModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mlp_size: int,
|
||||
hidden_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.attn_one = CompiledAttention(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
self.attn_two = CompiledAttentionTwo(
|
||||
mlp_size=mlp_size,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
|
||||
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bsz = x.shape[0]
|
||||
# CUDAGraph expects same tensor addresses for each run
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_one(self.hidden_states[:bsz])
|
||||
self.hidden_states[:bsz].copy_(x)
|
||||
x = self.attn_two(self.hidden_states[:bsz])
|
||||
return x
|
||||
|
||||
eager_model = (
|
||||
SimpleModelWithTwoGraphs(
|
||||
mlp_size=MLP_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
compiled_model = torch.compile(eager_model, mode="reduce-overhead")
|
||||
|
||||
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
|
||||
|
||||
for _ in range(3):
|
||||
eager_out = eager_model(inputs)
|
||||
compiled_out = compiled_model(inputs)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
@torch._inductor.config.patch("graph_partition", True)
|
||||
@torch._inductor.config.patch("triton.cudagraph_trees", False)
|
||||
def test_graph_partition_gc(self):
|
||||
|
@ -1808,7 +1808,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||
self.lines = MemoryPlanner(self).plan(self.lines)
|
||||
|
||||
def memory_plan_reuse(self):
|
||||
out_names = V.graph.get_output_names()
|
||||
outputs = self.get_graph_outputs()
|
||||
out_names = V.graph._get_output_names(outputs)
|
||||
|
||||
while (
|
||||
self.lines
|
||||
|
@ -2479,11 +2479,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
return mod
|
||||
|
||||
def get_output_names(self) -> list[str]:
|
||||
def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
|
||||
names = []
|
||||
shape_counter = itertools.count(0)
|
||||
none_counter = itertools.count(0)
|
||||
for node in self.graph_outputs:
|
||||
for node in graph_outputs:
|
||||
if isinstance(node, ir.NoneAsConstantBuffer):
|
||||
names.append(f"{self.name}_none{next(none_counter)}")
|
||||
elif isinstance(node, ir.ShapeAsConstantBuffer):
|
||||
@ -2492,6 +2492,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
names.append(node.get_name())
|
||||
return names
|
||||
|
||||
def get_output_names(self) -> list[str]:
|
||||
return self._get_output_names(self.graph_outputs)
|
||||
|
||||
def is_unspec_arg(self, name: str) -> bool:
|
||||
# dynamo wraps unspec variable as 0d CPU tensor,
|
||||
# need to convert to scalar during codegen (triton only)
|
||||
|
Reference in New Issue
Block a user