mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][CPP] Fix issue in CPP GEMM Template Prune Tensor (#141798)
**Summary** When addressing [issue #134998](https://github.com/pytorch/pytorch/issues/134998), we will verify if any node in the current graph shares the same storage as the node we intend to prune. In the implementation, we assumed that when creating the `GraphLowering` in post-grad phase, there would be no `submodules`, and all `get_attr` nodes would correspond to a `torch.Tensor`. However, this assumption proves incorrect when enabling `FlexAttention`. In this scenario, `submodules` are present as `get_attr` node in post-grad phase. For example: ``` V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] class sdpa_score30(torch.nn.Module): V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] def forward(self, arg0_1: "bf16[][]cpu", arg1_1: "i32[][]cpu", arg2_1: "i32[][]cpu", arg3_1: "i32[][]cpu", arg4_1: "i32[][]cpu"): V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs] return arg0_1 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] sdpa_score30 = self.sdpa_score30 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] sdpa_mask30 = self.sdpa_mask30 V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] flex_attention_30 = torch.ops.higher_order.flex_attention(add_276, index_put_60, index_put_61, sdpa_score30, (_frozen_param293, _frozen_param295, _frozen_param296, _frozen_param297, _frozen_param298, _frozen_param299, _frozen_param300, _frozen_param301, 64, 64, sdpa_mask30), 0.08838834764831843, {'SKIP_MASK_SCORE': True, 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': False}, (), (_frozen_param294,)); add_276 = sdpa_score30 = sdpa_mask30 = None V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1] getitem_60: "bf16[1, 32, 1, 128]" = flex_attention_30[0]; flex_attention_30 = None ``` We added an extra check in the implementation to ensure only comparing the `get_attr` node with `torch.Tensor`. It is difficult to reproduce this issue using pure high-order operators. Adding a unit test after https://github.com/pytorch/pytorch/pull/141453 lands would be more straightforward. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141798 Approved by: https://github.com/jgong5
This commit is contained in:
committed by
PyTorch MergeBot
parent
90f4d60672
commit
96d2a511ce
@ -435,7 +435,9 @@ def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.Tens
|
||||
V.graph.module, node.name
|
||||
): # candidate tensor might already be deleted
|
||||
comp_tensor = getattr(V.graph.module, node.name)
|
||||
if share_storage(candidate_tensor, comp_tensor):
|
||||
if isinstance(comp_tensor, torch.Tensor) and share_storage(
|
||||
candidate_tensor, comp_tensor
|
||||
):
|
||||
candidate_tensor_users += 1
|
||||
|
||||
for node in reversed(V.graph.graph.nodes):
|
||||
|
Reference in New Issue
Block a user