Files
pytorch/torch/_inductor/shape_propagation.py
Karthick Panner Selvam 130e50afff [Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).

Changes Included

- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.

Fixes #147282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos, https://github.com/atalman
2025-08-28 18:57:34 +00:00

146 lines
4.1 KiB
Python

import functools
from collections.abc import Sequence
from typing import Callable, Optional, Protocol, Union
import sympy
import torch
from .virtualized import OpsValue, V
BlockShapeType = Optional[Sequence[Union[int, str]]]
class ShapeVar(Protocol):
@property
def shape(self) -> BlockShapeType: ...
ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype]
# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective
# So first decompose CSEVars -> tuple before calling this
@functools.lru_cache(None)
def get_broadcasted_shape(a: BlockShapeType, b: BlockShapeType) -> BlockShapeType:
assert isinstance(a, Sequence)
assert isinstance(b, Sequence)
if len(a) > len(b):
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
elif len(a) < len(b):
b, a = a, b
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
else:
def _get_broadcasted_dim(
d1: Union[int, str], d2: Union[int, str]
) -> Union[int, str]:
if str(d1) == "1":
return d2
elif str(d2) == "1":
return d1
assert str(d1) == str(d2)
return d1
return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b))
def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> BlockShapeType:
result_shape: BlockShapeType = None
for arg in args:
if hasattr(arg, "shape"):
shape = arg.shape
if shape is None:
return None
elif result_shape is None:
result_shape = tuple(shape)
else:
result_shape = get_broadcasted_shape(result_shape, tuple(shape))
elif isinstance(arg, (int, float)):
if result_shape is None:
result_shape = ()
elif isinstance(arg, torch.dtype):
continue
else:
from torch._inductor.loop_body import LoopBody, LoopBodyBlock
if isinstance(arg, (LoopBodyBlock, LoopBody, OpsValue)):
# TODO: fix me
return None
raise TypeError(f"Unknown type: {type(arg)}")
return result_shape
class ShapePropagationOpsHandler:
"""
Propagate shape from args to output
"""
@staticmethod
def constant(value: torch.types.Number, dtype: torch.dtype) -> BlockShapeType:
# See implementation of constant for triton for the reason
from torch._inductor.codegen.triton import TritonKernel
if isinstance(V.kernel, TritonKernel):
ndim = V.kernel.triton_tensor_ndim()
return tuple([1] * ndim)
else:
return ()
@staticmethod
def store_reduction(name: str, index: int, value: ShapeArg) -> None:
return None
@staticmethod
def reduction(
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: str,
value: Union[ShapeArg, tuple[ShapeArg, ...]],
) -> Union[BlockShapeType, tuple[BlockShapeType, ...]]:
raise NotImplementedError
@staticmethod
def store(
name: str, index: int, value: ShapeArg, mode: Optional[str] = None
) -> None:
return None
@staticmethod
def to_dtype(
value: ShapeVar,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> BlockShapeType:
return value.shape
@staticmethod
def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> BlockShapeType:
# shape is implicitly embedded in expr.
return None
@staticmethod
def load_seed(name: str, offset: int) -> BlockShapeType:
return ()
@staticmethod
def indirect_indexing(
var: ShapeArg,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg: bool = True,
) -> None:
return None
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
@staticmethod
def device_assert_async(cond: ShapeArg, msg: str) -> None:
return None