mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Ensure large tensor int32 -> int64 indexing is enabled (#157767)"
This reverts commit fc69c2bc67672c3b2d0c62c1821895f09288f1c0. Reverted https://github.com/pytorch/pytorch/pull/157767 on behalf of https://github.com/atalman due to internal failure, sorry will revert ([comment](https://github.com/pytorch/pytorch/pull/157767#issuecomment-3224341111))
This commit is contained in:
@ -1514,21 +1514,17 @@ class TritonTemplate(KernelTemplate):
|
||||
|
||||
for name, val in kwargs.items():
|
||||
defines.write(f"{name} : tl.constexpr = {val}\n")
|
||||
defines = defines.getvalue()
|
||||
|
||||
fake_out = ir.Buffer(name="buf_out", layout=layout)
|
||||
kernel_name = f"triton_{self.name}"
|
||||
|
||||
numel = sympy_product(layout.size)
|
||||
buffers = itertools.chain(input_nodes, (fake_out,))
|
||||
|
||||
if TritonScheduling.can_use_32bit_indexing(numel, buffers):
|
||||
index_dtype = "tl.int32"
|
||||
else:
|
||||
index_dtype = "tl.int64"
|
||||
|
||||
# Add index dtype to defines so it's available in the template
|
||||
defines.write(f"INDEX_DTYPE : tl.constexpr = {index_dtype}\n")
|
||||
defines = defines.getvalue()
|
||||
if not TritonScheduling.can_use_32bit_indexing(numel, buffers):
|
||||
raise NotImplementedError(
|
||||
"64-bit indexing is not yet implemented for triton templates"
|
||||
)
|
||||
|
||||
kernel_options = {
|
||||
"input_nodes": input_nodes,
|
||||
|
Reference in New Issue
Block a user