[inductor] use int64 for large index (#154575)

Split reduction may need add an extra mask to avoid invalid index. Previously we always uses torch.int32 dtype. That causes problem when the tensor numel exceeds 2^31.

Fix https://github.com/pytorch/pytorch/issues/154168

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154575
Approved by: https://github.com/ngimel, https://github.com/jansel
This commit is contained in:
Shunting Zhang
2025-06-09 14:10:25 -07:00
committed by PyTorch MergeBot
parent 0f47e76937
commit 0b677560e6
8 changed files with 91 additions and 7 deletions

View File

@ -78,7 +78,18 @@ class TestCase(InductorTestCase):
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
out = run(op.get_op(), args, kwargs)
out_c = torch.compile(run)(op.get_op(), args, kwargs)
# test_configs.runtime_triton_dtype_assert does not work well with dynamic shape so far.
# Consider the following cases for torch.add:
# both lhs/rhs are int32 tensor, there is also a integer alpha argument.
# In dynamic shape case, alpha is passed in as an ks0 argument. To be safe,
# we use tl.int64 for ks0's dtype.
# But the dtype for alpha is also decided as tl.int32 during lowering when
# we promote alpha to a ir.Constant.
# Ideally to resolve this problem, we should track assignment like
# alpha = ks0
# so that we know alpha is actually tl.int64 rather than tl.int32.
out_c = torch.compile(run, dynamic=False)(op.get_op(), args, kwargs)
self.assertEqual(out, out_c)
@requires_gpu()

View File

@ -13476,6 +13476,35 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
ref = pad_same(x, (5, 5), (2, 2))
self.assertEqual(res, ref, atol=0, rtol=0)
@skip_if_halide # only 32-bit indexing
@largeTensorTest("16GB", inductor=True)
def test_split_reduction_with_int64_size(self):
if torch._inductor.config.cpu_backend == "triton":
raise unittest.SkipTest(
"Fail for triton cpu backend with error: https://gist.github.com/shunting314/a873fb32b6b7b5a437f44280ae86839f"
)
if self.device == "cpu":
raise unittest.SkipTest(
"The test fails some times on CI: "
"https://github.com/pytorch/pytorch/actions/runs/15333913377/job/43153170162. "
"Skip for now."
)
size = (30000, 100000)
# rand rather than randn since the mean for the latter is close to 0
# which happens to be close to the value generated by the bug.
t = torch.rand(size, dtype=torch.float, device=self.device)
op = torch.mean
expected = op(t)
actual = torch.compile(op)(t)
# self.common takes more GPU memory. Do the check dirctly
self.assertTrue(
torch.allclose(expected, actual, atol=1e-2, rtol=1e-2),
f"{expected=} {actual=}",
)
def test_remove_noop_view_default(self):
def f(x):
batch_size = x.shape[0]

View File

@ -58,7 +58,7 @@ test_failures = {
"test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu")
),
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")),
"test_randint_distribution_dynamic_shapes": TestFailure(("xpu",)),
}
if not torch._inductor.config.cpp_wrapper:
test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure(

View File

@ -227,7 +227,12 @@ class IndexingOptions:
@property
def mask_str(self) -> str:
return " & ".join(map(str, self.mask_vars)) if self.mask_vars else "None"
# The sorted call is added to make sure the order is still
# deterministic if self.mask_vars contains mix of string
# and TritonCSEVariable
return (
" & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None"
)
@dataclasses.dataclass

View File

@ -111,11 +111,34 @@ def signature_to_meta(
size_dtype: Optional[str],
argdefs: list[ArgName],
indices: Optional[list[int]] = None,
is_template: bool = False,
) -> dict[str, str]:
if indices is None:
indices = list(range(len(signature)))
def _decide_tl_dtype(arg):
# Even if the ks0 symbol itself is within tl.int32 range, it's
# risky to use tl.int32 dtype since we may have ks0*ks1 later
# for kernels like torch.mean when dynamic shape is enabled.
#
# Check config.triton.use_block_ptr, since Triton block pointer
# does not support 64bit indexing:
# https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6
#
# If the triton metadata is for a template, don't use tl.int64 index.
# Templates like flex attention/decoding uses block pointers which
# does not support 64 bit indexing.
if (
not config.triton.use_block_ptr
and not is_template
and isinstance(arg, SizeArg)
and arg.name.startswith("ks")
):
return "tl.int64"
return size_dtype
return {
argdefs[i].name: signature_of(arg, size_dtype=size_dtype)
argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg))
for i, arg in zip(indices, signature)
}

View File

@ -90,6 +90,7 @@ from .utils import (
convert_shape_to_symint,
developer_warning,
do_bench_using_profiling,
dtype_from_size,
get_dtype_size,
get_kernel_metadata,
GPU_ALIGN_BYTES,
@ -1678,9 +1679,10 @@ class Reduction(Loops):
return loader(new_index, reindex([indices]))
if need_mask:
index_dtype = dtype_from_size(reduction_numel)
mask = ops.lt(
ops.index_expr(indices, torch.int32),
ops.index_expr(reduction_numel, torch.int32),
ops.index_expr(indices, index_dtype),
ops.index_expr(reduction_numel, index_dtype),
)
return ops.masked(mask, body, default)
else:

View File

@ -494,7 +494,10 @@ class TritonTemplateKernel(TritonKernel):
argdefs, _, signature, _ = self.args.python_argdefs()
triton_meta: dict[str, Any] = {
"signature": signature_to_meta(
signature, size_dtype=self.index_dtype, argdefs=argdefs
signature,
size_dtype=self.index_dtype,
argdefs=argdefs,
is_template=True,
),
"device": DeviceProperties.create(self.output_node.get_device()),
"constants": {},

View File

@ -3120,3 +3120,14 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
isinstance(wrapper, SubgraphPythonWrapperCodegen)
and wrapper.partition_signatures is not None
)
def dtype_from_size(size: int) -> torch.dtype:
from .virtualized import V
if V.graph.sizevars.statically_known_lt(
size, 2**31
) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
return torch.int32
else:
return torch.int64