mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] use int64 for large index (#154575)
Split reduction may need add an extra mask to avoid invalid index. Previously we always uses torch.int32 dtype. That causes problem when the tensor numel exceeds 2^31. Fix https://github.com/pytorch/pytorch/issues/154168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154575 Approved by: https://github.com/ngimel, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6e18bc105
commit
2596e3d061
@ -3129,3 +3129,14 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
|
||||
isinstance(wrapper, SubgraphPythonWrapperCodegen)
|
||||
and wrapper.partition_signatures is not None
|
||||
)
|
||||
|
||||
|
||||
def dtype_from_size(size: int) -> torch.dtype:
|
||||
from .virtualized import V
|
||||
|
||||
if V.graph.sizevars.statically_known_lt(
|
||||
size, 2**31
|
||||
) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
|
||||
return torch.int32
|
||||
else:
|
||||
return torch.int64
|
||||
|
Reference in New Issue
Block a user