mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084). Changes Included - Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination. - Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor. - Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler. - Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code. - Added test cases to verify both "should throw" and "should not throw" scenarios. Fixes #147282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2db6c86b0
commit
378edb047f
@ -1329,6 +1329,39 @@ def quantized_decomposed_quantize_per_channel(
|
||||
)
|
||||
|
||||
|
||||
def _assert_async(cond, msg):
|
||||
cond.realize()
|
||||
cond = to_dtype(cond, torch.bool)
|
||||
|
||||
def inner_fn(index):
|
||||
if hasattr(cond.data, "data") and hasattr(cond.data.data, "force_realize"):
|
||||
with cond.data.data.force_realize():
|
||||
cond_loader = cond.make_loader()
|
||||
return ops.device_assert_async(cond_loader(index), msg)
|
||||
else:
|
||||
cond_loader = cond.make_loader()
|
||||
return ops.device_assert_async(cond_loader(index), msg)
|
||||
|
||||
assertion_op = Pointwise.create(
|
||||
device=cond.get_device(),
|
||||
dtype=cond.get_dtype(),
|
||||
inner_fn=inner_fn,
|
||||
ranges=list(cond.get_size()),
|
||||
)
|
||||
assertion_op.realize()
|
||||
return assertion_op
|
||||
|
||||
|
||||
@register_lowering(aten._assert_async.msg)
|
||||
def lower_assert_async(cond, msg):
|
||||
return _assert_async(cond, msg)
|
||||
|
||||
|
||||
@register_lowering(aten._functional_assert_async.msg)
|
||||
def lower_assert_functional_async(cond, msg):
|
||||
return _assert_async(cond, msg)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
|
||||
)
|
||||
|
Reference in New Issue
Block a user