Files
pytorch/torch/_inductor/runtime/hints.py
PyTorch MergeBot 2928c5c572 Revert "Pyrefly suppressions 2 (#165692)"
This reverts commit 43d78423ac224cce432bf34ed9627035169d5433.

Reverted https://github.com/pytorch/pytorch/pull/165692 on behalf of https://github.com/seemethere due to This is causing merge conflicts when attempting to land internally, see D84890919 for more details ([comment](https://github.com/pytorch/pytorch/pull/165692#issuecomment-3416397240))
2025-10-17 17:13:04 +00:00

220 lines
6.7 KiB
Python

# mypy: allow-untyped-defs
from __future__ import annotations
import collections
import functools
import typing
from enum import auto, Enum
from torch.utils._triton import has_triton_package
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
"R1_": 2048 * 16, # * 16 is multi-kernel only
}
TRITON_MAX_RSPLIT = 64
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
# Define `AttrsDescriptorWrapper` function with clear conditional handling
if has_triton_package():
import triton
import triton.backends.compiler
import triton.compiler.compiler
if hasattr(triton.backends.compiler, "AttrsDescriptor"):
# Triton 3.2.0 - the second implementation
from triton.backends.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"tt.divisibility": divisible_by_16,
"tt.equal_to": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
res = AttrsDescriptor.from_dict(
{"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
)
assert res.property_values["tt.divisibility"] == 16
assert res.property_values["tt.equal_to"] == 1
return res
elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
# Triton 3.0.0 - the original implementation
from triton.compiler.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"divisible_by_16": divisible_by_16,
"equal_to_1": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
return AttrsDescriptor(**kwargs)
else:
# Triton in 2025:
# note: there's also a range of triton commits not currently supported
# from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still
# used, but the contents are different.
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
"AttrsDescriptor",
["divisible_by_16", "equal_to_1"],
defaults=[(), ()],
)
_NUM_THREADS_PER_WARP = 32
class HeuristicType(Enum):
PERSISTENT_REDUCTION = auto()
POINTWISE = auto()
REDUCTION = auto()
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
FIXED = auto()
class AutotuneHint(Enum):
ONE_ELEMENT_PER_THREAD = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__
class DeviceProperties(typing.NamedTuple):
"""Copy device properties into a data structure not requiring torch to be imported"""
type: str # type: ignore[assignment]
index: int # type: ignore[assignment]
multi_processor_count: int
cc: int
major: int | None = None
regs_per_multiprocessor: int | None = None
max_threads_per_multi_processor: int | None = None
warp_size: int | None = None
@classmethod
@functools.cache
def create(cls, device) -> DeviceProperties:
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type
if torch.version.hip and device_type == "cuda":
device_type = "hip"
device_interface = get_interface_for_device(device)
props = device_interface.get_device_properties(device)
try:
multi_processor_count = props.multi_processor_count
except AttributeError:
if device_type == "xpu":
multi_processor_count = props.gpu_subslice_count
elif device_type == "mtia":
multi_processor_count = 64
else:
raise
return cls(
type=device_type,
index=device.index,
multi_processor_count=multi_processor_count,
cc=device_interface.get_compute_capability(device),
major=getattr(props, "major", None),
regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
max_threads_per_multi_processor=getattr(
props, "max_threads_per_multi_processor", None
),
warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
)
class HalideInputSpec(typing.NamedTuple):
ctype: str
name: str
shape: list[str] | None = None
stride: list[str] | None = None
offset: str | None = None
alias_of: str | None = None
def bindings_type(self) -> str:
if self.ctype in ("at::Half*", "at::BFloat16*"):
return "uint16_t*" # half not defined
return self.ctype
def halide_type(self) -> str:
if self.ctype == "at::Half*":
return "halide_type_t(halide_type_float, 16)" # half not defined
if self.ctype == "at::BFloat16*":
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
return f"halide_type_of<{self.ctype.replace('*', '')}>()"
def is_scalar(self) -> bool:
return self.shape is None
def is_buffer(self) -> bool:
return self.shape is not None
class HalideMeta(typing.NamedTuple):
argtypes: list[HalideInputSpec]
target: str
scheduler: str | None = None
scheduler_flags: dict[str, int | str] | None = None
cuda_device: int | None = None
def args(self) -> list[str]:
"""Command line args to pass to halide generator"""
args = [f"target={self.target}"]
if self.scheduler:
args.append(f"autoscheduler={self.scheduler}")
if self.scheduler_flags:
assert self.scheduler
for k, v in self.scheduler_flags.items():
args.append(f"autoscheduler.{k}={v}")
return args
def is_cuda(self) -> bool:
return self.cuda_device is not None