Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"

This reverts commit fc69c2bc67672c3b2d0c62c1821895f09288f1c0.

Reverted https://github.com/pytorch/pytorch/pull/157767 on behalf of https://github.com/atalman due to internal failure, sorry will revert ([comment](https://github.com/pytorch/pytorch/pull/157767#issuecomment-3224341111))
This commit is contained in:
PyTorch MergeBot
2025-08-26 14:12:06 +00:00
parent ae8d319fd4
commit 818ba434c7
6 changed files with 120 additions and 98 deletions

View File

@ -1514,21 +1514,17 @@ class TritonTemplate(KernelTemplate):
for name, val in kwargs.items():
defines.write(f"{name} : tl.constexpr = {val}\n")
defines = defines.getvalue()
fake_out = ir.Buffer(name="buf_out", layout=layout)
kernel_name = f"triton_{self.name}"
numel = sympy_product(layout.size)
buffers = itertools.chain(input_nodes, (fake_out,))
if TritonScheduling.can_use_32bit_indexing(numel, buffers):
index_dtype = "tl.int32"
else:
index_dtype = "tl.int64"
# Add index dtype to defines so it's available in the template
defines.write(f"INDEX_DTYPE : tl.constexpr = {index_dtype}\n")
defines = defines.getvalue()
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
raise NotImplementedError(
"64-bit indexing is not yet implemented for triton templates"
)
kernel_options = {
"input_nodes": input_nodes,