[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)

This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084).

Changes Included

- Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination.
- Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor.
- Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler.
- Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code.
- Added test cases to verify both "should throw" and "should not throw" scenarios.

Fixes #147282

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677
Approved by: https://github.com/mlazos, https://github.com/atalman
This commit is contained in:
Karthick Panner Selvam
2025-08-28 18:57:34 +00:00
committed by PyTorch MergeBot
parent 30ab87c884
commit 130e50afff
11 changed files with 233 additions and 16 deletions

View File

@ -1279,6 +1279,13 @@ class SchedulerNode(BaseSchedulerNode):
)
return buffers_store_as_atomic_add
@cache_on_self
def has_side_effects(self) -> bool:
# self._body is None sometimes that's why this check was added
if self._body is not None and self._body.has_op("device_assert_async"):
return True
return super().has_side_effects()
def refresh_group_node_dependencies(
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
@ -1548,6 +1555,12 @@ class FusedSchedulerNode(BaseSchedulerNode):
return buf.getrawvalue().rstrip()
@cache_on_self
def has_side_effects(self) -> bool:
if self.snodes is not None:
return any(node.has_side_effects() for node in self.snodes)
return super().has_side_effects()
class ForeachKernelSchedulerNode(FusedSchedulerNode):
"""
@ -3877,7 +3890,6 @@ class Scheduler:
Determine if it is possible to combine node1 and node2 into a
single fused node.
"""
if node1 is node2:
return False
@ -3981,7 +3993,6 @@ class Scheduler:
):
why("fusion for buffer explicit disabled")
return False
device = node1.get_device()
device2 = node2.get_device()
if device != device2: