[Inductor][CPP] Fix issue in CPP GEMM Template Prune Tensor (#141798)

**Summary**
When addressing [issue #134998](https://github.com/pytorch/pytorch/issues/134998), we will verify if any node in the current graph shares the same storage as the node we intend to prune. In the implementation, we assumed that when creating the `GraphLowering` in post-grad phase, there would be no `submodules`, and all `get_attr` nodes would correspond to a `torch.Tensor`. However, this assumption proves incorrect when enabling `FlexAttention`. In this scenario, `submodules` are present as `get_attr` node in post-grad phase. For example:

```
V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs]     class sdpa_score30(torch.nn.Module):
V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs]         def forward(self, arg0_1: "bf16[][]cpu", arg1_1: "i32[][]cpu", arg2_1: "i32[][]cpu", arg3_1: "i32[][]cpu", arg4_1: "i32[][]cpu"):
V1128 23:23:47.071000 1965794 torch/_inductor/compile_fx.py:875] [0/1] [__post_grad_graphs]             return arg0_1

V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1]         sdpa_score30 = self.sdpa_score30
V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1]         sdpa_mask30 = self.sdpa_mask30
V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1]         flex_attention_30 = torch.ops.higher_order.flex_attention(add_276, index_put_60, index_put_61, sdpa_score30, (_frozen_param293, _frozen_param295, _frozen_param296, _frozen_param297, _frozen_param298, _frozen_param299, _frozen_param300, _frozen_param301, 64, 64, sdpa_mask30), 0.08838834764831843, {'SKIP_MASK_SCORE': True, 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': False}, (), (_frozen_param294,));  add_276 = sdpa_score30 = sdpa_mask30 = None
V1128 23:23:45.482000 1965794 torch/_inductor/freezing.py:118] [0/1]         getitem_60: "bf16[1, 32, 1, 128]" = flex_attention_30[0];  flex_attention_30 = None
```
We added an extra check in the implementation to ensure only comparing the `get_attr` node with `torch.Tensor`. It is difficult to reproduce this issue using pure high-order operators. Adding a unit test after https://github.com/pytorch/pytorch/pull/141453 lands would be more straightforward.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141798
Approved by: https://github.com/jgong5
This commit is contained in:
leslie-fang-intel
2024-11-29 00:46:09 -08:00
committed by PyTorch MergeBot
parent 90f4d60672
commit 96d2a511ce

View File

@ -435,7 +435,9 @@ def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.Tens
V.graph.module, node.name
): # candidate tensor might already be deleted
comp_tensor = getattr(V.graph.module, node.name)
if share_storage(candidate_tensor, comp_tensor):
if isinstance(comp_tensor, torch.Tensor) and share_storage(
candidate_tensor, comp_tensor
):
candidate_tensor_users += 1
for node in reversed(V.graph.graph.nodes):