[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)

This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).

Changes Included

- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.

Fixes #147282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos, https://github.com/atalman
This commit is contained in:
Karthick Panner Selvam
2025-08-28 18:57:34 +00:00
committed by PyTorch MergeBot
parent 30ab87c884
commit 130e50afff
11 changed files with 233 additions and 16 deletions

View File

@ -0,0 +1,144 @@
# Owner(s): ["module: inductor"]
import torch
import torch._inductor.config
from torch._inductor import metrics
from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
class TestTorchDeviceAssertTrigger(TestCase):
def _run_assert_should_throw(self, device):
def func():
a = torch.tensor([1.0, -2.0], device=device)
result = torch.all(a > 0)
assert result, "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return False
except Exception:
return True
bisect_result = CompilerBisector.do_bisect(test_fn)
# do_bisect return None if all system is passed else return BisectionResult
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_should_not_throw(self, device):
def func():
a = torch.tensor([1.0, 2.0], device=device)
result = torch.all(a > 0)
assert result, "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return True
except Exception:
return False
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_inline_expression_should_throw(self, device):
def func():
a = torch.tensor([1.0, -2.0], device=device)
assert torch.all(a > 0), "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return False
except Exception:
return True
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_inline_expression_should_not_throw(self, device):
def func():
a = torch.tensor([1.0, 2.0], device=device)
assert torch.all(a > 0), "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return True
except Exception:
return False
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
@torch._inductor.config.patch(force_disable_caches=True)
def test_assert_should_throw(self):
device = "cpu"
self._run_assert_should_throw(device)
self._run_assert_inline_expression_should_throw(device)
@torch._inductor.config.patch(force_disable_caches=True)
def test_assert_should_not_throw(self):
device = "cpu"
self._run_assert_should_not_throw(device)
self._run_assert_inline_expression_should_not_throw(device)
@requires_cuda_and_triton
@skipIfRocm
@torch._inductor.config.patch(force_disable_caches=True)
def test_assert_fusion(self):
torch._logging.set_logs(inductor_metrics=True)
def func():
a = torch.tensor([1.0, 2.0], device="cuda")
result = torch.all(a > 0)
assert result, "should throw"
torch._dynamo.reset()
f_c = torch.compile(func, backend="inductor")
metrics.reset()
self.assertEqual(metrics.generated_kernel_count, 0)
f_c()
self.assertEqual(metrics.generated_kernel_count, 1)
torch._logging.set_logs()
@requires_cuda_and_triton
@skipIfRocm
@torch._inductor.config.patch(force_disable_caches=True)
def test_run_assert_triton(self):
@torch.compile(backend="inductor")
def fn():
a = torch.tensor([1.0, 2.0], device="cuda")
result = torch.all(a > 0)
assert result, "should throw"
def should_not_throw(fn):
try:
fn()
return True
except Exception:
return False
self.assertEqual(should_not_throw(fn), True)
_, code = run_and_get_code(fn)
self.assertEqual(code[0].count("tl.device_assert"), 1)
if __name__ == "__main__":
run_tests()

View File

@ -1119,6 +1119,10 @@ class CppOverrides(OpOverrides):
code.writeline("()")
return code
@staticmethod
def device_assert_async(cond, msg):
return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'
CppOverrides._initialize_pointwise_overrides("cpp")

View File

@ -566,6 +566,10 @@ class HalideOverrides(OpOverrides):
def frexp(x):
raise NotImplementedError("frexp")
@staticmethod
def device_assert_async(cond, msg):
raise NotImplementedError("device_assert_async")
HalideOverrides._initialize_pointwise_overrides("halide")

View File

@ -1578,6 +1578,10 @@ class TritonKernelOverrides(TritonOverrides):
V.kernel.cse.put(cache_key, (mantissa, exponent))
return (mantissa, exponent)
@staticmethod
def device_assert_async(cond, msg):
return f"tl.device_assert({cond}, {repr(msg)})"
class HelperFunctions:
"""An ordered set of helper functions."""

View File

@ -158,19 +158,6 @@ def _embedding_dense_backward(
)
# TODO: for now, inductor doesn't handle asserts
# because the condition is symbol -> tensor in the graph.
@register_decomposition([aten._assert_async.msg])
def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
return
# Following `assert_async_msg_decomp` and implement as non-op.
@register_decomposition([aten._functional_assert_async.msg])
def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
return
@register_decomposition([aten.sym_constrain_range_for_size.default])
def sym_constrain_range_for_size(
symbol: torch.SymInt,

View File

@ -373,6 +373,10 @@ class DtypePropagationOpsHandler:
f"{type(self).__name__}: ops.placeholder should not appear here"
)
@staticmethod
def device_assert_async(cond, msg: str) -> torch.dtype:
return torch.bool
if TYPE_CHECKING:

View File

@ -1094,7 +1094,10 @@ class Pointwise(Loops):
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Pointwise(
device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges
device=device,
dtype=self.dtype,
inner_fn=loader,
ranges=self.ranges,
)
@ -4423,6 +4426,17 @@ class ComputedBuffer(OperationBuffer):
"""
data: Loops
_force_realize: ClassVar[bool] = False
@staticmethod
@contextlib.contextmanager
def force_realize() -> Iterator[None]:
old_value = ComputedBuffer._force_realize
try:
ComputedBuffer._force_realize = True
yield
finally:
ComputedBuffer._force_realize = old_value
def get_computed_buffer_name(self) -> Optional[str]:
"""
@ -4497,6 +4511,7 @@ class ComputedBuffer(OperationBuffer):
not self.get_reduction_type()
and self.name not in V.graph.mutated_buffers
and self.num_reads() == 0
and not self._force_realize
):
# inline this op rather than generating ops.load()
return self.data.make_loader()

View File

@ -1329,6 +1329,34 @@ def quantized_decomposed_quantize_per_channel(
)
def _assert_async(cond, msg):
cond.realize()
cond = to_dtype(cond, torch.bool)
def inner_fn(index):
with ir.ComputedBuffer.force_realize():
return ops.device_assert_async(cond.make_loader()(index), msg)
assertion_op = Pointwise.create(
device=cond.get_device(),
dtype=cond.get_dtype(),
inner_fn=inner_fn,
ranges=list(cond.get_size()),
)
assertion_op.realize()
return assertion_op
@register_lowering(aten._assert_async.msg)
def lower_assert_async(cond, msg):
return _assert_async(cond, msg)
@register_lowering(aten._functional_assert_async.msg)
def lower_assert_functional_async(cond, msg):
return _assert_async(cond, msg)
@register_lowering(
quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
)

View File

@ -706,6 +706,9 @@ class OpsHandler(Generic[T]):
"""This is a fake op used in analysis but not codegen"""
raise NotImplementedError
def device_assert_async(self, cond: T, msg: str) -> T:
raise NotImplementedError
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
@ -788,6 +791,9 @@ class DefaultHandler(OpsHandler[Any]):
if target in OP_NAMES:
setattr(cls, target, impl)
def device_assert_async(self, cond, msg):
return None
DefaultHandler._init_cls()
@ -933,6 +939,9 @@ class MockHandler(BasicMathOpsMixin, DefaultHandler):
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
return sympy_index_symbol(str(index_var))
def device_assert_async(self, cond, msg):
return None
class KernelFormatterHandler(DefaultHandler):
def __init__(self, parent_handler: OpsHandler[Any]):
@ -999,6 +1008,9 @@ class KernelFormatterHandler(DefaultHandler):
self._output.writeline(f"return {result}")
return self._output.getvalue()
def device_assert_async(self, cond, msg: str):
return f"ops.device_assert_async({cond}, {msg})"
class WrapperHandler(DefaultHandler):
def __init__(self, inner: OpsHandler[Any]):

View File

@ -1279,6 +1279,13 @@ class SchedulerNode(BaseSchedulerNode):
)
return buffers_store_as_atomic_add
@cache_on_self
def has_side_effects(self) -> bool:
# self._body is None sometimes that's why this check was added
if self._body is not None and self._body.has_op("device_assert_async"):
return True
return super().has_side_effects()
def refresh_group_node_dependencies(
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
@ -1548,6 +1555,12 @@ class FusedSchedulerNode(BaseSchedulerNode):
return buf.getrawvalue().rstrip()
@cache_on_self
def has_side_effects(self) -> bool:
if self.snodes is not None:
return any(node.has_side_effects() for node in self.snodes)
return super().has_side_effects()
class ForeachKernelSchedulerNode(FusedSchedulerNode):
"""
@ -3877,7 +3890,6 @@ class Scheduler:
Determine if it is possible to combine node1 and node2 into a
single fused node.
"""
if node1 is node2:
return False
@ -3981,7 +3993,6 @@ class Scheduler:
):
why("fusion for buffer explicit disabled")
return False
device = node1.get_device()
device2 = node2.get_device()
if device != device2:

View File

@ -139,3 +139,7 @@ class ShapePropagationOpsHandler:
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
@staticmethod
def device_assert_async(cond: ShapeArg, msg: str) -> None:
return None