Files
pytorch/test/dynamo/test_skip_guard_eval_unsafe.py

148 lines
3.8 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
def my_custom_function(x):
return x + 1
class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
def test_bool_recompile(self):
def fn(x, y, c):
if c:
return x * y
else:
return x + y
opt_fn = torch.compile(fn, backend="inductor")
x = 2 * torch.ones(4)
y = 3 * torch.ones(4)
ref1 = opt_fn(x, y, True)
ref2 = opt_fn(x, y, False)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res2 = opt_fn(x, y, False)
res1 = opt_fn(x, y, True)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_tensor_recompile(self):
def fn(x, y):
return x * y
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4, dtype=torch.float32)
y = torch.randn(4, dtype=torch.float32)
ref1 = opt_fn(x, y)
x64 = torch.randn(4, dtype=torch.float64)
y64 = torch.randn(4, dtype=torch.float64)
ref2 = opt_fn(x64, y64)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res1 = opt_fn(x, y)
res2 = opt_fn(x64, y64)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_post_recompile(self):
class Foo:
def __init__(self):
self.a = 4
self.b = 5
foo = Foo()
def fn(x):
return x + foo.a + foo.b
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
foo.a = 11
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
# Set it back to original value
foo.a = 4
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
foo.a = 11
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
# Check that we are back to original behavior
foo.b = 8
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
def test_fail_on_tensor_shape_change(self):
def fn(dt):
return dt["x"] + 1
x = torch.randn(4)
dt = {}
dt["x"] = x
opt_fn = torch.compile(fn, backend="eager")
opt_fn(dt)
with self.assertRaisesRegex(
RuntimeError, "Recompilation triggered with skip_guard_eval_unsafe stance"
):
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
x = torch.randn(4, 4)
dt["x"] = x
opt_fn(dt)
def test_cache_line_pickup(self):
def fn(x, a=None, b=None):
x = x * 3
if a:
x = x * 5
if b:
x = x * 7
return x
opt_fn = torch.compile(fn, backend="eager")
x = torch.ones(4)
ref1 = opt_fn(x, a=None, b=None)
ref2 = opt_fn(x, a=1, b=None)
ref3 = opt_fn(x, a=1, b=1)
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
res1 = opt_fn(x, a=None, b=None)
res2 = opt_fn(x, a=1, b=None)
res3 = opt_fn(x, a=1, b=1)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
self.assertEqual(ref3, res3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()