mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Add DeviceAssert op to enable device-side assertion in torch.compile (#160677)
This PR introduces a device_assert op to trigger device-side assertions within torch.compile. This implementation is based on the suggestion in [this comment](https://github.com/pytorch/pytorch/issues/147282#issuecomment-2756056084). Changes Included - Implemented device_assert op and overrides has_side_effect to return True to avoid removal by dead code elimination. - Commented out the assert_async_msg_decomp and functional_assert_async_msg_decomp decompositions to disable the default assert decomposition inside Inductor. - Added lowering for torch.ops.aten._assert_async.msg to convert assert calls into the ops_handler. - Implemented the codegen method for the device_assert op. This supports generating C++ and Triton code. - Added test cases to verify both "should throw" and "should not throw" scenarios. Fixes #147282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160677 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2db6c86b0
commit
378edb047f
204
test/inductor/test_device_assert.py
Normal file
204
test/inductor/test_device_assert.py
Normal file
@ -0,0 +1,204 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
class TestTorchDeviceAssertTrigger(TestCase):
|
||||
def _run_assert_should_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, -2.0], device=device)
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
# do_bisect return None if all system is passed else return BisectionResult
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_should_not_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, 2.0], device=device)
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_inline_expression_should_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, -2.0], device=device)
|
||||
assert torch.all(a > 0), "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_inline_expression_should_not_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, 2.0], device=device)
|
||||
assert torch.all(a > 0), "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_assert_should_throw(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_throw(device)
|
||||
self._run_assert_inline_expression_should_throw(device)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_assert_should_not_throw(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_not_throw(device)
|
||||
self._run_assert_inline_expression_should_not_throw(device)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True)
|
||||
def test_assert_should_throw_cpp_wrapper(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_throw(device)
|
||||
self._run_assert_inline_expression_should_throw(device)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True, cpp_wrapper=True)
|
||||
def test_assert_should_not_throw_cpp_wrapper(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_not_throw(device)
|
||||
self._run_assert_inline_expression_should_not_throw(device)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@skipIfRocm
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_assert_fusion(self):
|
||||
torch._logging.set_logs(inductor_metrics=True)
|
||||
|
||||
def func():
|
||||
a = torch.tensor([1.0, 2.0], device="cuda")
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func, backend="inductor")
|
||||
metrics.reset()
|
||||
self.assertEqual(metrics.generated_kernel_count, 0)
|
||||
f_c()
|
||||
self.assertEqual(metrics.generated_kernel_count, 1)
|
||||
torch._logging.set_logs()
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@skipIfRocm
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_run_assert_triton(self):
|
||||
should_throw = """
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
def func_should_throw():
|
||||
a = torch.tensor([1.0, -2.0], device='cuda')
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func_should_throw, backend="inductor")
|
||||
|
||||
try:
|
||||
f_c()
|
||||
torch.cuda.synchronize()
|
||||
return False
|
||||
except Exception as e:
|
||||
return True
|
||||
|
||||
result = test_fn()
|
||||
print(f"Test result: {result}")
|
||||
"""
|
||||
|
||||
should_not_throw = """
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
def func_should_not_throw():
|
||||
a = torch.tensor([1.0, 2.0], device='cuda')
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func_should_not_throw, backend="inductor")
|
||||
|
||||
try:
|
||||
f_c()
|
||||
torch.cuda.synchronize()
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
result = test_fn()
|
||||
print(f"Test result: {result}")
|
||||
"""
|
||||
for script in [should_not_throw, should_throw]:
|
||||
p = subprocess.run(
|
||||
[sys.executable, "-c", script],
|
||||
cwd=os.path.dirname(os.path.realpath(__file__)),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
output = p.stdout + "\n" + p.stderr
|
||||
|
||||
self.assertIn("Test result: True", output)
|
||||
|
||||
if p.returncode != 0:
|
||||
self.fail(
|
||||
f"Subprocess failed with return code {p.returncode}. Output: {output}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -1119,6 +1119,10 @@ class CppOverrides(OpOverrides):
|
||||
code.writeline("()")
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'
|
||||
|
||||
|
||||
CppOverrides._initialize_pointwise_overrides("cpp")
|
||||
|
||||
|
@ -566,6 +566,10 @@ class HalideOverrides(OpOverrides):
|
||||
def frexp(x):
|
||||
raise NotImplementedError("frexp")
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
raise NotImplementedError("device_assert_async")
|
||||
|
||||
|
||||
HalideOverrides._initialize_pointwise_overrides("halide")
|
||||
|
||||
|
@ -1592,6 +1592,10 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
V.kernel.cse.put(cache_key, (mantissa, exponent))
|
||||
return (mantissa, exponent)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
return f"tl.device_assert({cond}, {repr(msg)})"
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
"""An ordered set of helper functions."""
|
||||
|
@ -158,19 +158,6 @@ def _embedding_dense_backward(
|
||||
)
|
||||
|
||||
|
||||
# TODO: for now, inductor doesn't handle asserts
|
||||
# because the condition is symbol -> tensor in the graph.
|
||||
@register_decomposition([aten._assert_async.msg])
|
||||
def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
# Following `assert_async_msg_decomp` and implement as non-op.
|
||||
@register_decomposition([aten._functional_assert_async.msg])
|
||||
def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
@register_decomposition([aten.sym_constrain_range_for_size.default])
|
||||
def sym_constrain_range_for_size(
|
||||
symbol: torch.SymInt,
|
||||
|
@ -373,6 +373,10 @@ class DtypePropagationOpsHandler:
|
||||
f"{type(self).__name__}: ops.placeholder should not appear here"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg: str) -> torch.dtype:
|
||||
return torch.bool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
@ -1094,7 +1094,10 @@ class Pointwise(Loops):
|
||||
loader = self.make_loader()
|
||||
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
|
||||
return Pointwise(
|
||||
device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
inner_fn=loader,
|
||||
ranges=self.ranges,
|
||||
)
|
||||
|
||||
|
||||
@ -4423,6 +4426,17 @@ class ComputedBuffer(OperationBuffer):
|
||||
"""
|
||||
|
||||
data: Loops
|
||||
_force_realize: ClassVar[bool] = False
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
def force_realize() -> Iterator[None]:
|
||||
old_value = ComputedBuffer._force_realize
|
||||
try:
|
||||
ComputedBuffer._force_realize = True
|
||||
yield
|
||||
finally:
|
||||
ComputedBuffer._force_realize = old_value
|
||||
|
||||
def get_computed_buffer_name(self) -> Optional[str]:
|
||||
"""
|
||||
@ -4497,6 +4511,7 @@ class ComputedBuffer(OperationBuffer):
|
||||
not self.get_reduction_type()
|
||||
and self.name not in V.graph.mutated_buffers
|
||||
and self.num_reads() == 0
|
||||
and not self._force_realize
|
||||
):
|
||||
# inline this op rather than generating ops.load()
|
||||
return self.data.make_loader()
|
||||
|
@ -1329,6 +1329,39 @@ def quantized_decomposed_quantize_per_channel(
|
||||
)
|
||||
|
||||
|
||||
def _assert_async(cond, msg):
|
||||
cond.realize()
|
||||
cond = to_dtype(cond, torch.bool)
|
||||
|
||||
def inner_fn(index):
|
||||
if hasattr(cond.data, "data") and hasattr(cond.data.data, "force_realize"):
|
||||
with cond.data.data.force_realize():
|
||||
cond_loader = cond.make_loader()
|
||||
return ops.device_assert_async(cond_loader(index), msg)
|
||||
else:
|
||||
cond_loader = cond.make_loader()
|
||||
return ops.device_assert_async(cond_loader(index), msg)
|
||||
|
||||
assertion_op = Pointwise.create(
|
||||
device=cond.get_device(),
|
||||
dtype=cond.get_dtype(),
|
||||
inner_fn=inner_fn,
|
||||
ranges=list(cond.get_size()),
|
||||
)
|
||||
assertion_op.realize()
|
||||
return assertion_op
|
||||
|
||||
|
||||
@register_lowering(aten._assert_async.msg)
|
||||
def lower_assert_async(cond, msg):
|
||||
return _assert_async(cond, msg)
|
||||
|
||||
|
||||
@register_lowering(aten._functional_assert_async.msg)
|
||||
def lower_assert_functional_async(cond, msg):
|
||||
return _assert_async(cond, msg)
|
||||
|
||||
|
||||
@register_lowering(
|
||||
quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
|
||||
)
|
||||
|
@ -706,6 +706,9 @@ class OpsHandler(Generic[T]):
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
raise NotImplementedError
|
||||
|
||||
def device_assert_async(self, cond: T, msg: str) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
|
||||
|
||||
@ -788,6 +791,9 @@ class DefaultHandler(OpsHandler[Any]):
|
||||
if target in OP_NAMES:
|
||||
setattr(cls, target, impl)
|
||||
|
||||
def device_assert_async(self, cond, msg):
|
||||
return None
|
||||
|
||||
|
||||
DefaultHandler._init_cls()
|
||||
|
||||
@ -933,6 +939,9 @@ class MockHandler(BasicMathOpsMixin, DefaultHandler):
|
||||
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
def device_assert_async(self, cond, msg):
|
||||
return None
|
||||
|
||||
|
||||
class KernelFormatterHandler(DefaultHandler):
|
||||
def __init__(self, parent_handler: OpsHandler[Any]):
|
||||
@ -999,6 +1008,9 @@ class KernelFormatterHandler(DefaultHandler):
|
||||
self._output.writeline(f"return {result}")
|
||||
return self._output.getvalue()
|
||||
|
||||
def device_assert_async(self, cond, msg: str):
|
||||
return f"ops.device_assert_async({cond}, {msg})"
|
||||
|
||||
|
||||
class WrapperHandler(DefaultHandler):
|
||||
def __init__(self, inner: OpsHandler[Any]):
|
||||
|
@ -1276,6 +1276,13 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
)
|
||||
return buffers_store_as_atomic_add
|
||||
|
||||
@cache_on_self
|
||||
def has_side_effects(self) -> bool:
|
||||
# self._body is None sometimes that's why this check was added
|
||||
if self._body is not None and self._body.has_op("device_assert_async"):
|
||||
return True
|
||||
return super().has_side_effects()
|
||||
|
||||
|
||||
def refresh_group_node_dependencies(
|
||||
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
|
||||
@ -1545,6 +1552,12 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
return buf.getrawvalue().rstrip()
|
||||
|
||||
@cache_on_self
|
||||
def has_side_effects(self) -> bool:
|
||||
if self.snodes is not None:
|
||||
return any(node.has_side_effects() for node in self.snodes)
|
||||
return super().has_side_effects()
|
||||
|
||||
|
||||
class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
"""
|
||||
@ -3874,7 +3887,6 @@ class Scheduler:
|
||||
Determine if it is possible to combine node1 and node2 into a
|
||||
single fused node.
|
||||
"""
|
||||
|
||||
if node1 is node2:
|
||||
return False
|
||||
|
||||
@ -3978,7 +3990,6 @@ class Scheduler:
|
||||
):
|
||||
why("fusion for buffer explicit disabled")
|
||||
return False
|
||||
|
||||
device = node1.get_device()
|
||||
device2 = node2.get_device()
|
||||
if device != device2:
|
||||
|
@ -139,3 +139,7 @@ class ShapePropagationOpsHandler:
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., BlockShapeType]:
|
||||
return lambda *args, **kwargs: broadcast_shapes_for_args(args)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond: ShapeArg, msg: str) -> None:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user