mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[inductor] use int64 for large index (#154575)
Split reduction may need add an extra mask to avoid invalid index. Previously we always uses torch.int32 dtype. That causes problem when the tensor numel exceeds 2^31. Fix https://github.com/pytorch/pytorch/issues/154168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154575 Approved by: https://github.com/ngimel, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f6e18bc105
commit
2596e3d061
@ -13477,6 +13477,35 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
ref = pad_same(x, (5, 5), (2, 2))
|
||||
self.assertEqual(res, ref, atol=0, rtol=0)
|
||||
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@largeTensorTest("16GB", inductor=True)
|
||||
def test_split_reduction_with_int64_size(self):
|
||||
if torch._inductor.config.cpu_backend == "triton":
|
||||
raise unittest.SkipTest(
|
||||
"Fail for triton cpu backend with error: https://gist.github.com/shunting314/a873fb32b6b7b5a437f44280ae86839f"
|
||||
)
|
||||
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest(
|
||||
"The test fails some times on CI: "
|
||||
"https://github.com/pytorch/pytorch/actions/runs/15333913377/job/43153170162. "
|
||||
"Skip for now."
|
||||
)
|
||||
|
||||
size = (30000, 100000)
|
||||
|
||||
# rand rather than randn since the mean for the latter is close to 0
|
||||
# which happens to be close to the value generated by the bug.
|
||||
t = torch.rand(size, dtype=torch.float, device=self.device)
|
||||
op = torch.mean
|
||||
expected = op(t)
|
||||
actual = torch.compile(op)(t)
|
||||
# self.common takes more GPU memory. Do the check dirctly
|
||||
self.assertTrue(
|
||||
torch.allclose(expected, actual, atol=1e-2, rtol=1e-2),
|
||||
f"{expected=} {actual=}",
|
||||
)
|
||||
|
||||
def test_remove_noop_view_default(self):
|
||||
def f(x):
|
||||
batch_size = x.shape[0]
|
||||
|
@ -61,7 +61,7 @@ test_failures = {
|
||||
"test_AllenaiLongformerBase_repro_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu")
|
||||
),
|
||||
"test_randint_distribution_dynamic_shapes": TestFailure(("cuda", "xpu")),
|
||||
"test_randint_distribution_dynamic_shapes": TestFailure(("xpu",)),
|
||||
}
|
||||
if not torch._inductor.config.cpp_wrapper:
|
||||
test_failures["test_conv_inference_heuristics_dynamic_shapes"] = TestFailure(
|
||||
|
@ -227,7 +227,12 @@ class IndexingOptions:
|
||||
|
||||
@property
|
||||
def mask_str(self) -> str:
|
||||
return " & ".join(map(str, self.mask_vars)) if self.mask_vars else "None"
|
||||
# The sorted call is added to make sure the order is still
|
||||
# deterministic if self.mask_vars contains mix of string
|
||||
# and TritonCSEVariable
|
||||
return (
|
||||
" & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None"
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -111,11 +111,34 @@ def signature_to_meta(
|
||||
size_dtype: Optional[str],
|
||||
argdefs: list[ArgName],
|
||||
indices: Optional[list[int]] = None,
|
||||
is_template: bool = False,
|
||||
) -> dict[str, str]:
|
||||
if indices is None:
|
||||
indices = list(range(len(signature)))
|
||||
|
||||
def _decide_tl_dtype(arg):
|
||||
# Even if the ks0 symbol itself is within tl.int32 range, it's
|
||||
# risky to use tl.int32 dtype since we may have ks0*ks1 later
|
||||
# for kernels like torch.mean when dynamic shape is enabled.
|
||||
#
|
||||
# Check config.triton.use_block_ptr, since Triton block pointer
|
||||
# does not support 64bit indexing:
|
||||
# https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6
|
||||
#
|
||||
# If the triton metadata is for a template, don't use tl.int64 index.
|
||||
# Templates like flex attention/decoding uses block pointers which
|
||||
# does not support 64 bit indexing.
|
||||
if (
|
||||
not config.triton.use_block_ptr
|
||||
and not is_template
|
||||
and isinstance(arg, SizeArg)
|
||||
and arg.name.startswith("ks")
|
||||
):
|
||||
return "tl.int64"
|
||||
return size_dtype
|
||||
|
||||
return {
|
||||
argdefs[i].name: signature_of(arg, size_dtype=size_dtype)
|
||||
argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg))
|
||||
for i, arg in zip(indices, signature)
|
||||
}
|
||||
|
||||
|
@ -90,6 +90,7 @@ from .utils import (
|
||||
convert_shape_to_symint,
|
||||
developer_warning,
|
||||
do_bench_using_profiling,
|
||||
dtype_from_size,
|
||||
get_dtype_size,
|
||||
get_kernel_metadata,
|
||||
GPU_ALIGN_BYTES,
|
||||
@ -1678,9 +1679,10 @@ class Reduction(Loops):
|
||||
return loader(new_index, reindex([indices]))
|
||||
|
||||
if need_mask:
|
||||
index_dtype = dtype_from_size(reduction_numel)
|
||||
mask = ops.lt(
|
||||
ops.index_expr(indices, torch.int32),
|
||||
ops.index_expr(reduction_numel, torch.int32),
|
||||
ops.index_expr(indices, index_dtype),
|
||||
ops.index_expr(reduction_numel, index_dtype),
|
||||
)
|
||||
return ops.masked(mask, body, default)
|
||||
else:
|
||||
|
@ -494,7 +494,10 @@ class TritonTemplateKernel(TritonKernel):
|
||||
argdefs, _, signature, _ = self.args.python_argdefs()
|
||||
triton_meta: dict[str, Any] = {
|
||||
"signature": signature_to_meta(
|
||||
signature, size_dtype=self.index_dtype, argdefs=argdefs
|
||||
signature,
|
||||
size_dtype=self.index_dtype,
|
||||
argdefs=argdefs,
|
||||
is_template=True,
|
||||
),
|
||||
"device": DeviceProperties.create(self.output_node.get_device()),
|
||||
"constants": {},
|
||||
|
@ -3129,3 +3129,14 @@ def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
|
||||
isinstance(wrapper, SubgraphPythonWrapperCodegen)
|
||||
and wrapper.partition_signatures is not None
|
||||
)
|
||||
|
||||
|
||||
def dtype_from_size(size: int) -> torch.dtype:
|
||||
from .virtualized import V
|
||||
|
||||
if V.graph.sizevars.statically_known_lt(
|
||||
size, 2**31
|
||||
) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
|
||||
return torch.int32
|
||||
else:
|
||||
return torch.int64
|
||||
|
Reference in New Issue
Block a user