Files
pytorch/torch/testing/_internal/inductor_utils.py
drisspg 1750cc8037 Updates to CuTe DSL template renderer (#161117)
# Summary
This adds a few more render functions available to template writers, specifically get_output and modification. The reasons why are more clear in the next PR in this stack.

<img width="1645" height="364" alt="Screenshot 2025-08-21 at 1 48 50 PM" src="https://github.com/user-attachments/assets/2d508fda-4273-43ef-9edf-086e592e9249" />

Majority of the new cod is around the OpOverrides for CuTe DSL. It is alot to test and most of the actual testing I have been doing is via score_mods to the flash_attention at the next layer of this stack.

A bunch of score mods that me and Claude came up with , that exercise the actual ops.
``` Py

def causal_mask(score, b, h, q_idx, kv_idx):
    """Causal attention mask."""
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

def relative_bias(score, b, h, token_q, token_kv):
    """Relative position bias."""
    return score + torch.abs(token_q - token_kv)

def relative_bias_v2(score, b, h, token_q, token_kv):
    """Relative position bias with factor of 2."""
    return score + 2 * torch.abs(token_q - token_kv)

def times_two(score, b, h, q_idx, kv_idx):
    """Simple score modification that doubles the score."""
    return score * 2

def alibi_bias(score, b, h, q_idx, kv_idx):
    """ALiBi (Attention with Linear Biases) - used in some modern models."""
    # Different slopes for different heads
    slope = 2 ** (-8 * (h + 1) / 8)  # Simplified version
    return score - slope * torch.abs(q_idx - kv_idx)

def sliding_window(score, b, h, q_idx, kv_idx, window_size=256):
    """Sliding window attention - only attend to nearby tokens."""
    return torch.where(
        torch.abs(q_idx - kv_idx) <= window_size,
        score,
        float("-inf")
    )

def block_diagonal(score, b, h, q_idx, kv_idx, block_size=64):
    """Block diagonal attention pattern."""
    q_block = q_idx // block_size
    kv_block = kv_idx // block_size
    return torch.where(q_block == kv_block, score, float("-inf"))

def additive_bias(score, b, h, q_idx, kv_idx):
    """Test simple addition with position-based bias."""
    return score + (q_idx + kv_idx) * 0.01

def multiplicative_decay(score, b, h, q_idx, kv_idx):
    """Test multiplication with distance-based decay."""
    distance = torch.abs(q_idx - kv_idx)
    return score * torch.exp(-0.1 * distance)

def sine_wave_bias(score, b, h, q_idx, kv_idx):
    """Test trigonometric functions."""
    return score + 0.1 * torch.sin(2 * math.pi * (q_idx - kv_idx) / 64)

def log_distance_penalty(score, b, h, q_idx, kv_idx):
    """Test logarithmic operations."""
    distance = torch.abs(q_idx - kv_idx).float()
    return score - torch.log(1 + distance)

def alternating_mask(score, b, h, q_idx, kv_idx):
    """Test with alternating pattern - good for branch prediction."""
    return torch.where((q_idx + kv_idx) % 2 == 0, score, float("-inf"))

def head_specific_pattern(score, b, h, q_idx, kv_idx):
    """Different behavior per attention head."""
    even_head = h % 2 == 0
    causal = q_idx >= kv_idx
    return torch.where(even_head & causal, score, float("-inf"))

def sparse_strided(score, b, h, q_idx, kv_idx, stride=4):
    """Sparse attention with strided pattern."""
    return torch.where(
        (kv_idx % stride == 0) | (q_idx == kv_idx),
        score,
        float("-inf")
    )

def causal_with_global(score, b, h, q_idx, kv_idx):
    """Causal mask but first few tokens are globally attended."""
    is_causal = q_idx >= kv_idx
    is_global = kv_idx < 4
    return torch.where(is_causal | is_global, score, float("-inf"))

def dilated_attention(score, b, h, q_idx, kv_idx, dilation_rate=2):
    """Dilated attention pattern - exponentially increasing gaps."""
    distance = torch.abs(q_idx - kv_idx)
    is_attended = (distance == 0) | ((distance > 0) & ((distance & (distance - 1)) == 0))
    return torch.where(is_attended, score, float("-inf"))

```

Example outputs:
```
[Test Suite]
Config: batch=4, heads=32, seq_q=8192, seq_kv=8192, dim=128

[Test 1: none]
[No score_mod, flash='enabled'] Found flash_attncute: True
[No score_mod, flash='disabled'] Found flash_attncute: False
✓ Outputs match between flash enabled/disabled
✓ Output matches eager SDPA (rtol=0.001, atol=0.001)

[Test 2: causal]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 17879 / 134217728 (0.0%)
Greatest absolute difference: 0.0078125 at index (0, 15, 15, 60) (up to 0.001 allowed)
Greatest relative difference: 2.5 at index (3, 22, 153, 126) (up to 0.001 allowed)

[Test 3: rel_bias]
[With score_mod, flash='enabled'] Found flash_attncute: True
[With score_mod, flash='disabled'] Found flash_attncute: False
✗ Outputs differ between flash modes: Tensor-likes are not close!

Mismatched elements: 12836 / 134217728 (0.0%)
Greatest absolute difference: 0.015625 at index (0, 3, 2775, 84) (up to 0.001 allowed)
Greatest relative difference: 11.8125 at index (3, 28, 4095, 76) (up to 0.001 allowed)

[Test 4: rel_bias_v2]
```

This is bfloat16 and there are no major differences. The list of pointwise ops here isn't exhaustive but it is fairly covering

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161117
Approved by: https://github.com/mlazos
2025-08-27 18:39:09 +00:00

424 lines
14 KiB
Python

# mypy: ignore-errors
import logging
import torch
import re
import unittest
import functools
import contextlib
import os
from subprocess import CalledProcessError
import sys
from typing import Any, Optional
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch.fx.experimental.proxy_tensor import make_fx
from torch._dynamo.device_interface import get_interface_for_device
from torch._inductor.graph import GraphLowering
from torch._inductor.compile_fx import shape_env_from_inputs
from torch._inductor.utils import OrderedSet
from torch._inductor.codecache import CppCodeCache
from torch._inductor.custom_graph_pass import CustomGraphModulePass
from torch._inductor.codegen.common import (
get_custom_backend_config_for_device,
get_custom_backend_pass_for_device,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
init_backend_registration,
register_backend_for_device
)
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.utils import get_gpu_shared_memory, is_big_gpu
from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu
from torch.utils._helion import has_helion
from torch.utils._triton import has_triton
from torch.utils._config_module import ConfigModule
from torch.testing._internal.common_device_type import (
get_desired_device_type_test_bases,
)
from torch.testing._internal.common_utils import (
LazyVal,
IS_FBCODE,
TestCase,
IS_CI,
IS_WINDOWS,
)
log: logging.Logger = logging.getLogger(__name__)
def test_cpu():
try:
CppCodeCache.load("")
return not IS_FBCODE
except (
CalledProcessError,
OSError,
torch._inductor.exc.InvalidCxxCompiler,
torch._inductor.exc.CppCompileError,
):
return False
HAS_CPU = LazyVal(test_cpu)
HAS_TRITON = has_triton()
HAS_HELION = has_helion()
if HAS_TRITON:
import triton
TRITON_HAS_CPU = "cpu" in triton.backends.backends
else:
TRITON_HAS_CPU = False
HAS_CUDA_AND_TRITON = torch.cuda.is_available() and HAS_TRITON
HAS_XPU_AND_TRITON = torch.xpu.is_available() and HAS_TRITON
HAS_MPS = torch.mps.is_available()
HAS_GPU = HAS_CUDA_AND_TRITON or HAS_XPU_AND_TRITON
GPU_TYPE = get_gpu_type()
HAS_MULTIGPU = any(
getattr(torch, gpu).is_available() and getattr(torch, gpu).device_count() >= 2
for gpu in GPU_TYPES
)
_desired_test_bases = get_desired_device_type_test_bases(allow_xpu=True)
RUN_GPU = (
HAS_GPU
and any(is_gpu(getattr(x, "device_type", "")) for x in _desired_test_bases)
)
RUN_CPU = (
HAS_CPU
and any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
)
def _check_has_dynamic_shape(
self: TestCase,
code,
):
for_loop_found = False
has_dynamic = False
lines = code.split("\n")
for line in lines:
if "for(" in line:
for_loop_found = True
if re.search(r";.*ks.*;", line) is not None:
has_dynamic = True
break
self.assertTrue(
has_dynamic, msg=f"Failed to find dynamic for loop variable\n{code}"
)
self.assertTrue(for_loop_found, f"Failed to find for loop\n{code}")
def skipDeviceIf(cond, msg, *, device):
if cond:
def decorate_fn(fn):
@functools.wraps(fn)
def inner(self, *args, **kwargs):
if not hasattr(self, "device"):
warn_msg = "Expect the test class to have attribute device but not found. "
if hasattr(self, "device_type"):
warn_msg += "Consider using the skip device decorators in common_device_type.py"
log.warning(warn_msg)
if self.device == device:
raise unittest.SkipTest(msg)
return fn(self, *args, **kwargs)
return inner
else:
def decorate_fn(fn):
return fn
return decorate_fn
def skip_windows_ci(name: str, file: str) -> None:
if IS_WINDOWS and IS_CI:
module = os.path.basename(file).strip(".py")
sys.stderr.write(
f"Windows CI does not have necessary dependencies for {module} tests yet\n"
)
if name == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
# TODO: Remove HAS_MPS condition when `HAS_GPU` includes HAS_MPS
requires_gpu = functools.partial(unittest.skipIf, not (HAS_GPU or HAS_MPS), "requires gpu")
requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton")
requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion")
def requires_cuda_with_enough_memory(min_mem_required):
def inner(fn):
if not torch.cuda.is_available() or torch.cuda.get_device_properties().total_memory < min_mem_required:
return unittest.skip(f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe")(fn)
else:
return fn
return inner
skipCUDAIf = functools.partial(skipDeviceIf, device="cuda")
skipXPUIf = functools.partial(skipDeviceIf, device="xpu")
skipCPUIf = functools.partial(skipDeviceIf, device="cpu")
IS_A100 = LazyVal(
lambda: HAS_CUDA_AND_TRITON
and get_gpu_shared_memory() == 166912
)
IS_H100 = LazyVal(
lambda: HAS_CUDA_AND_TRITON
and get_gpu_shared_memory() == 232448
)
IS_BIG_GPU = LazyVal(lambda: HAS_CUDA_AND_TRITON and is_big_gpu())
def dummy_graph() -> GraphLowering:
"""
Create a graph. This is useful for unit testing code which accesses
V.graph.sizevars.
"""
example_inputs = [torch.randn(10) for _ in range(2)]
gm = make_fx(torch.add, tracing_mode="fake")(*example_inputs)
shape_env = shape_env_from_inputs(example_inputs)
graph = GraphLowering(
gm,
shape_env=shape_env,
)
return graph
def maybe_skip_size_asserts(op):
"""
For certain ops, there meta and eager implementation returns different
strides. This cause size/strides assert fail. Skip adding those
asserts for now.
"""
if (
op.aten_name
in (
"fft_hfftn",
"fft_hfft",
"fft_hfft2",
"fft_ihfftn",
"fft_fft",
"fft_fft2",
"fft_fftn",
"fft_ifft",
"fft_ifft2",
"fft_ifftn",
"fft_irfft",
"fft_irfft2",
"fft_irfftn",
"fft_ihfft",
"fft_ihfft2",
"fft_rfft",
"fft_rfft2",
"fft_rfftn",
"linalg_eig",
"linalg_eigvals",
)
and "TORCHINDUCTOR_SIZE_ASSERTS" not in os.environ
):
return torch._inductor.config.patch(size_asserts=False)
else:
return contextlib.nullcontext()
def get_func_call() -> str:
return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call("
def get_kernel_launch() -> str:
return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run("
def clone_preserve_strides_offset(x, device=None):
if not isinstance(x, torch.Tensor):
return x
buffer = torch.as_strided(
x, (x.untyped_storage().size() // x.element_size(),), (1,), 0
)
if not device:
buffer = buffer.clone()
else:
buffer = buffer.to(device, copy=True)
out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
return out
# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
FP16_MAX_POS: float = torch.finfo(torch.float16).max
EPS: float = 1e-12
Tensor = torch.Tensor
def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
else:
raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)
@torch.no_grad()
def _amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
) -> torch.Tensor:
# To make scale dtype to be fp32 for accuracy
amax = amax.float()
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
return res
def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x))
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
scale = _amax_to_scale(amax, float8_dtype, x.dtype)
x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
inverse_scale = scale.reciprocal()
return x_fp8, inverse_scale
@contextlib.contextmanager
def patch_inductor_backend(
device: str,
python_wrapper_codegen: PythonWrapperCodegen = None,
custom_pass: CustomGraphModulePass = None,
custom_backend_config: ConfigModule = None
):
"""
Patch the inductor backend for a specific device.
"""
# Make sure the backend is already registered
init_backend_registration()
# Get the original registration parameters
original_scheduling = get_scheduling_for_device(device)
original_python_wrapper = get_wrapper_codegen_for_device(device, False)
original_cpp_wrapper = get_wrapper_codegen_for_device(device, True)
original_custom_pass = get_custom_backend_pass_for_device(device)
original_custom_backend_config = get_custom_backend_config_for_device(device)
try:
# Register modified backend for the device
register_backend_for_device(
device,
original_scheduling,
python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper,
original_cpp_wrapper,
custom_pass if custom_pass is not None else original_custom_pass,
custom_backend_config if custom_backend_config is not None else original_custom_backend_config
)
yield
finally:
# Restore the original backend
register_backend_for_device(
device,
original_scheduling,
original_python_wrapper,
original_cpp_wrapper,
original_custom_pass,
original_custom_backend_config
)
def backend_for_device(device: str) -> Optional[str]:
""" Get the Inductor codegen backend used for the device ``device``. """
if dev_int := get_interface_for_device(device):
return dev_int.inductor_backend()
return None
def try_patch_inductor_backend_config(device: str, key: str,
value: Any) -> contextlib.ContextDecorator:
"""
Try to patch the backend-specific Inductor options, for the codegen backend
corresponding to the given ``device``. If that config can't be found to
patch, skip the test.
Will patch the member of the global ``config.$BACKEND``, if it exists. If
the given device also specifies a custom config module, will also try to
patch its ``$BACKEND`` member if it exists.
"""
device_backend = backend_for_device(device)
if device_backend is None:
return unittest.skip(
f"Can't patch Inductor config {key} for device {device}")
config_modules = [torch._inductor.config]
if custom_config_module := get_custom_backend_config_for_device(device):
config_modules.append(custom_config_module)
contexts: list[contextlib.ContextDecorator] = []
for mod in config_modules:
if (
hasattr(mod, f"{device_backend}")
and hasattr(mod, f"{device_backend}.{key}")
):
contexts.append(mod.patch(f"{device_backend}.{key}", value))
if len(contexts) == 0:
return unittest.skip(
f"Can't patch Inductor config {key} for device {device}")
class ContextStack(contextlib.ContextDecorator):
def __init__(self, contexts: list[contextlib.ContextDecorator]) -> None:
self.contexts: list[contextlib.ContextDecorator] = contexts
def __enter__(self) -> None:
for cd in self.contexts:
cd.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def]
for cd in self.contexts:
cd.__exit__(exc_type, exc_val, exc_tb)
return ContextStack(contexts)
class MockGraphHandler(GraphLowering):
"""Minimal mock graph handler for testing virtualized context."""
def __init__(self, name_to_buffer=None):
import torch._inductor.sizevars
self.sizevars = torch._inductor.sizevars.SizeVarAllocator()
self.name_to_buffer = name_to_buffer or {}
self.graph_inputs = {}
self.mutated_buffers = OrderedSet()
self.removed_buffers = OrderedSet()
self.constants = {}
self.scheduler = None
def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002
"""Return default dtype for any buffer (for testing)."""
return torch.float32