From 8c98aee436c07d6341e71dbb8284e32bd7fac278 Mon Sep 17 00:00:00 2001 From: karthickai Date: Wed, 24 Sep 2025 09:33:39 -0700 Subject: [PATCH] [Inductor] Update DeviceAssert op to behave like store (#163696) Updated the DeviceAssert operation to match the behavior of Store, it will fixes the issue mentioned in [this PR](https://github.com/pytorch/pytorch/pull/163023) and updated testcases as Elias [suggested](https://github.com/pytorch/pytorch/pull/160677#discussion_r2353834646). Pull Request resolved: https://github.com/pytorch/pytorch/pull/163696 Approved by: https://github.com/mlazos --- test/inductor/test_device_assert.py | 105 ++++++++------------------- torch/_inductor/codegen/common.py | 13 ++++ torch/_inductor/codegen/cpp.py | 9 ++- torch/_inductor/codegen/triton.py | 7 +- torch/_inductor/dtype_propagation.py | 4 +- torch/_inductor/ops_handler.py | 14 ++-- 6 files changed, 59 insertions(+), 93 deletions(-) diff --git a/test/inductor/test_device_assert.py b/test/inductor/test_device_assert.py index ddf85f9d88da..f3c142299501 100644 --- a/test/inductor/test_device_assert.py +++ b/test/inductor/test_device_assert.py @@ -3,100 +3,57 @@ import torch import torch._inductor.config from torch._inductor import metrics -from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import skipIfRocm +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + skipIfRocm, +) from torch.testing._internal.triton_utils import requires_cuda_and_triton +@instantiate_parametrized_tests class TestTorchDeviceAssertTrigger(TestCase): - def _run_assert_should_throw(self, device): + @parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_assert_should_throw(self, backend): def func(): - a = torch.tensor([1.0, -2.0], device=device) + a = torch.tensor([1.0, -2.0], device="cpu") result = torch.all(a > 0) assert result, "should throw" - def test_fn(): + def func_inline(): + a = torch.tensor([1.0, -2.0], device="cpu") + assert torch.all(a > 0), "should throw" + + with self.assertRaisesRegex(RuntimeError, "should throw"): torch._dynamo.reset() - f_c = torch.compile(func) + f_c = torch.compile(func, backend=backend) + f_c() - try: - f_c() - return False - except Exception: - return True + with self.assertRaisesRegex(RuntimeError, "should throw"): + torch._dynamo.reset() + f_c = torch.compile(func_inline, backend=backend) + f_c() - bisect_result = CompilerBisector.do_bisect(test_fn) - # do_bisect return None if all system is passed else return BisectionResult - self.assertNotIsInstance(bisect_result, BisectionResult) - - def _run_assert_should_not_throw(self, device): + @parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_assert_should_not_throw(self, backend): def func(): - a = torch.tensor([1.0, 2.0], device=device) + a = torch.tensor([1.0, 2.0], device="cpu") result = torch.all(a > 0) assert result, "should throw" - def test_fn(): - torch._dynamo.reset() - f_c = torch.compile(func) - - try: - f_c() - return True - except Exception: - return False - - bisect_result = CompilerBisector.do_bisect(test_fn) - self.assertNotIsInstance(bisect_result, BisectionResult) - - def _run_assert_inline_expression_should_throw(self, device): - def func(): - a = torch.tensor([1.0, -2.0], device=device) + def func_inline(): + a = torch.tensor([1.0, 2.0], device="cpu") assert torch.all(a > 0), "should throw" - def test_fn(): - torch._dynamo.reset() - f_c = torch.compile(func) + torch._dynamo.reset() + f_c = torch.compile(func, backend=backend) + f_c() - try: - f_c() - return False - except Exception: - return True - - bisect_result = CompilerBisector.do_bisect(test_fn) - self.assertNotIsInstance(bisect_result, BisectionResult) - - def _run_assert_inline_expression_should_not_throw(self, device): - def func(): - a = torch.tensor([1.0, 2.0], device=device) - assert torch.all(a > 0), "should throw" - - def test_fn(): - torch._dynamo.reset() - f_c = torch.compile(func) - - try: - f_c() - return True - except Exception: - return False - - bisect_result = CompilerBisector.do_bisect(test_fn) - self.assertNotIsInstance(bisect_result, BisectionResult) - - @torch._inductor.config.patch(force_disable_caches=True) - def test_assert_should_throw(self): - device = "cpu" - self._run_assert_should_throw(device) - self._run_assert_inline_expression_should_throw(device) - - @torch._inductor.config.patch(force_disable_caches=True) - def test_assert_should_not_throw(self): - device = "cpu" - self._run_assert_should_not_throw(device) - self._run_assert_inline_expression_should_not_throw(device) + torch._dynamo.reset() + f_c = torch.compile(func_inline, backend=backend) + f_c() @requires_cuda_and_triton @skipIfRocm diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index ebea8e3a6339..fee086ee6db5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1016,6 +1016,11 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]): f"{type(self).__name__}: store should be handled by CSEProxy" ) + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + ) + def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None: raise NotImplementedError( f"{type(self).__name__}: store_reduction should be handled by CSEProxy" @@ -2119,6 +2124,11 @@ class Kernel(CodeGen, Generic[CSEVariableType]): ) -> None: raise NotImplementedError + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + ) + def reduction( self, dtype: torch.dtype, @@ -2704,6 +2714,9 @@ class CSEProxy(DefaultHandler): self.kernel.store(name, index, value, mode=mode) self.kernel.num_store += 1 + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + self.kernel.device_assert_async(cond, msg) + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: self.kernel.store_buffer_names.add(name) self._update_store_cache(name, value) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 18103c9aab01..15bc7d283b1d 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1119,10 +1119,6 @@ class CppOverrides(OpOverrides): code.writeline("()") return code - @staticmethod - def device_assert_async(cond, msg): - return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))' - CppOverrides._initialize_pointwise_overrides("cpp") @@ -2138,6 +2134,11 @@ class CppKernel(Kernel): raise NotImplementedError(f"store mode={mode}") self.stores.writeline(DeferredLine(name, line)) + def device_assert_async(self, cond, msg): + self.compute.writeline( + f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0));' + ) + def _gen_reduction_prefix( self, acc: Union[CSEVariable, str], diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d6601569a047..3367ee758c01 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1597,10 +1597,6 @@ class TritonKernelOverrides(TritonOverrides): V.kernel.cse.put(cache_key, (mantissa, exponent)) return (mantissa, exponent) - @staticmethod - def device_assert_async(cond, msg): - return f"tl.device_assert({cond}, {repr(msg)})" - class HelperFunctions: """An ordered set of helper functions.""" @@ -2845,6 +2841,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): exit_stack.close() + def device_assert_async(self, cond, msg) -> None: + self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})") + def guard_cooperative_store(self, name, buffer): """ For cooperative reductions only one thread block should write out the result. diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index d80caa1e2b72..4c30079549c5 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -374,8 +374,8 @@ class DtypePropagationOpsHandler: ) @staticmethod - def device_assert_async(cond, msg: str) -> torch.dtype: - return torch.bool + def device_assert_async(cond, msg: str) -> None: + return None if TYPE_CHECKING: diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index cccb0e294362..d24111978bdb 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -791,9 +791,6 @@ class DefaultHandler(OpsHandler[Any]): if target in OP_NAMES: setattr(cls, target, impl) - def device_assert_async(self, cond, msg): - return None - DefaultHandler._init_cls() @@ -939,9 +936,6 @@ class MockHandler(BasicMathOpsMixin, DefaultHandler): def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: return sympy_index_symbol(str(index_var)) - def device_assert_async(self, cond, msg): - return None - class KernelFormatterHandler(DefaultHandler): def __init__(self, parent_handler: OpsHandler[Any]): @@ -1008,9 +1002,6 @@ class KernelFormatterHandler(DefaultHandler): self._output.writeline(f"return {result}") return self._output.getvalue() - def device_assert_async(self, cond, msg: str): - return f"ops.device_assert_async({cond}, {msg})" - class WrapperHandler(DefaultHandler): def __init__(self, inner: OpsHandler[Any]): @@ -1158,3 +1149,8 @@ class SimpleCSEHandler(WrapperHandler): val = getattr(self._inner, name)(*args, **kwargs) self.cse_cache[key] = val return val + + def device_assert_async(self, *args, **kwargs) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + )