mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Dynamo] Replace torch._dynamo.optimize()
with torch.compile()
[2/N] (#140238)
related commits: - #139706 - #140238 - #140247 - #140253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
42ad54c71b
commit
d6b3ad4de2
@ -229,7 +229,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
for i in range(1, 5):
|
||||
torch._dynamo.reset()
|
||||
model = globals()[f"Module{i}"]()
|
||||
opt_model = torch._dynamo.optimize("eager")(model)
|
||||
opt_model = torch.compile(model, backend="eager")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
opt_model(torch.ones(2, 3, requires_grad=grad)),
|
||||
@ -243,7 +243,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
for model in [Module5(), Module6()]:
|
||||
torch._dynamo.reset()
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_model = torch._dynamo.optimize(cnts)(model)
|
||||
opt_model = torch.compile(model, backend=cnts)
|
||||
for _ in range(3):
|
||||
ref = model(x)
|
||||
res = opt_model(x)
|
||||
@ -252,7 +252,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_linear_setup_context(self):
|
||||
model = ModuleLinear()
|
||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
||||
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
||||
eager_result = model(input, weight)
|
||||
@ -261,7 +261,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_materialize_grad(self):
|
||||
model = MaterializingGradModule()
|
||||
opt_model = torch._dynamo.optimize("eager")(model)
|
||||
opt_model = torch.compile(model, backend="eager")
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
optim_result = opt_model(x)
|
||||
eager_result = model(x)
|
||||
@ -269,7 +269,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_print_in_bwd(self):
|
||||
model = CustomFuncBwdPrintModule()
|
||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
||||
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
|
||||
opt_model(x)
|
||||
@ -323,7 +323,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def test_save_for_bwd(self):
|
||||
model = SaveForBwdModule()
|
||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
||||
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
opt_model(x)
|
||||
|
||||
@ -402,7 +402,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
before = mod(*args, **kwargs)
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
||||
compiled_model = torch.compile(mod, backend="eager")
|
||||
after = compiled_model(*args, **kwargs)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
@ -412,7 +412,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
before = mod(*args, **kwargs)
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
||||
compiled_model = torch.compile(mod, backend="eager")
|
||||
after = compiled_model(*args, **kwargs)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
@ -691,7 +691,7 @@ class GraphModule(torch.nn.Module):
|
||||
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
|
||||
before = mod(*args, **kwargs)
|
||||
|
||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
||||
compiled_model = torch.compile(mod, backend="eager")
|
||||
after = compiled_model(*args, **kwargs)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
@ -859,7 +859,7 @@ class GraphModule(torch.nn.Module):
|
||||
foo = MyFn3.apply(base, False)
|
||||
|
||||
test()
|
||||
opt_test = torch._dynamo.optimize("eager")(test)
|
||||
opt_test = torch.compile(test, backend="eager")
|
||||
opt_test()
|
||||
|
||||
def test_tensor_subclass_intermediary_input(self):
|
||||
|
@ -160,7 +160,7 @@ class NormalizeIRTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
ref = fn(a, b)
|
||||
|
||||
optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
|
||||
optimized_fn = torch.compile(fn, backend="aot_eager")
|
||||
res = optimized_fn(a, b)
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
|
@ -47,7 +47,7 @@ class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
x.register_hook(_multiply_invoke)
|
||||
return x * y
|
||||
|
||||
fn = torch._dynamo.optimize(backend)(fn)
|
||||
fn = torch.compile(fn, backend=backend)
|
||||
out = fn(x, y)
|
||||
grad_out = torch.tensor([2.0, 2.0])
|
||||
out.backward(grad_out)
|
||||
@ -114,7 +114,7 @@ class _multiply_invoke(torch.nn.Module):
|
||||
x.register_hook(_multiply_invoke)
|
||||
return x + y
|
||||
|
||||
fn = torch._dynamo.optimize(backend)(fn)
|
||||
fn = torch.compile(fn, backend=backend)
|
||||
out = fn(x, y)
|
||||
grad_out = torch.tensor([2.0, 2.0])
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
@ -179,7 +179,7 @@ class GraphModule(torch.nn.Module):
|
||||
def fn(x, y):
|
||||
return x + y
|
||||
|
||||
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
|
||||
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
out = fn(x, y)
|
||||
grad_out = torch.tensor([2.0, 2.0])
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
@ -237,7 +237,7 @@ class GraphModule(torch.nn.Module):
|
||||
x.register_hook(_graph_break_invoke)
|
||||
return x + y
|
||||
|
||||
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
|
||||
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
out = fn(x, y)
|
||||
grad_out = torch.tensor([2.0, 2.0])
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -122,7 +122,7 @@ def fn():
|
||||
z *= 3
|
||||
return z
|
||||
|
||||
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -226,7 +226,7 @@ def fn():
|
||||
dummy_fn.__code__ = code
|
||||
self.assertEqual(dummy_fn(), test[3])
|
||||
|
||||
dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
|
||||
dummy_opt = torch.compile(dummy_fn, backend="eager")
|
||||
self.assertEqual(dummy_opt(), test[3])
|
||||
|
||||
def test_exception_table_encode_varint(self):
|
||||
|
@ -35,7 +35,7 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
|
||||
class mylist(list):
|
||||
pass
|
||||
|
||||
@torch._dynamo.optimize(cnt, dynamic=True)
|
||||
@torch.compile(backend=cnt, dynamic=True)
|
||||
def f(x):
|
||||
y = x * 2
|
||||
comptime_print(y)
|
||||
|
@ -61,7 +61,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
def model(x, y):
|
||||
return (x + y) * y
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x, y):
|
||||
for i in range(N_ITERS):
|
||||
loss = model(x, y).sum()
|
||||
@ -78,7 +78,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
b = a.cpu() * 3
|
||||
return b
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x, y):
|
||||
for i in range(N_ITERS):
|
||||
loss = model(x, y).sum()
|
||||
@ -94,7 +94,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
a = x + y
|
||||
return a * 3
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x, y):
|
||||
for i in range(N_ITERS):
|
||||
loss = model(x, y).sum()
|
||||
@ -109,7 +109,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y.add_(3)
|
||||
return x * y
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x, y):
|
||||
for i in range(N_ITERS):
|
||||
with self.subTest(i):
|
||||
@ -129,7 +129,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
c.add_(2)
|
||||
return x * y * 0 + c
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x, y):
|
||||
for i in range(N_ITERS):
|
||||
with self.subTest(i):
|
||||
@ -148,7 +148,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
x.add_(3)
|
||||
return x * y
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(y):
|
||||
for i in range(N_ITERS):
|
||||
with self.subTest(i):
|
||||
@ -168,7 +168,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
x.fill_(2)
|
||||
return x
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x):
|
||||
for i in range(N_ITERS):
|
||||
with self.subTest(i):
|
||||
@ -187,7 +187,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
||||
y.fill_(3)
|
||||
return x, y
|
||||
|
||||
@torch._dynamo.optimize("cudagraphs")
|
||||
@torch.compile(backend="cudagraphs")
|
||||
def fn(x):
|
||||
for i in range(N_ITERS):
|
||||
with self.subTest(i):
|
||||
|
@ -20,7 +20,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_disallow_in_graph(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.optimize(cnts)
|
||||
@torch.compile(backend=cnts)
|
||||
def fn(a):
|
||||
x = torch.add(a, 1)
|
||||
x = torch.add(x, 1)
|
||||
@ -63,7 +63,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
ref = fn(x)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(ref, res)
|
||||
@ -187,7 +187,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_allow_in_graph(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.optimize(cnts)
|
||||
@torch.compile(backend=cnts)
|
||||
def fn(a):
|
||||
x = torch.add(a, 1)
|
||||
x = torch.add(x, 1)
|
||||
@ -214,7 +214,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_graph_break(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.optimize(cnts)
|
||||
@torch.compile(backend=cnts)
|
||||
def fn(x):
|
||||
x = torch.cos(x)
|
||||
x = torch.cos(x)
|
||||
@ -243,7 +243,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
return fn1(x.tan())
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
opt_fn(torch.randn(4))
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
@ -254,7 +254,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
# out of the box
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = operator.indexOf
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
@ -282,7 +282,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = operator.indexOf
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
@ -294,7 +294,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = polyfill
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
out = fn([1, 2, 3, 4, 5], 3)
|
||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||
self.assertEqual(out, opt_out)
|
||||
@ -309,7 +309,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def fn1(x):
|
||||
return torch.sin(x) * 10
|
||||
|
||||
@torch._dynamo.optimize(cnts)
|
||||
@torch.compile(backend=cnts)
|
||||
def fn2(x):
|
||||
x = x + 1
|
||||
x = x + 1
|
||||
@ -318,7 +318,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
@torch._dynamo.optimize(cnts, nopython=True)
|
||||
@torch.compile(backend=cnts, fullgraph=True)
|
||||
def fn3(x):
|
||||
return fn2(x)
|
||||
|
||||
@ -335,14 +335,14 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
def test_disable_optimize(self):
|
||||
cnt = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch._dynamo.optimize(cnt, disable=True)
|
||||
@torch.compile(backend=cnt, disable=True)
|
||||
def f1(x):
|
||||
return x + 1
|
||||
|
||||
f1(torch.ones(6))
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
||||
@torch._dynamo.optimize(cnt, disable=True)
|
||||
@torch.compile(backend=cnt, disable=True)
|
||||
def f2(x):
|
||||
return x + 1
|
||||
|
||||
@ -351,7 +351,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
|
||||
|
||||
@torch._dynamo.optimize(cnt)
|
||||
@torch.compile(backend=cnt)
|
||||
def f3(x):
|
||||
return x + 1
|
||||
|
||||
@ -389,7 +389,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
"torch._guards.TracingContext.current_frame",
|
||||
side_effect=global_context_capture_fn,
|
||||
):
|
||||
torch._dynamo.optimize("eager")(e)(x)
|
||||
torch.compile(e, backend="eager")(x)
|
||||
|
||||
self.assertEqual(len(seen_frames), 0)
|
||||
|
||||
@ -463,7 +463,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
compiles += 1
|
||||
return gm
|
||||
|
||||
@torch._dynamo.optimize(backend=debug_compiler)
|
||||
@torch.compile(backend=debug_compiler)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
|
@ -982,7 +982,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
||||
x_inference = torch.randn(2, 2)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
|
||||
x = torch.randn(2, 2)
|
||||
|
||||
|
@ -56,7 +56,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -71,10 +71,10 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
"""Wrap the second call with torch._dynamo as well"""
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res2 = opt_fn(x)
|
||||
self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
|
||||
|
||||
@ -87,7 +87,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
self.assertTrue(same(res1, x + x + 1))
|
||||
|
||||
@ -104,7 +104,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -118,7 +118,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -136,7 +136,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -150,7 +150,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -164,7 +164,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -177,7 +177,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
res1 = opt_fn(x)
|
||||
res2 = fn(x)
|
||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||
@ -197,7 +197,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
a = torch.randn(10)
|
||||
b = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
v0, s0 = opt_fn(a, b)
|
||||
self.assertEqual(s0, "v0v1")
|
||||
reset_name()
|
||||
@ -221,7 +221,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
||||
a = torch.randn(10)
|
||||
b = torch.randn(10)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
v0, s0 = opt_fn(a, b)
|
||||
self.assertEqual(s0, "v0v1")
|
||||
reset_name()
|
||||
|
@ -15,7 +15,7 @@ from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
def compiler_fn(gm):
|
||||
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
|
||||
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)
|
||||
|
||||
|
||||
def global_hook_0(grad):
|
||||
@ -45,7 +45,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v)
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -58,7 +58,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z * z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -74,7 +74,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -89,7 +89,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z, handle, h2
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -107,7 +107,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z, handle, handle
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -142,7 +142,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
||||
fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
|
||||
mod = torch.nn.Module()
|
||||
@ -165,7 +165,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v)
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -183,7 +183,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v)
|
||||
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -199,7 +199,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z * z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -221,7 +221,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, y * y, z
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -234,7 +234,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, x * x
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v)[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -249,7 +249,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, x * x
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v = fn(v)[0]
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -264,7 +264,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, x * x, h0, h1, h2
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v, r, handle_0, handle_1, handle_2 = fn(v)
|
||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||
@ -286,7 +286,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
return x, x * x
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts)(fn)
|
||||
fn = torch.compile(fn, backend=cnts)
|
||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||
v, r = fn(v)
|
||||
|
||||
@ -315,7 +315,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
out = torch.randn(1, requires_grad=True)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
|
||||
fn = torch.compile(f, backend=cnts, fullgraph=False)
|
||||
res = fn(out)
|
||||
res.backward()
|
||||
self.assertEqual(res, f(out))
|
||||
@ -348,7 +348,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
|
||||
dynamo_out = torch.compile(mod, backend="aot_eager", fullgraph=True)(x2)
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
self.assertEqual(dynamo_out, aot_out)
|
||||
@ -384,7 +384,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
aot_out[0].backward(torch.ones(4))
|
||||
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
|
||||
dynamo_out = torch.compile(mod, backend=backend, fullgraph=True)(x2)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
@ -420,7 +420,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
|
||||
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2)
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
self.assertEqual(dynamo_out, aot_out)
|
||||
@ -464,7 +464,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(obj.count, 2)
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
|
||||
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
self.assertEqual(dynamo_out, eager_out)
|
||||
@ -511,7 +511,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
x2 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
|
||||
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
|
||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
|
||||
dynamo_out[0].backward(torch.ones(4))
|
||||
|
||||
@ -661,7 +661,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
x1 = torch.ones(4, requires_grad=True)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
|
||||
comp_mod = torch.compile(mod, backend=cnts, fullgraph=True)
|
||||
comp_out = comp_mod(x1)
|
||||
comp_out[0].backward(torch.ones(4))
|
||||
|
||||
@ -736,7 +736,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
||||
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
||||
compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
|
||||
compiled_fn = torch.compile(reg_and_mul, backend=cnts, fullgraph=True)
|
||||
|
||||
compiled_bwd_ctx = (
|
||||
compiled_autograd.enable(
|
||||
|
@ -13,7 +13,7 @@ class MinifierTests(MinifierTestBase):
|
||||
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
|
||||
def _test_after_dynamo(self, device, backend, expected_error):
|
||||
run_code = f"""\
|
||||
@torch._dynamo.optimize({backend!r})
|
||||
@torch.compile(backend={backend!r})
|
||||
def inner(x):
|
||||
for _ in range(10):
|
||||
x = torch.sin(x)
|
||||
|
@ -52,7 +52,7 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.tensor([3.0])
|
||||
with RewriteAddToMul():
|
||||
eager_res = fn(x)
|
||||
compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
|
||||
compiled_res = torch.compile(fn, backend=cnt)(x)
|
||||
|
||||
self.assertEqual(eager_res, compiled_res)
|
||||
self.assertEqual(cnt.frame_count, 0)
|
||||
|
@ -57,7 +57,7 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
|
||||
optimizer = torch.optim.Adam([input2], lr=0.1)
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
|
||||
opt_training_iter_fn = torch.compile(training_iter_fn, backend=cnts)
|
||||
batch = {"x": input1, "y": input2}
|
||||
for _ in range(2):
|
||||
opt_training_iter_fn(batch, net, optimizer)
|
||||
|
@ -176,7 +176,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
||||
_variable_2 = 0
|
||||
|
||||
mod = MyModule(mode=mode)
|
||||
model = torch._dynamo.optimize(backend="eager", nopython=mode != 6)(mod)
|
||||
model = torch.compile(mod, backend="eager", fullgraph=mode != 6)
|
||||
assert _variable == 0
|
||||
assert _variable_2 == 0
|
||||
|
||||
|
@ -179,7 +179,7 @@ class TorchRecTests(TestCase):
|
||||
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch._dynamo.optimize(counter, nopython=True)
|
||||
@torch.compile(backend=counter, fullgraph=True)
|
||||
def f(jag_tensor):
|
||||
# The indexing here requires more symbolic reasoning
|
||||
# and doesn't work right now
|
||||
|
@ -91,7 +91,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
||||
s = Seq()
|
||||
i = torch.randn(10)
|
||||
r1 = s(i)
|
||||
opt_s = torch._dynamo.optimize("ts")(s)
|
||||
opt_s = torch.compile(s, backend="ts")
|
||||
r2 = opt_s(i)
|
||||
self.assertTrue(same(r1, r2))
|
||||
|
||||
@ -110,7 +110,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
||||
|
||||
toy_example(i1, i2)
|
||||
try:
|
||||
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
|
||||
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
||||
opt_toy_example(i1, i2)
|
||||
except RuntimeError:
|
||||
pass
|
||||
@ -132,7 +132,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
||||
return transform(gm).forward
|
||||
|
||||
r1 = toy_example(i1, i2)
|
||||
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
|
||||
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
||||
r2 = opt_toy_example(i1, i2)
|
||||
self.assertTrue(not same(r1, r2))
|
||||
|
||||
|
@ -788,7 +788,7 @@ class TestDynamoAOT(JitTestCase):
|
||||
mod = Seq()
|
||||
|
||||
import torch._dynamo
|
||||
aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod)
|
||||
aot_mod = torch.compile(mod, backend="aot_ts", fullgraph=True)
|
||||
|
||||
for _ in range(10):
|
||||
with torch.jit.fuser("fuser3"):
|
||||
|
@ -6372,7 +6372,7 @@ torch.cuda.synchronize()
|
||||
|
||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
|
||||
compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True)
|
||||
check_results(fn, compiled_fn, generate_inp(18))
|
||||
self.assertEqual(compile_counter.frame_count, 1)
|
||||
|
||||
|
Reference in New Issue
Block a user