[Dynamo] Replace torch._dynamo.optimize() with torch.compile() [5/N] (#140663)

related commits:

- #139706
- #140238
- #140247
- #140253
- #140663
- #140688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140663
Approved by: https://github.com/williamwen42
This commit is contained in:
Yuanhao Ji
2024-11-18 04:11:53 +00:00
committed by PyTorch MergeBot
parent 16bc82a015
commit a1327fac45
6 changed files with 97 additions and 97 deletions

View File

@ -76,7 +76,7 @@ s0""",
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
@ -105,7 +105,7 @@ def forward(self, L_x_ : torch.Tensor):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
@ -149,7 +149,7 @@ def forward(self, L_x_ : torch.Tensor):
return x
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x + g(x)
@ -169,7 +169,7 @@ def forward(self, L_x_ : torch.Tensor):
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
@ -195,7 +195,7 @@ y = FakeTensor(..., size=(2,))
def test_print_direct(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x, z):
y = x * 2
lambda: z
@ -208,7 +208,7 @@ y = FakeTensor(..., size=(2,))
sleep_time = 5
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x, z, should_sleep):
if should_sleep:
comptime.sleep(sleep_time)
@ -233,7 +233,7 @@ y = FakeTensor(..., size=(2,))
SELF = self
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
z = 3
@ -265,7 +265,7 @@ y = FakeTensor(..., size=(2,))
return x + 3
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
y = g(y)
@ -284,7 +284,7 @@ y = FakeTensor(..., size=(2,))
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
@ -349,7 +349,7 @@ y = FakeTensor(..., size=(2,))
def test_graph_break(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
@ -363,7 +363,7 @@ y = FakeTensor(..., size=(2,))
self.assertEqual(cnt.frame_count, 1)
cnt.frame_count = 0
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def g(x):
y = x * 2
@ -386,7 +386,7 @@ y = FakeTensor(..., size=(2,))
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
@torch.compile(backend=cnt)
def f(x):
y = x * 2
lit = 2

View File

@ -20,7 +20,7 @@ class ConfigTests(torch._dynamo.test_case.TestCase):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=False, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_static)(fn)
opt_fn = torch.compile(fn, backend=cnt_static)
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
self.assertEqual(cnt_static.frame_count, 10)
@ -35,7 +35,7 @@ class ConfigTests(torch._dynamo.test_case.TestCase):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=True
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))
@ -52,7 +52,7 @@ class ConfigTests(torch._dynamo.test_case.TestCase):
with torch._dynamo.config.patch(
automatic_dynamic_shapes=True, assume_static_by_default=False
):
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic)
# NB: must not do 0, 1 as they specialized
for i in range(2, 12):
opt_fn(torch.randn(i), torch.randn(i))

View File

@ -73,7 +73,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func()
torch._dynamo.reset()
@ -106,7 +106,7 @@ class ExportTests(torch._dynamo.test_case.TestCase):
def func(x, y):
return x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -134,7 +134,7 @@ def forward(self, x, y):
def func(x, y):
return y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -159,7 +159,7 @@ def forward(self, x, y):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
@ -177,7 +177,7 @@ def forward(self, x, y):
return x.cos()
return x.sin()
opt_func = torch._dynamo.optimize("eager")(func)
opt_func = torch.compile(func, backend="eager")
real_result = opt_func(torch.ones(6, 4))
torch._dynamo.reset()
@ -236,7 +236,7 @@ def forward(self, x, y):
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -260,7 +260,7 @@ def forward(self, x, y):
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -304,7 +304,7 @@ def forward(self, x, y):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
@ -329,7 +329,7 @@ def forward(self, x, y):
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -355,7 +355,7 @@ def forward(self, x, y):
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -374,7 +374,7 @@ def forward(self, x, y):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -393,7 +393,7 @@ def forward(self, x, y):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -414,7 +414,7 @@ def forward(self, x, y):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -436,7 +436,7 @@ def forward(self, x, y):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -458,7 +458,7 @@ def forward(self, x, y):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -481,7 +481,7 @@ def forward(self, x, y):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -504,7 +504,7 @@ def forward(self, x, y):
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -528,7 +528,7 @@ def forward(self, x, y):
def func(a, b, c):
return a[0].item() + b[0].item() + c[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -552,7 +552,7 @@ def forward(self, x, y):
def func(a, b, c):
return b[0].item() + c[0].item() + a[0].item() + a[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -576,7 +576,7 @@ def forward(self, x, y):
def func(a, b, c):
return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -604,7 +604,7 @@ def forward(self, x, y):
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -628,7 +628,7 @@ def forward(self, x, y):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -670,7 +670,7 @@ def forward(self, x, y):
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func()
torch._dynamo.reset()
@ -689,7 +689,7 @@ def forward(self, x, y):
def func(x, y):
return x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -717,7 +717,7 @@ def forward(self, x, y):
def func(x, y):
return y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -742,7 +742,7 @@ def forward(self, x, y):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
@ -768,7 +768,7 @@ def forward(self, x, y):
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -792,7 +792,7 @@ def forward(self, x, y):
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -809,7 +809,7 @@ def forward(self, x, y):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
@ -836,7 +836,7 @@ def forward(self, x, y):
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -862,7 +862,7 @@ def forward(self, x, y):
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -881,7 +881,7 @@ def forward(self, x, y):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -900,7 +900,7 @@ def forward(self, x, y):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(inp)
torch._dynamo.reset()
@ -921,7 +921,7 @@ def forward(self, x, y):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -943,7 +943,7 @@ def forward(self, x, y):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -965,7 +965,7 @@ def forward(self, x, y):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -988,7 +988,7 @@ def forward(self, x, y):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -1011,7 +1011,7 @@ def forward(self, x, y):
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -1039,7 +1039,7 @@ def forward(self, x, y):
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -1063,7 +1063,7 @@ def forward(self, x, y):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
@ -1182,7 +1182,7 @@ def forward(self, x, y):
return fw
opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend=compiler, fullgraph=True, dynamic=True)
make_fx_result_through_backend = opt_func(inp)
fx_g = make_fx(func)(inp)
@ -2939,7 +2939,7 @@ def forward(self, x):
return x + x
inps = (torch.randn(1, 5),)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -2958,7 +2958,7 @@ def forward(self, x):
return x + x
inps = (torch.randn(1, 5),)
opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
real_result = opt_func(*inps)
torch._dynamo.reset()
@ -3435,7 +3435,7 @@ def forward(self, x):
return tensor + tensor
text = "".join(chr(a % 90 + 40) for a in range(111))
opt_func = torch._dynamo.optimize("eager", dynamic=True)(func)
opt_func = torch.compile(func, backend="eager", dynamic=True)
for i in [99, 100]:
input = text[:i]
opt_func(input)
@ -4378,7 +4378,7 @@ def forward(self, x):
gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
do_export = torch._dynamo.export(gm)
torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3))
torch.compile(fn, backend="eager")(torch.randn(3, 3))
gm1, _ = do_export(torch.randn(3, 3))
gm2, _ = do_export(torch.randn(5, 3))

View File

@ -2509,7 +2509,7 @@ class GraphModule(torch.nn.Module):
lambda1 = functools.partial(multiply, y=2)
cnts = torch._dynamo.testing.CompileCounter()
torch._dynamo.optimize(cnts, nopython=True)(fn)(
torch.compile(fn, backend=cnts, fullgraph=True)(
lambda0, lambda1, torch.randn(2, 2)
)
self.assertEqual(cnts.frame_count, 1)
@ -2523,7 +2523,7 @@ class GraphModule(torch.nn.Module):
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
dynamo_result = torch.compile(fn, backend=cnts, fullgraph=True)(
lambda0, lambda1, x
)
self.assertEqual(cnts.frame_count, 1)
@ -2540,7 +2540,7 @@ class GraphModule(torch.nn.Module):
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
dynamo_result = torch.compile(fn, backend=cnts, fullgraph=True)(
lambda0, lambda1, x
)
self.assertEqual(cnts.frame_count, 1)
@ -2559,7 +2559,7 @@ class GraphModule(torch.nn.Module):
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x)
dynamo_result = torch.compile(fn, backend=cnts)(udf_mul, udf_mul, x)
eager_result = fn(udf_mul, udf_mul, x)
gm = backend.graphs[0]
@ -2606,7 +2606,7 @@ class GraphModule(torch.nn.Module):
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x)
dynamo_result = torch.compile(fn, backend=cnts)(udf_mul, udf_add, x)
eager_result = fn(udf_mul, udf_add, x)
gm = backend.graphs[0]
@ -2657,7 +2657,7 @@ class GraphModule(torch.nn.Module):
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x)
dynamo_result = torch.compile(fn, backend=cnts)(udf_mul, x)
eager_result = fn(udf_mul, x)
gm = backend.graphs[0]
@ -2705,7 +2705,7 @@ class GraphModule(torch.nn.Module):
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x)
dynamo_result = torch.compile(fn, backend=cnts)(udf_mul2, x)
eager_result = fn(udf_mul2, x)
gm = backend.graphs[0]
@ -2753,7 +2753,7 @@ class GraphModule(torch.nn.Module):
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
fn = torch.compile(fn, backend=cnts, fullgraph=True)
dynamo_result = fn(lambda0, lambda1, x)
self.assertEqual(cnts.frame_count, 1)
@ -2780,7 +2780,7 @@ class GraphModule(torch.nn.Module):
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
fn2 = torch.compile(fn2, backend=cnts, fullgraph=True)
dynamo_result = fn2(lambda0, lambda1, [x])
self.assertEqual(cnts.frame_count, 1) # start over
@ -2838,7 +2838,7 @@ class GraphModule(torch.nn.Module):
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
opt_fn = torch.compile(forward, backend=cnts)
input = torch.rand([2])
_ = opt_fn(input)
@ -2854,7 +2854,7 @@ class GraphModule(torch.nn.Module):
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
opt_fn = torch.compile(forward, backend=cnts)
input = torch.rand([2])
_ = opt_fn(input)
@ -2945,7 +2945,7 @@ class GraphModule(torch.nn.Module):
np.dtype(typ.__name__),
]
cnts_1 = torch._dynamo.testing.CompileCounter()
opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype)
opt_fn_dtype = torch.compile(func_dtype, backend=cnts_1)
a = torch.zeros(3, dtype=typ)
for arg in dt_args:
r = opt_fn_dtype(a, arg)
@ -2953,7 +2953,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(cnts_1.frame_count, 1)
cnts_2 = torch._dynamo.testing.CompileCounter()
opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info)
opt_fn_info = torch.compile(func_info, backend=cnts_2)
info_args = [info_func(dt) for dt in dt_args]
for arg in info_args:
r = opt_fn_info(a, arg)
@ -3024,7 +3024,7 @@ class GraphModule(torch.nn.Module):
for dynamic in [True, False]:
torch._dynamo.reset()
opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn)
opt_fn = torch.compile(fn, dynamic=dynamic)
t = torch.ones(1)
test(10, t)
test(-100, t)
@ -3057,7 +3057,7 @@ class GraphModule(torch.nn.Module):
a = range(-10, 10)
return list(map(op, a))
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
opt_fn = torch.compile(fn, fullgraph=True)
self.assertEqual(opt_fn(), fn())
def test_unary_fold_op_seq(self):
@ -3068,7 +3068,7 @@ class GraphModule(torch.nn.Module):
a = [tuple(range(-10, i)) for i in range(10)]
return tuple(map(op, a))
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
opt_fn = torch.compile(fn, fullgraph=True)
self.assertEqual(opt_fn(), fn())
def gen_random_range_args(self):
@ -3972,7 +3972,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
return param in tensor_list
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
self.assertEqual(opt_fn(param, param2), fn(param, param2))
self.assertEqual(cnts.frame_count, 1)
# Test aliased
@ -3994,7 +3994,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
return y in tensor_list and z in tensor_list
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
opt_fn = torch.compile(fn, backend=cnts, fullgraph=True)
self.assertEqual(opt_fn(param, param2), fn(param, param2))
self.assertEqual(cnts.frame_count, 1)
# Test aliased
@ -4015,7 +4015,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
disallowed(g)
f_opt = torch._dynamo
opt_f = torch._dynamo.optimize(backend="eager")(f)
opt_f = torch.compile(f, backend="eager")
opt_f()
f()
self.assertEqual(len(lst), 2)
@ -4032,8 +4032,8 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
x += y * z
return x
opt_fn = torch._dynamo.optimize(backend="eager")(fn)
nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
opt_fn = torch.compile(fn, backend="eager")
nopython_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.ones(3)
ys = [1.0, 2.0, 3.0]

View File

@ -17,7 +17,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
x = torch.randn([2])
y = torch.randn([2])
opt = torch._dynamo.optimize(cnt, dynamic=dynamic)(foo)
opt = torch.compile(foo, backend=cnt, dynamic=dynamic)
opt(x, y)
x = torch.randn([3])
y = torch.randn([3])
@ -78,7 +78,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = True
y = torch.randn([2])
@ -126,7 +126,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
x = torch.randn([2])
y = torch.randn([2])
opt = torch._dynamo.optimize(cnt)(foo)
opt = torch.compile(foo, backend=cnt)
opt(x, y)
x = torch.randn([3])
y = 3
@ -169,7 +169,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
return c + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.randn([3])
y = torch.randn([3])
@ -206,7 +206,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
return g2 + 1
cnt = torch._dynamo.testing.CompileCounter()
compiled_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
compiled_foo = torch.compile(foo, backend=cnt, fullgraph=True)
z = torch.randn([3])
cmp_result = compiled_foo(z.detach().clone())
@ -235,7 +235,7 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
def run_foo_6_times_and_count_recompiles():
cnt = torch._dynamo.testing.CompileCounter()
opt = torch._dynamo.optimize(cnt, nopython=True)(foo)
opt = torch.compile(foo, backend=cnt, fullgraph=True)
x = torch.nn.Parameter(torch.randn(1, 3))
opt(x)

View File

@ -24,7 +24,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
correct1 = fn(v1, v2)
correct2 = fn(v2, v1)
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
r1 = opt_fn(v1, v2)
r2 = opt_fn(v2, v1)
self.assertTrue(torch._dynamo.testing.same(r1, correct1))
@ -284,7 +284,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
correct1 = fn(v1, v2, t)
correct2 = fn(v1, v2, f)
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
r1 = opt_fn(v1, v2, t)
r2 = opt_fn(v1, v2, f)
self.assertTrue(torch._dynamo.testing.same(r1, correct1))
@ -354,7 +354,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
return a[b.size(0) - 1]
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
for i in range(3, 12):
opt_fn(torch.randn(i), torch.randn(i))
# just one graph
@ -366,7 +366,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic, dynamic=True)
start = 2
end = 12
steps = end - start
@ -384,7 +384,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic, dynamic=True)
x = torch.randn(2)
y = torch.randn(3)
self.assertEqual(opt_fn(x, x), fn(x, x))
@ -397,7 +397,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic)
x = torch.randn(2)
y = torch.randn(3)
self.assertEqual(opt_fn(x, y), fn(x, y))
@ -423,7 +423,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
torch._dynamo.reset()
cnt_dynamic = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
opt_fn = torch.compile(fn, backend=cnt_dynamic, dynamic=True)
x = torch.randn(0)
y = torch.randn(2)
self.assertEqual(opt_fn(y), fn(y))
@ -469,7 +469,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
t = torch.Tensor([True])
f = torch.Tensor([False])
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
for a in (t, f):
for b in (t, f):
for c in (t, f):
@ -555,7 +555,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
v1 = torch.randn(10)
v2, it2 = fn(v1)
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
v3, it3 = opt_fn(v1)
v4, it4 = opt_fn(v1)
self.assertEqual(v2.tolist(), v3.tolist())
@ -575,7 +575,7 @@ class SubGraphTests(torch._dynamo.test_case.TestCase):
v1 = torch.randn(10)
it1 = iter(tuple(range(10)))
cnt = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnt)(fn)
opt_fn = torch.compile(fn, backend=cnt)
self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist())
self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9])