mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084). Changes Included - Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination. - Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor. - Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler. - Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code. - Added test cases to verify both "should throw" and "should not throw" scenarios. Fixes #147282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
1e4dfeeb06
commit
cddcaa1903
@ -1276,6 +1276,13 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
)
|
||||
return buffers_store_as_atomic_add
|
||||
|
||||
@cache_on_self
|
||||
def has_side_effects(self) -> bool:
|
||||
# self._body is None sometimes that's why this check was added
|
||||
if self._body is not None and self._body.has_op("device_assert_async"):
|
||||
return True
|
||||
return super().has_side_effects()
|
||||
|
||||
|
||||
def refresh_group_node_dependencies(
|
||||
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
|
||||
@ -1545,6 +1552,12 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
return buf.getrawvalue().rstrip()
|
||||
|
||||
@cache_on_self
|
||||
def has_side_effects(self) -> bool:
|
||||
if self.snodes is not None:
|
||||
return any(node.has_side_effects() for node in self.snodes)
|
||||
return super().has_side_effects()
|
||||
|
||||
|
||||
class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
"""
|
||||
@ -3874,7 +3887,6 @@ class Scheduler:
|
||||
Determine if it is possible to combine node1 and node2 into a
|
||||
single fused node.
|
||||
"""
|
||||
|
||||
if node1 is node2:
|
||||
return False
|
||||
|
||||
@ -3978,7 +3990,6 @@ class Scheduler:
|
||||
):
|
||||
why("fusion for buffer explicit disabled")
|
||||
return False
|
||||
|
||||
device = node1.get_device()
|
||||
device2 = node2.get_device()
|
||||
if device != device2:
|
||||
|
Reference in New Issue
Block a user