mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Updated the DeviceAssert operation to match the behavior of Store, it will fixes the issue mentioned in [this PR](https://github.com/pytorch/pytorch/pull/163023) and updated testcases as Elias [suggested](https://github.com/pytorch/pytorch/pull/160677#discussion_r2353834646). Pull Request resolved: https://github.com/pytorch/pytorch/pull/163696 Approved by: https://github.com/mlazos
102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
import torch
|
|
import torch._inductor.config
|
|
from torch._inductor import metrics
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
skipIfRocm,
|
|
)
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestTorchDeviceAssertTrigger(TestCase):
|
|
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_assert_should_throw(self, backend):
|
|
def func():
|
|
a = torch.tensor([1.0, -2.0], device="cpu")
|
|
result = torch.all(a > 0)
|
|
assert result, "should throw"
|
|
|
|
def func_inline():
|
|
a = torch.tensor([1.0, -2.0], device="cpu")
|
|
assert torch.all(a > 0), "should throw"
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "should throw"):
|
|
torch._dynamo.reset()
|
|
f_c = torch.compile(func, backend=backend)
|
|
f_c()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "should throw"):
|
|
torch._dynamo.reset()
|
|
f_c = torch.compile(func_inline, backend=backend)
|
|
f_c()
|
|
|
|
@parametrize("backend", ["eager", "aot_eager", "inductor"])
|
|
def test_assert_should_not_throw(self, backend):
|
|
def func():
|
|
a = torch.tensor([1.0, 2.0], device="cpu")
|
|
result = torch.all(a > 0)
|
|
assert result, "should throw"
|
|
|
|
def func_inline():
|
|
a = torch.tensor([1.0, 2.0], device="cpu")
|
|
assert torch.all(a > 0), "should throw"
|
|
|
|
torch._dynamo.reset()
|
|
f_c = torch.compile(func, backend=backend)
|
|
f_c()
|
|
|
|
torch._dynamo.reset()
|
|
f_c = torch.compile(func_inline, backend=backend)
|
|
f_c()
|
|
|
|
@requires_cuda_and_triton
|
|
@skipIfRocm
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_assert_fusion(self):
|
|
torch._logging.set_logs(inductor_metrics=True)
|
|
|
|
def func():
|
|
a = torch.tensor([1.0, 2.0], device="cuda")
|
|
result = torch.all(a > 0)
|
|
assert result, "should throw"
|
|
|
|
torch._dynamo.reset()
|
|
f_c = torch.compile(func, backend="inductor")
|
|
metrics.reset()
|
|
self.assertEqual(metrics.generated_kernel_count, 0)
|
|
f_c()
|
|
self.assertEqual(metrics.generated_kernel_count, 1)
|
|
torch._logging.set_logs()
|
|
|
|
@requires_cuda_and_triton
|
|
@skipIfRocm
|
|
@torch._inductor.config.patch(force_disable_caches=True)
|
|
def test_run_assert_triton(self):
|
|
@torch.compile(backend="inductor")
|
|
def fn():
|
|
a = torch.tensor([1.0, 2.0], device="cuda")
|
|
result = torch.all(a > 0)
|
|
assert result, "should throw"
|
|
|
|
def should_not_throw(fn):
|
|
try:
|
|
fn()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
self.assertEqual(should_not_throw(fn), True)
|
|
|
|
_, code = run_and_get_code(fn)
|
|
self.assertEqual(code[0].count("tl.device_assert"), 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|