[Inductor] Update DeviceAssert op to behave like store (#163696)

Updated the DeviceAssert operation to match the behavior of Store, it will fixes the issue mentioned in [this PR](https://github.com/pytorch/pytorch/pull/163023) and updated testcases as Elias [suggested](https://github.com/pytorch/pytorch/pull/160677#discussion_r2353834646).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163696
Approved by: https://github.com/mlazos
This commit is contained in:
karthickai
2025-09-24 09:33:39 -07:00
committed by PyTorch MergeBot
parent d927e55498
commit 8c98aee436
6 changed files with 59 additions and 93 deletions

View File

@ -3,100 +3,57 @@
import torch
import torch._inductor.config
from torch._inductor import metrics
from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.triton_utils import requires_cuda_and_triton
@instantiate_parametrized_tests
class TestTorchDeviceAssertTrigger(TestCase):
def _run_assert_should_throw(self, device):
@parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_assert_should_throw(self, backend):
def func():
a = torch.tensor([1.0, -2.0], device=device)
a = torch.tensor([1.0, -2.0], device="cpu")
result = torch.all(a > 0)
assert result, "should throw"
def test_fn():
def func_inline():
a = torch.tensor([1.0, -2.0], device="cpu")
assert torch.all(a > 0), "should throw"
with self.assertRaisesRegex(RuntimeError, "should throw"):
torch._dynamo.reset()
f_c = torch.compile(func)
f_c = torch.compile(func, backend=backend)
f_c()
try:
f_c()
return False
except Exception:
return True
with self.assertRaisesRegex(RuntimeError, "should throw"):
torch._dynamo.reset()
f_c = torch.compile(func_inline, backend=backend)
f_c()
bisect_result = CompilerBisector.do_bisect(test_fn)
# do_bisect return None if all system is passed else return BisectionResult
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_should_not_throw(self, device):
@parametrize("backend", ["eager", "aot_eager", "inductor"])
def test_assert_should_not_throw(self, backend):
def func():
a = torch.tensor([1.0, 2.0], device=device)
a = torch.tensor([1.0, 2.0], device="cpu")
result = torch.all(a > 0)
assert result, "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return True
except Exception:
return False
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_inline_expression_should_throw(self, device):
def func():
a = torch.tensor([1.0, -2.0], device=device)
def func_inline():
a = torch.tensor([1.0, 2.0], device="cpu")
assert torch.all(a > 0), "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
torch._dynamo.reset()
f_c = torch.compile(func, backend=backend)
f_c()
try:
f_c()
return False
except Exception:
return True
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
def _run_assert_inline_expression_should_not_throw(self, device):
def func():
a = torch.tensor([1.0, 2.0], device=device)
assert torch.all(a > 0), "should throw"
def test_fn():
torch._dynamo.reset()
f_c = torch.compile(func)
try:
f_c()
return True
except Exception:
return False
bisect_result = CompilerBisector.do_bisect(test_fn)
self.assertNotIsInstance(bisect_result, BisectionResult)
@torch._inductor.config.patch(force_disable_caches=True)
def test_assert_should_throw(self):
device = "cpu"
self._run_assert_should_throw(device)
self._run_assert_inline_expression_should_throw(device)
@torch._inductor.config.patch(force_disable_caches=True)
def test_assert_should_not_throw(self):
device = "cpu"
self._run_assert_should_not_throw(device)
self._run_assert_inline_expression_should_not_throw(device)
torch._dynamo.reset()
f_c = torch.compile(func_inline, backend=backend)
f_c()
@requires_cuda_and_triton
@skipIfRocm

View File

@ -1016,6 +1016,11 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
f"{type(self).__name__}: store should be handled by CSEProxy"
)
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
raise NotImplementedError(
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
)
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
raise NotImplementedError(
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
@ -2119,6 +2124,11 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
) -> None:
raise NotImplementedError
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
raise NotImplementedError(
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
)
def reduction(
self,
dtype: torch.dtype,
@ -2704,6 +2714,9 @@ class CSEProxy(DefaultHandler):
self.kernel.store(name, index, value, mode=mode)
self.kernel.num_store += 1
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
self.kernel.device_assert_async(cond, msg)
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
self.kernel.store_buffer_names.add(name)
self._update_store_cache(name, value)

View File

@ -1119,10 +1119,6 @@ class CppOverrides(OpOverrides):
code.writeline("()")
return code
@staticmethod
def device_assert_async(cond, msg):
return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'
CppOverrides._initialize_pointwise_overrides("cpp")
@ -2138,6 +2134,11 @@ class CppKernel(Kernel):
raise NotImplementedError(f"store mode={mode}")
self.stores.writeline(DeferredLine(name, line))
def device_assert_async(self, cond, msg):
self.compute.writeline(
f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0));'
)
def _gen_reduction_prefix(
self,
acc: Union[CSEVariable, str],

View File

@ -1597,10 +1597,6 @@ class TritonKernelOverrides(TritonOverrides):
V.kernel.cse.put(cache_key, (mantissa, exponent))
return (mantissa, exponent)
@staticmethod
def device_assert_async(cond, msg):
return f"tl.device_assert({cond}, {repr(msg)})"
class HelperFunctions:
"""An ordered set of helper functions."""
@ -2845,6 +2841,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
exit_stack.close()
def device_assert_async(self, cond, msg) -> None:
self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
def guard_cooperative_store(self, name, buffer):
"""
For cooperative reductions only one thread block should write out the result.

View File

@ -374,8 +374,8 @@ class DtypePropagationOpsHandler:
)
@staticmethod
def device_assert_async(cond, msg: str) -> torch.dtype:
return torch.bool
def device_assert_async(cond, msg: str) -> None:
return None
if TYPE_CHECKING:

View File

@ -791,9 +791,6 @@ class DefaultHandler(OpsHandler[Any]):
if target in OP_NAMES:
setattr(cls, target, impl)
def device_assert_async(self, cond, msg):
return None
DefaultHandler._init_cls()
@ -939,9 +936,6 @@ class MockHandler(BasicMathOpsMixin, DefaultHandler):
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
return sympy_index_symbol(str(index_var))
def device_assert_async(self, cond, msg):
return None
class KernelFormatterHandler(DefaultHandler):
def __init__(self, parent_handler: OpsHandler[Any]):
@ -1008,9 +1002,6 @@ class KernelFormatterHandler(DefaultHandler):
self._output.writeline(f"return {result}")
return self._output.getvalue()
def device_assert_async(self, cond, msg: str):
return f"ops.device_assert_async({cond}, {msg})"
class WrapperHandler(DefaultHandler):
def __init__(self, inner: OpsHandler[Any]):
@ -1158,3 +1149,8 @@ class SimpleCSEHandler(WrapperHandler):
val = getattr(self._inner, name)(*args, **kwargs)
self.cse_cache[key] = val
return val
def device_assert_async(self, *args, **kwargs) -> None:
raise NotImplementedError(
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
)