[inductor] use int64 for large index (#154575)

Split reduction may need add an extra mask to avoid invalid index. Previously we always uses torch.int32 dtype. That causes problem when the tensor numel exceeds 2^31.

Fix https://github.com/pytorch/pytorch/issues/154168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154575
Approved by: https://github.com/ngimel, https://github.com/jansel
This commit is contained in:
Shunting Zhang
2025-06-06 11:46:52 -07:00
committed by PyTorch MergeBot
parent f6e18bc105
commit 2596e3d061
7 changed files with 79 additions and 6 deletions

View File

@ -3129,3 +3129,14 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
isinstance(wrapper, SubgraphPythonWrapperCodegen)
and wrapper.partition_signatures is not None
)
def dtype_from_size(size: int) -> torch.dtype:
from .virtualized import V
if V.graph.sizevars.statically_known_lt(
size, 2**31
) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
return torch.int32
else:
return torch.int64