mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159534 Approved by: https://github.com/jansel
148 lines
3.8 KiB
Python
148 lines
3.8 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
|
|
|
|
def my_custom_function(x):
|
|
return x + 1
|
|
|
|
|
|
class RunDiffGuardTests(torch._dynamo.test_case.TestCase):
|
|
def test_bool_recompile(self):
|
|
def fn(x, y, c):
|
|
if c:
|
|
return x * y
|
|
else:
|
|
return x + y
|
|
|
|
opt_fn = torch.compile(fn, backend="inductor")
|
|
x = 2 * torch.ones(4)
|
|
y = 3 * torch.ones(4)
|
|
|
|
ref1 = opt_fn(x, y, True)
|
|
ref2 = opt_fn(x, y, False)
|
|
|
|
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
|
|
res2 = opt_fn(x, y, False)
|
|
res1 = opt_fn(x, y, True)
|
|
|
|
self.assertEqual(ref1, res1)
|
|
self.assertEqual(ref2, res2)
|
|
|
|
def test_tensor_recompile(self):
|
|
def fn(x, y):
|
|
return x * y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.randn(4, dtype=torch.float32)
|
|
y = torch.randn(4, dtype=torch.float32)
|
|
|
|
ref1 = opt_fn(x, y)
|
|
|
|
x64 = torch.randn(4, dtype=torch.float64)
|
|
y64 = torch.randn(4, dtype=torch.float64)
|
|
ref2 = opt_fn(x64, y64)
|
|
|
|
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
|
|
res1 = opt_fn(x, y)
|
|
res2 = opt_fn(x64, y64)
|
|
|
|
self.assertEqual(ref1, res1)
|
|
self.assertEqual(ref2, res2)
|
|
|
|
def test_post_recompile(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.a = 4
|
|
self.b = 5
|
|
|
|
foo = Foo()
|
|
|
|
def fn(x):
|
|
return x + foo.a + foo.b
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
foo.a = 11
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
|
|
# Set it back to original value
|
|
foo.a = 4
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
foo.a = 11
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Check that we are back to original behavior
|
|
foo.b = 8
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
|
|
def test_fail_on_tensor_shape_change(self):
|
|
def fn(dt):
|
|
return dt["x"] + 1
|
|
|
|
x = torch.randn(4)
|
|
dt = {}
|
|
dt["x"] = x
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn(dt)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Recompilation triggered with skip_guard_eval_unsafe stance"
|
|
):
|
|
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
|
|
x = torch.randn(4, 4)
|
|
dt["x"] = x
|
|
opt_fn(dt)
|
|
|
|
def test_cache_line_pickup(self):
|
|
def fn(x, a=None, b=None):
|
|
x = x * 3
|
|
if a:
|
|
x = x * 5
|
|
if b:
|
|
x = x * 7
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.ones(4)
|
|
|
|
ref1 = opt_fn(x, a=None, b=None)
|
|
ref2 = opt_fn(x, a=1, b=None)
|
|
ref3 = opt_fn(x, a=1, b=1)
|
|
|
|
with torch.compiler.set_stance(skip_guard_eval_unsafe=True):
|
|
res1 = opt_fn(x, a=None, b=None)
|
|
res2 = opt_fn(x, a=1, b=None)
|
|
res3 = opt_fn(x, a=1, b=1)
|
|
|
|
self.assertEqual(ref1, res1)
|
|
self.assertEqual(ref2, res2)
|
|
self.assertEqual(ref3, res3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|