Files
pytorch/torch/testing/_internal/inductor_utils.py
Mwiza Kunda d3d9bc1c31 [inductor] Allow backends to register their own custom config object (#158254)
An out of tree backend can have its own configuration options that the user can enable to control inductor compilation. These config options need to be taken into account when calculating the key that is used to determine cache miss / hits. This PR allows out of tree backends to specify a custom config module that has the same type as `torch._inductor.config` that can be used to control codegen (in addition to the default config), and will be used when creating the cache key.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158254
Approved by: https://github.com/eellison
2025-07-23 15:56:06 +00:00

350 lines
11 KiB
Python

# mypy: ignore-errors
import logging
import torch
import re
import unittest
import functools
import contextlib
import os
from subprocess import CalledProcessError
import sys
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch.fx.experimental.proxy_tensor import make_fx
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
get_custom_backend_config_for_device,
get_custom_backend_pass_for_device,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
init_backend_registration,
register_backend_for_device
)
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
from torch.utils._helion import has_helion
from torch.utils._triton import has_triton
from torch.utils._config_module import ConfigModule
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
from torch.testing._internal.common_utils import (
LazyVal,
IS_FBCODE,
)
from torch.testing._internal.common_utils import (
TestCase,
IS_CI,
IS_WINDOWS,
)
log: logging.Logger = logging.getLogger(__name__)
def test_cpu():
try:
CppCodeCache.load("")
return not IS_FBCODE
except (
CalledProcessError,
OSError,
torch._inductor.exc.InvalidCxxCompiler,
torch._inductor.exc.CppCompileError,
):
return False
HAS_CPU = LazyVal(test_cpu)
HAS_TRITON = has_triton()
HAS_HELION = has_helion()
if HAS_TRITON:
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
HAS_CUDA = torch.cuda.is_available() and HAS_TRITON
HAS_XPU = torch.xpu.is_available() and HAS_TRITON
HAS_MPS = torch.mps.is_available()
HAS_GPU = HAS_CUDA or HAS_XPU
GPU_TYPE = get_gpu_type()
HAS_MULTIGPU = any(
getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2
for gpu in GPU_TYPES
)
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
RUN_GPU = (
HAS_GPU
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
)
RUN_CPU = (
HAS_CPU
and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
)
def _check_has_dynamic_shape(
self: TestCase,
code,
):
for_loop_found = False
has_dynamic = False
lines = code.split("\n")
for line in lines:
if "for(" in line:
for_loop_found = True
if re.search(r";.*ks.*;", line) is not None:
has_dynamic = True
break
self.assertTrue(
has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}"
)
self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}")
def skipDeviceIf(cond, msg, *, device):
if cond:
def decorate_fn(fn):
@functools.wraps(fn)
def inner(self, *args, **kwargs):
if not hasattr(self, "device"):
warn_msg = "Expect the test class to have attribute device but not found. "
if hasattr(self, "device_type"):
warn_msg += "Consider using the skip device decorators in common_device_type.py"
log.warning(warn_msg)
if self.device == device:
raise unittest.SkipTest(msg)
return fn(self, *args, **kwargs)
return inner
else:
def decorate_fn(fn):
return fn
return decorate_fn
def skip_windows_ci(name: str, file: str) -> None:
if IS_WINDOWS and IS_CI:
module = os.path.basename(file).strip(".py")
sys.stderr.write(
f"Windows CI does not have necessary dependencies for {module} tests yet\n"
)
if name == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS
requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion")
def requires_cuda_with_enough_memory(min_mem_required):
def inner(fn):
if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required:
return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn)
else:
return fn
return inner
skipCUDAIf = functools.partial(skipDeviceIf, device="cuda")
skipXPUIf = functools.partial(skipDeviceIf, device="xpu")
skipCPUIf = functools.partial(skipDeviceIf, device="cpu")
IS_A100 = LazyVal(
lambda: HAS_CUDA
and get_gpu_shared_memory() == 166912
)
IS_H100 = LazyVal(
lambda: HAS_CUDA
and get_gpu_shared_memory() == 232448
)
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA and is_big_gpu())
def dummy_graph() -> GraphLowering:
"""
Create a graph. This is useful for unit testing code which accesses
V.graph.sizevars.
"""
example_inputs = [torch.randn(10) for _ in range(2)]
gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
shape_env = shape_env_from_inputs(example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
)
return graph
def maybe_skip_size_asserts(op):
"""
For certain ops, there meta and eager implementation returns different
strides. This cause size/strides assert fail. Skip adding those
asserts for now.
"""
if (
op.aten_name
in (
"fft_hfftn",
"fft_hfft",
"fft_hfft2",
"fft_ihfftn",
"fft_fft",
"fft_fft2",
"fft_fftn",
"fft_ifft",
"fft_ifft2",
"fft_ifftn",
"fft_irfft",
"fft_irfft2",
"fft_irfftn",
"fft_ihfft",
"fft_ihfft2",
"fft_rfft",
"fft_rfft2",
"fft_rfftn",
"linalg_eig",
"linalg_eigvals",
)
and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
):
return torch._inductor.config.patch(size_asserts=False)
else:
return contextlib.nullcontext()
def get_func_call() -> str:
return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
def get_kernel_launch() -> str:
return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
def clone_preserve_strides_offset(x, device=None):
if not isinstance(x, torch.Tensor):
return x
buffer = torch.as_strided(
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
)
if not device:
buffer = buffer.clone()
else:
buffer = buffer.to(device, copy=True)
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12
Tensor = torch.Tensor
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
else:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
) -> torch.Tensor:
# To make scale dtype to be fp32 for accuracy
amax = amax.float()
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
@contextlib.contextmanager
def patch_inductor_backend(
device: str,
python_wrapper_codegen: PythonWrapperCodegen = None,
custom_pass: CustomGraphModulePass = None,
custom_backend_config: ConfigModule = None
):
"""
Patch the inductor backend for a specific device.
"""
# Make sure the backend is already registered
init_backend_registration()
# Get the original registration parameters
original_scheduling = get_scheduling_for_device(device)
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
original_custom_pass = get_custom_backend_pass_for_device(device)
original_custom_backend_config = get_custom_backend_config_for_device(device)
try:
# Register modified backend for the device
register_backend_for_device(
device,
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
original_cpp_wrapper,
custom_pass if custom_pass is not None else original_custom_pass,
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
)
yield
finally:
# Restore the original backend
register_backend_for_device(
device,
original_scheduling,
original_python_wrapper,
original_cpp_wrapper,
original_custom_pass,
original_custom_backend_config
)