mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Replace torch._dynamo.optimize() with torch.compile() [2/N] (#140238)
related commits: - #139706 - #140238 - #140247 - #140253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
42ad54c71b
commit
d6b3ad4de2
@ -229,7 +229,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
for i in range(1, 5):
|
for i in range(1, 5):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
model = globals()[f"Module{i}"]()
|
model = globals()[f"Module{i}"]()
|
||||||
opt_model = torch._dynamo.optimize("eager")(model)
|
opt_model = torch.compile(model, backend="eager")
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
opt_model(torch.ones(2, 3, requires_grad=grad)),
|
opt_model(torch.ones(2, 3, requires_grad=grad)),
|
||||||
@ -243,7 +243,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
for model in [Module5(), Module6()]:
|
for model in [Module5(), Module6()]:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_model = torch._dynamo.optimize(cnts)(model)
|
opt_model = torch.compile(model, backend=cnts)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
ref = model(x)
|
ref = model(x)
|
||||||
res = opt_model(x)
|
res = opt_model(x)
|
||||||
@ -252,7 +252,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def test_linear_setup_context(self):
|
def test_linear_setup_context(self):
|
||||||
model = ModuleLinear()
|
model = ModuleLinear()
|
||||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||||
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||||
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
||||||
eager_result = model(input, weight)
|
eager_result = model(input, weight)
|
||||||
@ -261,7 +261,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def test_materialize_grad(self):
|
def test_materialize_grad(self):
|
||||||
model = MaterializingGradModule()
|
model = MaterializingGradModule()
|
||||||
opt_model = torch._dynamo.optimize("eager")(model)
|
opt_model = torch.compile(model, backend="eager")
|
||||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||||
optim_result = opt_model(x)
|
optim_result = opt_model(x)
|
||||||
eager_result = model(x)
|
eager_result = model(x)
|
||||||
@ -269,7 +269,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def test_print_in_bwd(self):
|
def test_print_in_bwd(self):
|
||||||
model = CustomFuncBwdPrintModule()
|
model = CustomFuncBwdPrintModule()
|
||||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
|
||||||
opt_model(x)
|
opt_model(x)
|
||||||
@ -323,7 +323,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
def test_save_for_bwd(self):
|
def test_save_for_bwd(self):
|
||||||
model = SaveForBwdModule()
|
model = SaveForBwdModule()
|
||||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
||||||
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||||
opt_model(x)
|
opt_model(x)
|
||||||
|
|
||||||
@ -402,7 +402,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
before = mod(*args, **kwargs)
|
before = mod(*args, **kwargs)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
compiled_model = torch.compile(mod, backend="eager")
|
||||||
after = compiled_model(*args, **kwargs)
|
after = compiled_model(*args, **kwargs)
|
||||||
self.assertEqual(before, after)
|
self.assertEqual(before, after)
|
||||||
|
|
||||||
@ -412,7 +412,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
before = mod(*args, **kwargs)
|
before = mod(*args, **kwargs)
|
||||||
|
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
compiled_model = torch.compile(mod, backend="eager")
|
||||||
after = compiled_model(*args, **kwargs)
|
after = compiled_model(*args, **kwargs)
|
||||||
self.assertEqual(before, after)
|
self.assertEqual(before, after)
|
||||||
|
|
||||||
@ -691,7 +691,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
|
args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
|
||||||
before = mod(*args, **kwargs)
|
before = mod(*args, **kwargs)
|
||||||
|
|
||||||
compiled_model = torch._dynamo.optimize("eager")(mod)
|
compiled_model = torch.compile(mod, backend="eager")
|
||||||
after = compiled_model(*args, **kwargs)
|
after = compiled_model(*args, **kwargs)
|
||||||
self.assertEqual(before, after)
|
self.assertEqual(before, after)
|
||||||
|
|
||||||
@ -859,7 +859,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
foo = MyFn3.apply(base, False)
|
foo = MyFn3.apply(base, False)
|
||||||
|
|
||||||
test()
|
test()
|
||||||
opt_test = torch._dynamo.optimize("eager")(test)
|
opt_test = torch.compile(test, backend="eager")
|
||||||
opt_test()
|
opt_test()
|
||||||
|
|
||||||
def test_tensor_subclass_intermediary_input(self):
|
def test_tensor_subclass_intermediary_input(self):
|
||||||
|
|||||||
@ -160,7 +160,7 @@ class NormalizeIRTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
ref = fn(a, b)
|
ref = fn(a, b)
|
||||||
|
|
||||||
optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
|
optimized_fn = torch.compile(fn, backend="aot_eager")
|
||||||
res = optimized_fn(a, b)
|
res = optimized_fn(a, b)
|
||||||
self.assertTrue(same(ref, res))
|
self.assertTrue(same(ref, res))
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class BackwardHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|||||||
x.register_hook(_multiply_invoke)
|
x.register_hook(_multiply_invoke)
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
fn = torch._dynamo.optimize(backend)(fn)
|
fn = torch.compile(fn, backend=backend)
|
||||||
out = fn(x, y)
|
out = fn(x, y)
|
||||||
grad_out = torch.tensor([2.0, 2.0])
|
grad_out = torch.tensor([2.0, 2.0])
|
||||||
out.backward(grad_out)
|
out.backward(grad_out)
|
||||||
@ -114,7 +114,7 @@ class _multiply_invoke(torch.nn.Module):
|
|||||||
x.register_hook(_multiply_invoke)
|
x.register_hook(_multiply_invoke)
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
fn = torch._dynamo.optimize(backend)(fn)
|
fn = torch.compile(fn, backend=backend)
|
||||||
out = fn(x, y)
|
out = fn(x, y)
|
||||||
grad_out = torch.tensor([2.0, 2.0])
|
grad_out = torch.tensor([2.0, 2.0])
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
@ -179,7 +179,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
|
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||||
out = fn(x, y)
|
out = fn(x, y)
|
||||||
grad_out = torch.tensor([2.0, 2.0])
|
grad_out = torch.tensor([2.0, 2.0])
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
@ -237,7 +237,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
x.register_hook(_graph_break_invoke)
|
x.register_hook(_graph_break_invoke)
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
fn = torch._dynamo.optimize(backend, nopython=True)(fn)
|
fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||||
out = fn(x, y)
|
out = fn(x, y)
|
||||||
grad_out = torch.tensor([2.0, 2.0])
|
grad_out = torch.tensor([2.0, 2.0])
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|||||||
@ -122,7 +122,7 @@ def fn():
|
|||||||
z *= 3
|
z *= 3
|
||||||
return z
|
return z
|
||||||
|
|
||||||
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||||
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
@ -226,7 +226,7 @@ def fn():
|
|||||||
dummy_fn.__code__ = code
|
dummy_fn.__code__ = code
|
||||||
self.assertEqual(dummy_fn(), test[3])
|
self.assertEqual(dummy_fn(), test[3])
|
||||||
|
|
||||||
dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
|
dummy_opt = torch.compile(dummy_fn, backend="eager")
|
||||||
self.assertEqual(dummy_opt(), test[3])
|
self.assertEqual(dummy_opt(), test[3])
|
||||||
|
|
||||||
def test_exception_table_encode_varint(self):
|
def test_exception_table_encode_varint(self):
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class ComptimeTests(torch._dynamo.test_case.TestCase):
|
|||||||
class mylist(list):
|
class mylist(list):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnt, dynamic=True)
|
@torch.compile(backend=cnt, dynamic=True)
|
||||||
def f(x):
|
def f(x):
|
||||||
y = x * 2
|
y = x * 2
|
||||||
comptime_print(y)
|
comptime_print(y)
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
def model(x, y):
|
def model(x, y):
|
||||||
return (x + y) * y
|
return (x + y) * y
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
@ -78,7 +78,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
b = a.cpu() * 3
|
b = a.cpu() * 3
|
||||||
return b
|
return b
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
@ -94,7 +94,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
a = x + y
|
a = x + y
|
||||||
return a * 3
|
return a * 3
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
loss = model(x, y).sum()
|
loss = model(x, y).sum()
|
||||||
@ -109,7 +109,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
y.add_(3)
|
y.add_(3)
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
with self.subTest(i):
|
with self.subTest(i):
|
||||||
@ -129,7 +129,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
c.add_(2)
|
c.add_(2)
|
||||||
return x * y * 0 + c
|
return x * y * 0 + c
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
with self.subTest(i):
|
with self.subTest(i):
|
||||||
@ -148,7 +148,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
x.add_(3)
|
x.add_(3)
|
||||||
return x * y
|
return x * y
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(y):
|
def fn(y):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
with self.subTest(i):
|
with self.subTest(i):
|
||||||
@ -168,7 +168,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
x.fill_(2)
|
x.fill_(2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
with self.subTest(i):
|
with self.subTest(i):
|
||||||
@ -187,7 +187,7 @@ class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
|
|||||||
y.fill_(3)
|
y.fill_(3)
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
@torch._dynamo.optimize("cudagraphs")
|
@torch.compile(backend="cudagraphs")
|
||||||
def fn(x):
|
def fn(x):
|
||||||
for i in range(N_ITERS):
|
for i in range(N_ITERS):
|
||||||
with self.subTest(i):
|
with self.subTest(i):
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
def test_disallow_in_graph(self):
|
def test_disallow_in_graph(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnts)
|
@torch.compile(backend=cnts)
|
||||||
def fn(a):
|
def fn(a):
|
||||||
x = torch.add(a, 1)
|
x = torch.add(a, 1)
|
||||||
x = torch.add(x, 1)
|
x = torch.add(x, 1)
|
||||||
@ -63,7 +63,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
ref = fn(x)
|
ref = fn(x)
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res = opt_fn(x)
|
res = opt_fn(x)
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
self.assertEqual(ref, res)
|
self.assertEqual(ref, res)
|
||||||
@ -187,7 +187,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
def test_allow_in_graph(self):
|
def test_allow_in_graph(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnts)
|
@torch.compile(backend=cnts)
|
||||||
def fn(a):
|
def fn(a):
|
||||||
x = torch.add(a, 1)
|
x = torch.add(a, 1)
|
||||||
x = torch.add(x, 1)
|
x = torch.add(x, 1)
|
||||||
@ -214,7 +214,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
def test_graph_break(self):
|
def test_graph_break(self):
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnts)
|
@torch.compile(backend=cnts)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
x = torch.cos(x)
|
x = torch.cos(x)
|
||||||
x = torch.cos(x)
|
x = torch.cos(x)
|
||||||
@ -243,7 +243,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
return fn1(x.tan())
|
return fn1(x.tan())
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
opt_fn(torch.randn(4))
|
opt_fn(torch.randn(4))
|
||||||
self.assertEqual(cnts.frame_count, 2)
|
self.assertEqual(cnts.frame_count, 2)
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
# out of the box
|
# out of the box
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = operator.indexOf
|
fn = operator.indexOf
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
out = fn([1, 2, 3, 4, 5], 3)
|
out = fn([1, 2, 3, 4, 5], 3)
|
||||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||||
self.assertEqual(out, opt_out)
|
self.assertEqual(out, opt_out)
|
||||||
@ -282,7 +282,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = operator.indexOf
|
fn = operator.indexOf
|
||||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||||
out = fn([1, 2, 3, 4, 5], 3)
|
out = fn([1, 2, 3, 4, 5], 3)
|
||||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||||
self.assertEqual(out, opt_out)
|
self.assertEqual(out, opt_out)
|
||||||
@ -294,7 +294,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = polyfill
|
fn = polyfill
|
||||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||||
out = fn([1, 2, 3, 4, 5], 3)
|
out = fn([1, 2, 3, 4, 5], 3)
|
||||||
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
opt_out = opt_fn([1, 2, 3, 4, 5], 3)
|
||||||
self.assertEqual(out, opt_out)
|
self.assertEqual(out, opt_out)
|
||||||
@ -309,7 +309,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
def fn1(x):
|
def fn1(x):
|
||||||
return torch.sin(x) * 10
|
return torch.sin(x) * 10
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnts)
|
@torch.compile(backend=cnts)
|
||||||
def fn2(x):
|
def fn2(x):
|
||||||
x = x + 1
|
x = x + 1
|
||||||
x = x + 1
|
x = x + 1
|
||||||
@ -318,7 +318,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
x = x + 1
|
x = x + 1
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnts, nopython=True)
|
@torch.compile(backend=cnts, fullgraph=True)
|
||||||
def fn3(x):
|
def fn3(x):
|
||||||
return fn2(x)
|
return fn2(x)
|
||||||
|
|
||||||
@ -335,14 +335,14 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
def test_disable_optimize(self):
|
def test_disable_optimize(self):
|
||||||
cnt = torch._dynamo.testing.CompileCounter()
|
cnt = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnt, disable=True)
|
@torch.compile(backend=cnt, disable=True)
|
||||||
def f1(x):
|
def f1(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
f1(torch.ones(6))
|
f1(torch.ones(6))
|
||||||
self.assertEqual(cnt.frame_count, 0)
|
self.assertEqual(cnt.frame_count, 0)
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnt, disable=True)
|
@torch.compile(backend=cnt, disable=True)
|
||||||
def f2(x):
|
def f2(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
@ -351,7 +351,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
|
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
|
||||||
|
|
||||||
@torch._dynamo.optimize(cnt)
|
@torch.compile(backend=cnt)
|
||||||
def f3(x):
|
def f3(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
@ -389,7 +389,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
"torch._guards.TracingContext.current_frame",
|
"torch._guards.TracingContext.current_frame",
|
||||||
side_effect=global_context_capture_fn,
|
side_effect=global_context_capture_fn,
|
||||||
):
|
):
|
||||||
torch._dynamo.optimize("eager")(e)(x)
|
torch.compile(e, backend="eager")(x)
|
||||||
|
|
||||||
self.assertEqual(len(seen_frames), 0)
|
self.assertEqual(len(seen_frames), 0)
|
||||||
|
|
||||||
@ -463,7 +463,7 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
|||||||
compiles += 1
|
compiles += 1
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
@torch._dynamo.optimize(backend=debug_compiler)
|
@torch.compile(backend=debug_compiler)
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return x + 1
|
return x + 1
|
||||||
|
|
||||||
|
|||||||
@ -982,7 +982,7 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
|
|||||||
x_inference = torch.randn(2, 2)
|
x_inference = torch.randn(2, 2)
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||||
|
|
||||||
x = torch.randn(2, 2)
|
x = torch.randn(2, 2)
|
||||||
|
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -71,10 +71,10 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
"""Wrap the second call with torch._dynamo as well"""
|
"""Wrap the second call with torch._dynamo as well"""
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res2 = opt_fn(x)
|
res2 = opt_fn(x)
|
||||||
self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
|
self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
self.assertTrue(same(res1, x + x + 1))
|
self.assertTrue(same(res1, x + x + 1))
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -118,7 +118,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -136,7 +136,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -150,7 +150,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -164,7 +164,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -177,7 +177,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x = torch.randn(10)
|
x = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
res1 = opt_fn(x)
|
res1 = opt_fn(x)
|
||||||
res2 = fn(x)
|
res2 = fn(x)
|
||||||
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
self.assertTrue(same(res2 - res1, torch.ones(10)))
|
||||||
@ -197,7 +197,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
a = torch.randn(10)
|
a = torch.randn(10)
|
||||||
b = torch.randn(10)
|
b = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
v0, s0 = opt_fn(a, b)
|
v0, s0 = opt_fn(a, b)
|
||||||
self.assertEqual(s0, "v0v1")
|
self.assertEqual(s0, "v0v1")
|
||||||
reset_name()
|
reset_name()
|
||||||
@ -221,7 +221,7 @@ class TestGlobals(torch._dynamo.test_case.TestCase):
|
|||||||
a = torch.randn(10)
|
a = torch.randn(10)
|
||||||
b = torch.randn(10)
|
b = torch.randn(10)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
opt_fn = torch.compile(fn, backend=cnts)
|
||||||
v0, s0 = opt_fn(a, b)
|
v0, s0 = opt_fn(a, b)
|
||||||
self.assertEqual(s0, "v0v1")
|
self.assertEqual(s0, "v0v1")
|
||||||
reset_name()
|
reset_name()
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from torch.utils.hooks import RemovableHandle
|
|||||||
|
|
||||||
|
|
||||||
def compiler_fn(gm):
|
def compiler_fn(gm):
|
||||||
return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
|
return torch.compile(gm, backend="inductor", fullgraph=True, dynamic=True)
|
||||||
|
|
||||||
|
|
||||||
def global_hook_0(grad):
|
def global_hook_0(grad):
|
||||||
@ -45,7 +45,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v)
|
v = fn(v)
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -58,7 +58,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z * z
|
return x, y * y, z * z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -74,7 +74,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z
|
return x, y * y, z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -89,7 +89,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z, handle, h2
|
return x, y * y, z, handle, h2
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -107,7 +107,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z, handle, handle
|
return x, y * y, z, handle, handle
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -142,7 +142,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z
|
return x, y * y, z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
fn = torch.compile(fn, backend=cnts, fullgraph=True)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
|
|
||||||
mod = torch.nn.Module()
|
mod = torch.nn.Module()
|
||||||
@ -165,7 +165,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v)
|
v = fn(v)
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -183,7 +183,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, z
|
return x, z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v)
|
v = fn(v)
|
||||||
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
|
v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -199,7 +199,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z * z
|
return x, y * y, z * z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -221,7 +221,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, y * y, z
|
return x, y * y, z
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -234,7 +234,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, x * x
|
return x, x * x
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v)[0]
|
v = fn(v)[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -249,7 +249,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, x * x
|
return x, x * x
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v = fn(v)[0]
|
v = fn(v)[0]
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -264,7 +264,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, x * x, h0, h1, h2
|
return x, x * x, h0, h1, h2
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v, r, handle_0, handle_1, handle_2 = fn(v)
|
v, r, handle_0, handle_1, handle_2 = fn(v)
|
||||||
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
v.backward(torch.tensor([1.0, 2.0, 3.0]))
|
||||||
@ -286,7 +286,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
return x, x * x
|
return x, x * x
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts)(fn)
|
fn = torch.compile(fn, backend=cnts)
|
||||||
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
|
||||||
v, r = fn(v)
|
v, r = fn(v)
|
||||||
|
|
||||||
@ -315,7 +315,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
out = torch.randn(1, requires_grad=True)
|
out = torch.randn(1, requires_grad=True)
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
fn = torch._dynamo.optimize(cnts, nopython=False)(f)
|
fn = torch.compile(f, backend=cnts, fullgraph=False)
|
||||||
res = fn(out)
|
res = fn(out)
|
||||||
res.backward()
|
res.backward()
|
||||||
self.assertEqual(res, f(out))
|
self.assertEqual(res, f(out))
|
||||||
@ -348,7 +348,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x2 = torch.ones(4, requires_grad=True)
|
x2 = torch.ones(4, requires_grad=True)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
|
dynamo_out = torch.compile(mod, backend="aot_eager", fullgraph=True)(x2)
|
||||||
dynamo_out[0].backward(torch.ones(4))
|
dynamo_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
self.assertEqual(dynamo_out, aot_out)
|
self.assertEqual(dynamo_out, aot_out)
|
||||||
@ -384,7 +384,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
aot_out[0].backward(torch.ones(4))
|
aot_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
x2 = torch.ones(4, requires_grad=True)
|
x2 = torch.ones(4, requires_grad=True)
|
||||||
dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
|
dynamo_out = torch.compile(mod, backend=backend, fullgraph=True)(x2)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
dynamo_out[0].backward(torch.ones(4))
|
dynamo_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
@ -420,7 +420,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x2 = torch.ones(4, requires_grad=True)
|
x2 = torch.ones(4, requires_grad=True)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
|
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2)
|
||||||
dynamo_out[0].backward(torch.ones(4))
|
dynamo_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
self.assertEqual(dynamo_out, aot_out)
|
self.assertEqual(dynamo_out, aot_out)
|
||||||
@ -464,7 +464,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(obj.count, 2)
|
self.assertEqual(obj.count, 2)
|
||||||
x2 = torch.ones(4, requires_grad=True)
|
x2 = torch.ones(4, requires_grad=True)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
|
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
|
||||||
dynamo_out[0].backward(torch.ones(4))
|
dynamo_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
self.assertEqual(dynamo_out, eager_out)
|
self.assertEqual(dynamo_out, eager_out)
|
||||||
@ -511,7 +511,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
x2 = torch.ones(4, requires_grad=True)
|
x2 = torch.ones(4, requires_grad=True)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
|
dynamo_out = torch.compile(mod, backend="inductor", fullgraph=True)(x2, obj)
|
||||||
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
|
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
|
||||||
dynamo_out[0].backward(torch.ones(4))
|
dynamo_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
@ -661,7 +661,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
x1 = torch.ones(4, requires_grad=True)
|
x1 = torch.ones(4, requires_grad=True)
|
||||||
with compiled_autograd.enable(compiler_fn):
|
with compiled_autograd.enable(compiler_fn):
|
||||||
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||||
comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
|
comp_mod = torch.compile(mod, backend=cnts, fullgraph=True)
|
||||||
comp_out = comp_mod(x1)
|
comp_out = comp_mod(x1)
|
||||||
comp_out[0].backward(torch.ones(4))
|
comp_out[0].backward(torch.ones(4))
|
||||||
|
|
||||||
@ -736,7 +736,7 @@ class HooksTests(torch._dynamo.test_case.TestCase):
|
|||||||
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
||||||
compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
|
compiled_fn = torch.compile(reg_and_mul, backend=cnts, fullgraph=True)
|
||||||
|
|
||||||
compiled_bwd_ctx = (
|
compiled_bwd_ctx = (
|
||||||
compiled_autograd.enable(
|
compiled_autograd.enable(
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class MinifierTests(MinifierTestBase):
|
|||||||
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
|
# Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA)
|
||||||
def _test_after_dynamo(self, device, backend, expected_error):
|
def _test_after_dynamo(self, device, backend, expected_error):
|
||||||
run_code = f"""\
|
run_code = f"""\
|
||||||
@torch._dynamo.optimize({backend!r})
|
@torch.compile(backend={backend!r})
|
||||||
def inner(x):
|
def inner(x):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
x = torch.sin(x)
|
x = torch.sin(x)
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
|
|||||||
x = torch.tensor([3.0])
|
x = torch.tensor([3.0])
|
||||||
with RewriteAddToMul():
|
with RewriteAddToMul():
|
||||||
eager_res = fn(x)
|
eager_res = fn(x)
|
||||||
compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
|
compiled_res = torch.compile(fn, backend=cnt)(x)
|
||||||
|
|
||||||
self.assertEqual(eager_res, compiled_res)
|
self.assertEqual(eager_res, compiled_res)
|
||||||
self.assertEqual(cnt.frame_count, 0)
|
self.assertEqual(cnt.frame_count, 0)
|
||||||
|
|||||||
@ -57,7 +57,7 @@ class End2EndTests(torch._dynamo.test_case.TestCase):
|
|||||||
optimizer = torch.optim.Adam([input2], lr=0.1)
|
optimizer = torch.optim.Adam([input2], lr=0.1)
|
||||||
|
|
||||||
cnts = torch._dynamo.testing.CompileCounter()
|
cnts = torch._dynamo.testing.CompileCounter()
|
||||||
opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn)
|
opt_training_iter_fn = torch.compile(training_iter_fn, backend=cnts)
|
||||||
batch = {"x": input1, "y": input2}
|
batch = {"x": input1, "y": input2}
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
opt_training_iter_fn(batch, net, optimizer)
|
opt_training_iter_fn(batch, net, optimizer)
|
||||||
|
|||||||
@ -176,7 +176,7 @@ class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
|||||||
_variable_2 = 0
|
_variable_2 = 0
|
||||||
|
|
||||||
mod = MyModule(mode=mode)
|
mod = MyModule(mode=mode)
|
||||||
model = torch._dynamo.optimize(backend="eager", nopython=mode != 6)(mod)
|
model = torch.compile(mod, backend="eager", fullgraph=mode != 6)
|
||||||
assert _variable == 0
|
assert _variable == 0
|
||||||
assert _variable_2 == 0
|
assert _variable_2 == 0
|
||||||
|
|
||||||
|
|||||||
@ -179,7 +179,7 @@ class TorchRecTests(TestCase):
|
|||||||
|
|
||||||
counter = CompileCounter()
|
counter = CompileCounter()
|
||||||
|
|
||||||
@torch._dynamo.optimize(counter, nopython=True)
|
@torch.compile(backend=counter, fullgraph=True)
|
||||||
def f(jag_tensor):
|
def f(jag_tensor):
|
||||||
# The indexing here requires more symbolic reasoning
|
# The indexing here requires more symbolic reasoning
|
||||||
# and doesn't work right now
|
# and doesn't work right now
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
|||||||
s = Seq()
|
s = Seq()
|
||||||
i = torch.randn(10)
|
i = torch.randn(10)
|
||||||
r1 = s(i)
|
r1 = s(i)
|
||||||
opt_s = torch._dynamo.optimize("ts")(s)
|
opt_s = torch.compile(s, backend="ts")
|
||||||
r2 = opt_s(i)
|
r2 = opt_s(i)
|
||||||
self.assertTrue(same(r1, r2))
|
self.assertTrue(same(r1, r2))
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
|||||||
|
|
||||||
toy_example(i1, i2)
|
toy_example(i1, i2)
|
||||||
try:
|
try:
|
||||||
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
|
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
||||||
opt_toy_example(i1, i2)
|
opt_toy_example(i1, i2)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
@ -132,7 +132,7 @@ class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
|||||||
return transform(gm).forward
|
return transform(gm).forward
|
||||||
|
|
||||||
r1 = toy_example(i1, i2)
|
r1 = toy_example(i1, i2)
|
||||||
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
|
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
||||||
r2 = opt_toy_example(i1, i2)
|
r2 = opt_toy_example(i1, i2)
|
||||||
self.assertTrue(not same(r1, r2))
|
self.assertTrue(not same(r1, r2))
|
||||||
|
|
||||||
|
|||||||
@ -788,7 +788,7 @@ class TestDynamoAOT(JitTestCase):
|
|||||||
mod = Seq()
|
mod = Seq()
|
||||||
|
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod)
|
aot_mod = torch.compile(mod, backend="aot_ts", fullgraph=True)
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
with torch.jit.fuser("fuser3"):
|
with torch.jit.fuser("fuser3"):
|
||||||
|
|||||||
@ -6372,7 +6372,7 @@ torch.cuda.synchronize()
|
|||||||
|
|
||||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||||
|
|
||||||
compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
|
compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True)
|
||||||
check_results(fn, compiled_fn, generate_inp(18))
|
check_results(fn, compiled_fn, generate_inp(18))
|
||||||
self.assertEqual(compile_counter.frame_count, 1)
|
self.assertEqual(compile_counter.frame_count, 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user