[Dynamo] Replace torch._dynamo.optimize() with torch.compile() [2/N] (#140238)

related commits:

- #139706
- #140238
- #140247
- #140253

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238
Approved by: https://github.com/soulitzer
This commit is contained in:
Yuanhao Ji
2024-11-13 05:13:37 +00:00
committed by PyTorch MergeBot
parent 42ad54c71b
commit d6b3ad4de2
18 changed files with 87 additions and 87 deletions

View File

@ -229,7 +229,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
for i in range(1, 5):
torch._dynamo.reset()
model = globals()[f"Module{i}"]()
opt_model = torch._dynamo.optimize("eager")(model)
opt_model = torch.compile(model, backend="eager")
self.assertTrue(
torch.allclose(
opt_model(torch.ones(2, 3, requires_grad=grad)),
@ -243,7 +243,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
for model in [Module5(), Module6()]:
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_model = torch._dynamo.optimize(cnts)(model)
opt_model = torch.compile(model, backend=cnts)
for _ in range(3):
ref = model(x)
res = opt_model(x)
@ -252,7 +252,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
def test_linear_setup_context(self):
model = ModuleLinear()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
opt_model = torch.compile(model, backend="eager", fullgraph=True)
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
eager_result = model(input, weight)
@ -261,7 +261,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
def test_materialize_grad(self):
model = MaterializingGradModule()
opt_model = torch._dynamo.optimize("eager")(model)
opt_model = torch.compile(model, backend="eager")
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
optim_result = opt_model(x)
eager_result = model(x)
@ -269,7 +269,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
def test_print_in_bwd(self):
model = CustomFuncBwdPrintModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
opt_model = torch.compile(model, backend="eager", fullgraph=True)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
opt_model(x)
@ -323,7 +323,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
def test_save_for_bwd(self):
model = SaveForBwdModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
opt_model = torch.compile(model, backend="eager", fullgraph=True)
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
opt_model(x)
@ -402,7 +402,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
compiled_model = torch.compile(mod, backend="eager")
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
@ -412,7 +412,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
before = mod(*args, **kwargs)
torch._dynamo.reset()
compiled_model = torch._dynamo.optimize("eager")(mod)
compiled_model = torch.compile(mod, backend="eager")
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
@ -691,7 +691,7 @@ class GraphModule(torch.nn.Module):
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
before = mod(*args, **kwargs)
compiled_model = torch._dynamo.optimize("eager")(mod)
compiled_model = torch.compile(mod, backend="eager")
after = compiled_model(*args, **kwargs)
self.assertEqual(before, after)
@ -859,7 +859,7 @@ class GraphModule(torch.nn.Module):
foo = MyFn3.apply(base, False)
test()
opt_test = torch._dynamo.optimize("eager")(test)
opt_test = torch.compile(test, backend="eager")
opt_test()
def test_tensor_subclass_intermediary_input(self):

View File

@ -160,7 +160,7 @@ class NormalizeIRTests(torch._dynamo.test_case.TestCase):
ref = fn(a, b)
optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
optimized_fn = torch.compile(fn, backend="aot_eager")
res = optimized_fn(a, b)
self.assertTrue(same(ref, res))

View File

@ -47,7 +47,7 @@ class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
x.register_hook(_multiply_invoke)
return x * y
fn = torch._dynamo.optimize(backend)(fn)
fn = torch.compile(fn, backend=backend)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
out.backward(grad_out)
@ -114,7 +114,7 @@ class _multiply_invoke(torch.nn.Module):
x.register_hook(_multiply_invoke)
return x + y
fn = torch._dynamo.optimize(backend)(fn)
fn = torch.compile(fn, backend=backend)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
@ -179,7 +179,7 @@ class GraphModule(torch.nn.Module):
def fn(x, y):
return x + y
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
fn = torch.compile(fn, backend=backend, fullgraph=True)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with compiled_autograd.enable(compiler_fn):
@ -237,7 +237,7 @@ class GraphModule(torch.nn.Module):
x.register_hook(_graph_break_invoke)
return x + y
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
fn = torch.compile(fn, backend=backend, fullgraph=True)
out = fn(x, y)
grad_out = torch.tensor([2.0, 2.0])
with self.assertRaisesRegex(

View File

@ -122,7 +122,7 @@ def fn():
z *= 3
return z
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
opt_f = torch.compile(f, backend="eager", fullgraph=True)
self.assertEqual(opt_f(None, torch.ones(2)), 6)
if sys.version_info >= (3, 11):
@ -226,7 +226,7 @@ def fn():
dummy_fn.__code__ = code
self.assertEqual(dummy_fn(), test[3])
dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
dummy_opt = torch.compile(dummy_fn, backend="eager")
self.assertEqual(dummy_opt(), test[3])
def test_exception_table_encode_varint(self):

View File

@ -35,7 +35,7 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
class mylist(list):
pass
@torch._dynamo.optimize(cnt, dynamic=True)
@torch.compile(backend=cnt, dynamic=True)
def f(x):
y = x * 2
comptime_print(y)

View File

@ -61,7 +61,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
def model(x, y):
return (x + y) * y
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
@ -78,7 +78,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
b = a.cpu() * 3
return b
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
@ -94,7 +94,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
a = x + y
return a * 3
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
loss = model(x, y).sum()
@ -109,7 +109,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
y.add_(3)
return x * y
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
with self.subTest(i):
@ -129,7 +129,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
c.add_(2)
return x * y * 0 + c
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x, y):
for i in range(N_ITERS):
with self.subTest(i):
@ -148,7 +148,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
x.add_(3)
return x * y
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(y):
for i in range(N_ITERS):
with self.subTest(i):
@ -168,7 +168,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
x.fill_(2)
return x
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x):
for i in range(N_ITERS):
with self.subTest(i):
@ -187,7 +187,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
y.fill_(3)
return x, y
@torch._dynamo.optimize("cudagraphs")
@torch.compile(backend="cudagraphs")
def fn(x):
for i in range(N_ITERS):
with self.subTest(i):

View File

@ -20,7 +20,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_disallow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
@torch.compile(backend=cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
@ -63,7 +63,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res = opt_fn(x)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(ref, res)
@ -187,7 +187,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_allow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
@torch.compile(backend=cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
@ -214,7 +214,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
@torch.compile(backend=cnts)
def fn(x):
x = torch.cos(x)
x = torch.cos(x)
@ -243,7 +243,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
return fn1(x.tan())
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
opt_fn(torch.randn(4))
self.assertEqual(cnts.frame_count, 2)
@ -254,7 +254,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
# out of the box
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
@ -282,7 +282,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
cnts = torch._dynamo.testing.CompileCounter()
fn = operator.indexOf
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
@ -294,7 +294,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
cnts = torch._dynamo.testing.CompileCounter()
fn = polyfill
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
out = fn([1, 2, 3, 4, 5], 3)
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
self.assertEqual(out, opt_out)
@ -309,7 +309,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def fn1(x):
return torch.sin(x) * 10
@torch._dynamo.optimize(cnts)
@torch.compile(backend=cnts)
def fn2(x):
x = x + 1
x = x + 1
@ -318,7 +318,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
x = x + 1
return x
@torch._dynamo.optimize(cnts, nopython=True)
@torch.compile(backend=cnts, fullgraph=True)
def fn3(x):
return fn2(x)
@ -335,14 +335,14 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt, disable=True)
@torch.compile(backend=cnt, disable=True)
def f1(x):
return x + 1
f1(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
@torch._dynamo.optimize(cnt, disable=True)
@torch.compile(backend=cnt, disable=True)
def f2(x):
return x + 1
@ -351,7 +351,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f3(x):
return x + 1
@ -389,7 +389,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
"torch._guards.TracingContext.current_frame",
side_effect=global_context_capture_fn,
):
torch._dynamo.optimize("eager")(e)(x)
torch.compile(e, backend="eager")(x)
self.assertEqual(len(seen_frames), 0)
@ -463,7 +463,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
compiles += 1
return gm
@torch._dynamo.optimize(backend=debug_compiler)
@torch.compile(backend=debug_compiler)
def fn(x):
return x + 1

View File

@ -982,7 +982,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
x_inference = torch.randn(2, 2)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
x = torch.randn(2, 2)

View File

@ -56,7 +56,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -71,10 +71,10 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
"""Wrap the second call with torch._dynamo as well"""
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res2 = opt_fn(x)
self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
@ -87,7 +87,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
self.assertTrue(same(res1, x + x + 1))
@ -104,7 +104,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -118,7 +118,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -136,7 +136,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -150,7 +150,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -164,7 +164,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -177,7 +177,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
@ -197,7 +197,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
a = torch.randn(10)
b = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
v0, s0 = opt_fn(a, b)
self.assertEqual(s0, "v0v1")
reset_name()
@ -221,7 +221,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
a = torch.randn(10)
b = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn = torch.compile(fn, backend=cnts)
v0, s0 = opt_fn(a, b)
self.assertEqual(s0, "v0v1")
reset_name()

View File

@ -15,7 +15,7 @@ from torch.utils.hooks import RemovableHandle
def compiler_fn(gm):
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)
def global_hook_0(grad):
@ -45,7 +45,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -58,7 +58,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -74,7 +74,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -89,7 +89,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z, handle, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -107,7 +107,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z, handle, handle
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -142,7 +142,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
fn = torch.compile(fn, backend=cnts, fullgraph=True)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
mod = torch.nn.Module()
@ -165,7 +165,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -183,7 +183,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
@ -199,7 +199,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z * z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -221,7 +221,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, y * y, z
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -234,7 +234,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -249,7 +249,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v = fn(v)[0]
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -264,7 +264,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, x * x, h0, h1, h2
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r, handle_0, handle_1, handle_2 = fn(v)
v.backward(torch.tensor([1.0, 2.0, 3.0]))
@ -286,7 +286,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
return x, x * x
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts)(fn)
fn = torch.compile(fn, backend=cnts)
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
v, r = fn(v)
@ -315,7 +315,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
out = torch.randn(1, requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
fn = torch.compile(f, backend=cnts, fullgraph=False)
res = fn(out)
res.backward()
self.assertEqual(res, f(out))
@ -348,7 +348,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
dynamo_out = torch.compile(mod, backend="aot_eager", fullgraph=True)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
@ -384,7 +384,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
aot_out[0].backward(torch.ones(4))
x2 = torch.ones(4, requires_grad=True)
dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
dynamo_out = torch.compile(mod, backend=backend, fullgraph=True)(x2)
with compiled_autograd.enable(compiler_fn):
dynamo_out[0].backward(torch.ones(4))
@ -420,7 +420,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, aot_out)
@ -464,7 +464,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
self.assertEqual(obj.count, 2)
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
dynamo_out[0].backward(torch.ones(4))
self.assertEqual(dynamo_out, eager_out)
@ -511,7 +511,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
x2 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
dynamo_out[0].backward(torch.ones(4))
@ -661,7 +661,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
x1 = torch.ones(4, requires_grad=True)
with compiled_autograd.enable(compiler_fn):
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
comp_mod = torch.compile(mod, backend=cnts, fullgraph=True)
comp_out = comp_mod(x1)
comp_out[0].backward(torch.ones(4))
@ -736,7 +736,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
compiled_fn = torch.compile(reg_and_mul, backend=cnts, fullgraph=True)
compiled_bwd_ctx = (
compiled_autograd.enable(

View File

@ -13,7 +13,7 @@ class MinifierTests(MinifierTestBase):
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
def _test_after_dynamo(self, device, backend, expected_error):
run_code = f"""\
@torch._dynamo.optimize({backend!r})
@torch.compile(backend={backend!r})
def inner(x):
for _ in range(10):
x = torch.sin(x)

View File

@ -52,7 +52,7 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
x = torch.tensor([3.0])
with RewriteAddToMul():
eager_res = fn(x)
compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
compiled_res = torch.compile(fn, backend=cnt)(x)
self.assertEqual(eager_res, compiled_res)
self.assertEqual(cnt.frame_count, 0)

View File

@ -57,7 +57,7 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
optimizer = torch.optim.Adam([input2], lr=0.1)
cnts = torch._dynamo.testing.CompileCounter()
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
opt_training_iter_fn = torch.compile(training_iter_fn, backend=cnts)
batch = {"x": input1, "y": input2}
for _ in range(2):
opt_training_iter_fn(batch, net, optimizer)

View File

@ -176,7 +176,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
_variable_2 = 0
mod = MyModule(mode=mode)
model = torch._dynamo.optimize(backend="eager", nopython=mode != 6)(mod)
model = torch.compile(mod, backend="eager", fullgraph=mode != 6)
assert _variable == 0
assert _variable_2 == 0

View File

@ -179,7 +179,7 @@ class TorchRecTests(TestCase):
counter = CompileCounter()
@torch._dynamo.optimize(counter, nopython=True)
@torch.compile(backend=counter, fullgraph=True)
def f(jag_tensor):
# The indexing here requires more symbolic reasoning
# and doesn't work right now

View File

@ -91,7 +91,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
s = Seq()
i = torch.randn(10)
r1 = s(i)
opt_s = torch._dynamo.optimize("ts")(s)
opt_s = torch.compile(s, backend="ts")
r2 = opt_s(i)
self.assertTrue(same(r1, r2))
@ -110,7 +110,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
toy_example(i1, i2)
try:
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
opt_toy_example(i1, i2)
except RuntimeError:
pass
@ -132,7 +132,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
return transform(gm).forward
r1 = toy_example(i1, i2)
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
r2 = opt_toy_example(i1, i2)
self.assertTrue(not same(r1, r2))

View File

@ -788,7 +788,7 @@ class TestDynamoAOT(JitTestCase):
mod = Seq()
import torch._dynamo
aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod)
aot_mod = torch.compile(mod, backend="aot_ts", fullgraph=True)
for _ in range(10):
with torch.jit.fuser("fuser3"):

View File

@ -6372,7 +6372,7 @@ torch.cuda.synchronize()
compile_counter = torch._dynamo.testing.CompileCounter()
compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True)
check_results(fn, compiled_fn, generate_inp(18))
self.assertEqual(compile_counter.frame_count, 1)