mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Inductor] Update DeviceAssert op to behave like store (#163696)
Updated the DeviceAssert operation to match the behavior of Store, it will fixes the issue mentioned in [this PR](https://github.com/pytorch/pytorch/pull/163023) and updated testcases as Elias [suggested](https://github.com/pytorch/pytorch/pull/160677#discussion_r2353834646). Pull Request resolved: https://github.com/pytorch/pytorch/pull/163696 Approved by: https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
d927e55498
commit
8c98aee436
@ -3,100 +3,57 @@
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.compiler_bisector import BisectionResult, CompilerBisector
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import skipIfRocm
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
skipIfRocm,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestTorchDeviceAssertTrigger(TestCase):
|
||||
def _run_assert_should_throw(self, device):
|
||||
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_assert_should_throw(self, backend):
|
||||
def func():
|
||||
a = torch.tensor([1.0, -2.0], device=device)
|
||||
a = torch.tensor([1.0, -2.0], device="cpu")
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
def func_inline():
|
||||
a = torch.tensor([1.0, -2.0], device="cpu")
|
||||
assert torch.all(a > 0), "should throw"
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "should throw"):
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
f_c = torch.compile(func, backend=backend)
|
||||
f_c()
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
with self.assertRaisesRegex(RuntimeError, "should throw"):
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func_inline, backend=backend)
|
||||
f_c()
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
# do_bisect return None if all system is passed else return BisectionResult
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_should_not_throw(self, device):
|
||||
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
def test_assert_should_not_throw(self, backend):
|
||||
def func():
|
||||
a = torch.tensor([1.0, 2.0], device=device)
|
||||
a = torch.tensor([1.0, 2.0], device="cpu")
|
||||
result = torch.all(a > 0)
|
||||
assert result, "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_inline_expression_should_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, -2.0], device=device)
|
||||
def func_inline():
|
||||
a = torch.tensor([1.0, 2.0], device="cpu")
|
||||
assert torch.all(a > 0), "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func, backend=backend)
|
||||
f_c()
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
def _run_assert_inline_expression_should_not_throw(self, device):
|
||||
def func():
|
||||
a = torch.tensor([1.0, 2.0], device=device)
|
||||
assert torch.all(a > 0), "should throw"
|
||||
|
||||
def test_fn():
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func)
|
||||
|
||||
try:
|
||||
f_c()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
bisect_result = CompilerBisector.do_bisect(test_fn)
|
||||
self.assertNotIsInstance(bisect_result, BisectionResult)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_assert_should_throw(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_throw(device)
|
||||
self._run_assert_inline_expression_should_throw(device)
|
||||
|
||||
@torch._inductor.config.patch(force_disable_caches=True)
|
||||
def test_assert_should_not_throw(self):
|
||||
device = "cpu"
|
||||
self._run_assert_should_not_throw(device)
|
||||
self._run_assert_inline_expression_should_not_throw(device)
|
||||
torch._dynamo.reset()
|
||||
f_c = torch.compile(func_inline, backend=backend)
|
||||
f_c()
|
||||
|
||||
@requires_cuda_and_triton
|
||||
@skipIfRocm
|
||||
|
@ -1016,6 +1016,11 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||
f"{type(self).__name__}: store should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
|
||||
@ -2119,6 +2124,11 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
@ -2704,6 +2714,9 @@ class CSEProxy(DefaultHandler):
|
||||
self.kernel.store(name, index, value, mode=mode)
|
||||
self.kernel.num_store += 1
|
||||
|
||||
def device_assert_async(self, cond: CSEVariable, msg: str) -> None:
|
||||
self.kernel.device_assert_async(cond, msg)
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
|
||||
self.kernel.store_buffer_names.add(name)
|
||||
self._update_store_cache(name, value)
|
||||
|
@ -1119,10 +1119,6 @@ class CppOverrides(OpOverrides):
|
||||
code.writeline("()")
|
||||
return code
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
return f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0))'
|
||||
|
||||
|
||||
CppOverrides._initialize_pointwise_overrides("cpp")
|
||||
|
||||
@ -2138,6 +2134,11 @@ class CppKernel(Kernel):
|
||||
raise NotImplementedError(f"store mode={mode}")
|
||||
self.stores.writeline(DeferredLine(name, line))
|
||||
|
||||
def device_assert_async(self, cond, msg):
|
||||
self.compute.writeline(
|
||||
f'({cond} ? 0 : (throw std::runtime_error("{msg}"), 0));'
|
||||
)
|
||||
|
||||
def _gen_reduction_prefix(
|
||||
self,
|
||||
acc: Union[CSEVariable, str],
|
||||
|
@ -1597,10 +1597,6 @@ class TritonKernelOverrides(TritonOverrides):
|
||||
V.kernel.cse.put(cache_key, (mantissa, exponent))
|
||||
return (mantissa, exponent)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg):
|
||||
return f"tl.device_assert({cond}, {repr(msg)})"
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
"""An ordered set of helper functions."""
|
||||
@ -2845,6 +2841,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
|
||||
exit_stack.close()
|
||||
|
||||
def device_assert_async(self, cond, msg) -> None:
|
||||
self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
|
||||
|
||||
def guard_cooperative_store(self, name, buffer):
|
||||
"""
|
||||
For cooperative reductions only one thread block should write out the result.
|
||||
|
@ -374,8 +374,8 @@ class DtypePropagationOpsHandler:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def device_assert_async(cond, msg: str) -> torch.dtype:
|
||||
return torch.bool
|
||||
def device_assert_async(cond, msg: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -791,9 +791,6 @@ class DefaultHandler(OpsHandler[Any]):
|
||||
if target in OP_NAMES:
|
||||
setattr(cls, target, impl)
|
||||
|
||||
def device_assert_async(self, cond, msg):
|
||||
return None
|
||||
|
||||
|
||||
DefaultHandler._init_cls()
|
||||
|
||||
@ -939,9 +936,6 @@ class MockHandler(BasicMathOpsMixin, DefaultHandler):
|
||||
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
def device_assert_async(self, cond, msg):
|
||||
return None
|
||||
|
||||
|
||||
class KernelFormatterHandler(DefaultHandler):
|
||||
def __init__(self, parent_handler: OpsHandler[Any]):
|
||||
@ -1008,9 +1002,6 @@ class KernelFormatterHandler(DefaultHandler):
|
||||
self._output.writeline(f"return {result}")
|
||||
return self._output.getvalue()
|
||||
|
||||
def device_assert_async(self, cond, msg: str):
|
||||
return f"ops.device_assert_async({cond}, {msg})"
|
||||
|
||||
|
||||
class WrapperHandler(DefaultHandler):
|
||||
def __init__(self, inner: OpsHandler[Any]):
|
||||
@ -1158,3 +1149,8 @@ class SimpleCSEHandler(WrapperHandler):
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
self.cse_cache[key] = val
|
||||
return val
|
||||
|
||||
def device_assert_async(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: device_assert_async should be handled by CSEProxy"
|
||||
)
|
||||
|
Reference in New Issue
Block a user