mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084). Changes Included - Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination. - Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor. - Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler. - Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code. - Added test cases to verify both "should throw" and "should not throw" scenarios. Fixes #147282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677 Approved by: https://github.com/mlazos, https://github.com/atalman
146 lines
4.1 KiB
Python
146 lines
4.1 KiB
Python
import functools
|
|
from collections.abc import Sequence
|
|
from typing import Callable, Optional, Protocol, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
|
|
from .virtualized import OpsValue, V
|
|
|
|
|
|
BlockShapeType = Optional[Sequence[Union[int, str]]]
|
|
|
|
|
|
class ShapeVar(Protocol):
|
|
@property
|
|
def shape(self) -> BlockShapeType: ...
|
|
|
|
|
|
ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype]
|
|
|
|
# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective
|
|
# So first decompose CSEVars -> tuple before calling this
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def get_broadcasted_shape(a: BlockShapeType, b: BlockShapeType) -> BlockShapeType:
|
|
assert isinstance(a, Sequence)
|
|
assert isinstance(b, Sequence)
|
|
if len(a) > len(b):
|
|
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
|
|
elif len(a) < len(b):
|
|
b, a = a, b
|
|
return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b))
|
|
else:
|
|
|
|
def _get_broadcasted_dim(
|
|
d1: Union[int, str], d2: Union[int, str]
|
|
) -> Union[int, str]:
|
|
if str(d1) == "1":
|
|
return d2
|
|
elif str(d2) == "1":
|
|
return d1
|
|
assert str(d1) == str(d2)
|
|
return d1
|
|
|
|
return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b))
|
|
|
|
|
|
def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> BlockShapeType:
|
|
result_shape: BlockShapeType = None
|
|
|
|
for arg in args:
|
|
if hasattr(arg, "shape"):
|
|
shape = arg.shape
|
|
if shape is None:
|
|
return None
|
|
elif result_shape is None:
|
|
result_shape = tuple(shape)
|
|
else:
|
|
result_shape = get_broadcasted_shape(result_shape, tuple(shape))
|
|
elif isinstance(arg, (int, float)):
|
|
if result_shape is None:
|
|
result_shape = ()
|
|
elif isinstance(arg, torch.dtype):
|
|
continue
|
|
else:
|
|
from torch._inductor.loop_body import LoopBody, LoopBodyBlock
|
|
|
|
if isinstance(arg, (LoopBodyBlock, LoopBody, OpsValue)):
|
|
# TODO: fix me
|
|
return None
|
|
raise TypeError(f"Unknown type: {type(arg)}")
|
|
|
|
return result_shape
|
|
|
|
|
|
class ShapePropagationOpsHandler:
|
|
"""
|
|
Propagate shape from args to output
|
|
"""
|
|
|
|
@staticmethod
|
|
def constant(value: torch.types.Number, dtype: torch.dtype) -> BlockShapeType:
|
|
# See implementation of constant for triton for the reason
|
|
from torch._inductor.codegen.triton import TritonKernel
|
|
|
|
if isinstance(V.kernel, TritonKernel):
|
|
ndim = V.kernel.triton_tensor_ndim()
|
|
return tuple([1] * ndim)
|
|
else:
|
|
return ()
|
|
|
|
@staticmethod
|
|
def store_reduction(name: str, index: int, value: ShapeArg) -> None:
|
|
return None
|
|
|
|
@staticmethod
|
|
def reduction(
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: str,
|
|
value: Union[ShapeArg, tuple[ShapeArg, ...]],
|
|
) -> Union[BlockShapeType, tuple[BlockShapeType, ...]]:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def store(
|
|
name: str, index: int, value: ShapeArg, mode: Optional[str] = None
|
|
) -> None:
|
|
return None
|
|
|
|
@staticmethod
|
|
def to_dtype(
|
|
value: ShapeVar,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types: bool = True,
|
|
) -> BlockShapeType:
|
|
return value.shape
|
|
|
|
@staticmethod
|
|
def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> BlockShapeType:
|
|
# shape is implicitly embedded in expr.
|
|
return None
|
|
|
|
@staticmethod
|
|
def load_seed(name: str, offset: int) -> BlockShapeType:
|
|
return ()
|
|
|
|
@staticmethod
|
|
def indirect_indexing(
|
|
var: ShapeArg,
|
|
size: Union[sympy.Expr, int],
|
|
check: bool = True,
|
|
wrap_neg: bool = True,
|
|
) -> None:
|
|
return None
|
|
|
|
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
|
|
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
|
|
|
|
@staticmethod
|
|
def device_assert_async(cond: ShapeArg, msg: str) -> None:
|
|
return None
|