mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762 Approved by: https://github.com/anijain2305
597 lines
16 KiB
Python
597 lines
16 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import unsupported
|
|
from torch._dynamo.utils import ifdynstaticdefault
|
|
|
|
|
|
globalmod = torch.nn.ReLU()
|
|
|
|
|
|
def indirectly_unsupported(a, b):
|
|
c = a + b
|
|
return unsupported(a, c)
|
|
|
|
|
|
class SubGraphTests(torch._dynamo.test_case.TestCase):
|
|
def _common(self, fn, frame_count, op_count):
|
|
torch._dynamo.reset()
|
|
v1 = torch.ones(10)
|
|
v2 = torch.ones(10) * -2.0
|
|
correct1 = fn(v1, v2)
|
|
correct2 = fn(v2, v1)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
r1 = opt_fn(v1, v2)
|
|
r2 = opt_fn(v2, v1)
|
|
self.assertTrue(torch._dynamo.testing.same(r1, correct1))
|
|
self.assertTrue(torch._dynamo.testing.same(r2, correct2))
|
|
self.assertEqual(
|
|
cnt.frame_count,
|
|
frame_count,
|
|
f"actual {cnt.frame_count} != expected {frame_count}",
|
|
)
|
|
self.assertEqual(cnt.op_count, op_count)
|
|
|
|
def test_control_flow1(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
if c1.sum() > c2.sum():
|
|
return c1
|
|
else:
|
|
return c2
|
|
|
|
self._common(fn, 1, 5)
|
|
|
|
def test_control_flow2(self):
|
|
def fn(a, b):
|
|
if a.sum() > b.sum():
|
|
return 1
|
|
else:
|
|
return 2
|
|
|
|
self._common(fn, 1, 3)
|
|
|
|
def test_control_flow3(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
m = globalmod
|
|
if c1.sum() > c2.sum():
|
|
return m(c1)
|
|
else:
|
|
return m(c2)
|
|
|
|
self._common(fn, 3, 7)
|
|
|
|
def test_control_flow4(self):
|
|
def fn(a, b):
|
|
tmp1 = a.sum() > b.sum() and a.sum() > 0
|
|
if tmp1:
|
|
return 1
|
|
else:
|
|
return 2
|
|
|
|
self._common(fn, 3, 5)
|
|
|
|
def test_control_flow5(self):
|
|
def fn(a, b):
|
|
tmp1 = a.sum() > b.sum() and a.sum() > 0
|
|
tmp2 = a.sum() < b.sum() or b.sum() > 0
|
|
if tmp1 and tmp2:
|
|
return 1, tmp1, tmp2
|
|
else:
|
|
return 2, tmp1, tmp2
|
|
|
|
self._common(fn, 6, 13)
|
|
|
|
def test_capi_call1(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return unsupported(c1, c2)
|
|
|
|
self._common(fn, 1, 2)
|
|
|
|
def test_capi_call2(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return a - (b - unsupported(c1, c2))
|
|
|
|
self._common(fn, 2, 4)
|
|
|
|
def test_capi_call3(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return torch._dynamo.testing.unsupported(c1, c2)
|
|
|
|
self._common(fn, 1, 2)
|
|
|
|
def test_indirect_unsupported1(self):
|
|
def fn(a, b):
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return indirectly_unsupported(c1, c2)
|
|
|
|
self._common(fn, 2, 3)
|
|
|
|
def test_indirect_unsupported2(self):
|
|
def fn(a, b):
|
|
local_const1 = 7
|
|
local_const2 = 22
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return local_const1 / (local_const2 - indirectly_unsupported(c1, c2))
|
|
|
|
self._common(fn, 3, 5)
|
|
|
|
def test_indirect_unsupported3(self):
|
|
def fn(a, b):
|
|
args = [a - b, b - a]
|
|
return indirectly_unsupported(*args)
|
|
|
|
self._common(fn, 2, 3)
|
|
|
|
def test_stack_state1(self):
|
|
def fn(a, b):
|
|
t1 = 1.23 * a
|
|
t2 = 4.56 * a
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return t1 / (t2 - unsupported(c1, c2))
|
|
|
|
self._common(fn, 2, 6)
|
|
|
|
def test_stack_state2(self):
|
|
def fn(a, b):
|
|
t1 = 1.23 * a
|
|
t2 = 4.56 * a
|
|
c1 = a - b
|
|
c2 = b - a
|
|
return t1 / (t2 - indirectly_unsupported(c1, c2))
|
|
|
|
self._common(fn, 3, 7)
|
|
|
|
def test_multigraph(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
if x.sum() < 0:
|
|
return x * -1.0
|
|
return x
|
|
|
|
self._common(fn, 2, 5)
|
|
|
|
def test_extended_args(self):
|
|
too_many_adds = "+".join(["a", "b"] * 256)
|
|
source = (
|
|
f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
|
|
)
|
|
self._common(eval(source), 3, 1026)
|
|
|
|
def test_resume1(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
x = x + 2.0
|
|
x = unsupported(x, a)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 2, 6)
|
|
|
|
def test_resume2(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
x = x + 2.0
|
|
x = indirectly_unsupported(x, a)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 3, 7)
|
|
|
|
def test_resume3(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
x = x + 2.0
|
|
x = indirectly_unsupported(x, b=a)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 3, 7)
|
|
|
|
def test_resume4(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
x = x + 2.0
|
|
x = indirectly_unsupported(a=x, b=a)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 3, 7)
|
|
|
|
def test_resume5(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
x = x / 2.0
|
|
x = x + 2.0
|
|
print(x)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 2, 6)
|
|
|
|
def test_start1(self):
|
|
def fn(a, b):
|
|
print(a)
|
|
x = a + b
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 1, 3)
|
|
|
|
def test_start2(self):
|
|
def fn(a, b):
|
|
x = indirectly_unsupported(a, b)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 2, 4)
|
|
|
|
def test_start3(self):
|
|
def fn(a, b):
|
|
x = unsupported(a, b)
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
x = x + 2.0
|
|
return x
|
|
|
|
self._common(fn, 1, 3)
|
|
|
|
def test_start4(self):
|
|
def fn(a, b, check):
|
|
if check:
|
|
return a + b + 10
|
|
else:
|
|
return a + b - 10
|
|
|
|
v1 = torch.randn(10)
|
|
v2 = torch.randn(10)
|
|
f = torch.zeros(1, dtype=torch.int32)
|
|
t = torch.ones(1, dtype=torch.int32)
|
|
correct1 = fn(v1, v2, t)
|
|
correct2 = fn(v1, v2, f)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
r1 = opt_fn(v1, v2, t)
|
|
r2 = opt_fn(v1, v2, f)
|
|
self.assertTrue(torch._dynamo.testing.same(r1, correct1))
|
|
self.assertTrue(torch._dynamo.testing.same(r2, correct2))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 4)
|
|
|
|
def test_resume_freevars(self):
|
|
c1 = torch.randn(10)
|
|
c2 = torch.randn(10)
|
|
|
|
def fn(a, b):
|
|
x = a + b + (c1 - c2)
|
|
x = unsupported(x, x)
|
|
return x + (c1 - c2)
|
|
|
|
self._common(fn, 2, 5)
|
|
|
|
def test_restore_state(self):
|
|
def fn(a, b):
|
|
len_ = len
|
|
x = a + b
|
|
x = torch.add(unsupported(x, x), 1)
|
|
return a * x + len_(b)
|
|
|
|
self._common(fn, 2, 4)
|
|
|
|
def test_restore_range(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
rng = range(3, 8, 2)
|
|
x = unsupported(x, x)
|
|
for i in rng:
|
|
x = x + i
|
|
return x
|
|
|
|
# We don't specialize on range with dynamic shapes, which
|
|
# means we fail to unroll the loop.
|
|
# TODO: Consider forcing specialization when we iterate over
|
|
# the loop
|
|
self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1))
|
|
|
|
def test_restore_range_iter(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
rng = iter(range(3, 8, 2))
|
|
x = unsupported(x, x)
|
|
x += next(rng)
|
|
return x, list(rng)
|
|
|
|
self._common(fn, 2, 2)
|
|
|
|
def test_pop_after_resume(self):
|
|
def fn(a, b):
|
|
tmp = [a + 1, b + 2, a + b]
|
|
x = a
|
|
x = unsupported(x, x)
|
|
for i in range(3):
|
|
x += tmp.pop(-1)
|
|
return x
|
|
|
|
self._common(fn, 2, 6)
|
|
|
|
@patch("torch._dynamo.config.assume_static_by_default", False)
|
|
def test_dynamic_getitem(self):
|
|
def fn(a, b):
|
|
return a[b.size(0) - 1]
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
for i in range(3, 12):
|
|
opt_fn(torch.randn(i), torch.randn(i))
|
|
# just one graph
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_dynamic_kwarg(self):
|
|
def fn(a, b):
|
|
return a - b * 10
|
|
|
|
torch._dynamo.reset()
|
|
cnt_dynamic = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
|
|
start = 2
|
|
end = 12
|
|
steps = end - start
|
|
for i in range(start, end):
|
|
opt_fn(torch.randn(i), torch.randn(i))
|
|
|
|
self.assertEqual(cnt_dynamic.frame_count, 1)
|
|
|
|
def test_dynamic_duck_size(self):
|
|
def fn(a, b):
|
|
if a.size(0) == b.size(0):
|
|
return a + b
|
|
else:
|
|
return a.sum() + b.sum()
|
|
|
|
torch._dynamo.reset()
|
|
cnt_dynamic = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
|
|
x = torch.randn(2)
|
|
y = torch.randn(3)
|
|
self.assertEqual(opt_fn(x, x), fn(x, x))
|
|
self.assertEqual(opt_fn(x, y), fn(x, y))
|
|
self.assertEqual(cnt_dynamic.frame_count, 2)
|
|
|
|
def test_dynamic_order_dependence(self):
|
|
def fn(a, b):
|
|
return a.sum() + b.sum()
|
|
|
|
torch._dynamo.reset()
|
|
cnt_dynamic = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn)
|
|
x = torch.randn(2)
|
|
y = torch.randn(3)
|
|
self.assertEqual(opt_fn(x, y), fn(x, y))
|
|
self.assertEqual(opt_fn(x, x), fn(x, x))
|
|
# NB: This COULD validly be 2, but we don't test disjointness in the
|
|
# guards for when x and y didn't duck size together, so we end up
|
|
# with a generic graph that also works when x and y happen to duck
|
|
# size together.
|
|
self.assertEqual(cnt_dynamic.frame_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
cnt_dynamic.frame_count = 0
|
|
self.assertEqual(opt_fn(x, x), fn(x, x)) # this overspecializes!
|
|
self.assertEqual(opt_fn(x, y), fn(x, y))
|
|
self.assertEqual(cnt_dynamic.frame_count, 2)
|
|
|
|
def test_dynamic_zero_inference(self):
|
|
def fn(a):
|
|
if a.size(0) != 0:
|
|
return a * 2
|
|
else:
|
|
return a + 1
|
|
|
|
torch._dynamo.reset()
|
|
cnt_dynamic = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn)
|
|
x = torch.randn(0)
|
|
y = torch.randn(2)
|
|
self.assertEqual(opt_fn(y), fn(y))
|
|
self.assertEqual(opt_fn(x), fn(x))
|
|
self.assertEqual(cnt_dynamic.frame_count, 2)
|
|
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_no_graph_break_on_item(self):
|
|
def fn(a, b):
|
|
x = a + b - 1.5
|
|
x = x.sum()
|
|
x.item()
|
|
x = x / (a + b)
|
|
return x
|
|
|
|
self._common(fn, 1, 5) # item gets DCE'd
|
|
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
|
|
def test_graph_break_on_item(self):
|
|
def fn(a, b):
|
|
x = a + b - 1.5
|
|
x = x.sum()
|
|
x.item()
|
|
x = x / (a + b)
|
|
return x
|
|
|
|
self._common(fn, 2, 5)
|
|
|
|
def test_resume_paths_join(self):
|
|
def fn(x, c1, c2, c3):
|
|
x = x + 1
|
|
if c1:
|
|
x = x + 2
|
|
x = x + 3
|
|
if c2:
|
|
x = x + 4
|
|
x = x + 5
|
|
if c3:
|
|
x = x + 6
|
|
return x + 7
|
|
|
|
v1 = torch.randn(10)
|
|
t = torch.Tensor([True])
|
|
f = torch.Tensor([False])
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
for a in (t, f):
|
|
for b in (t, f):
|
|
for c in (t, f):
|
|
opt_fn(v1, a, b, c)
|
|
|
|
# checking here we don't create 2^n graphs
|
|
self.assertEqual(cnt.frame_count, 7)
|
|
self.assertEqual(cnt.op_count, 10)
|
|
|
|
def test_resume_with_no_grad1(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
with torch.no_grad():
|
|
x = x + 1
|
|
x.sum().tolist() # graph break
|
|
x = x + 2
|
|
x = x + 3
|
|
return x
|
|
|
|
self._common(fn, 2, 9)
|
|
torch._dynamo.reset()
|
|
with torch.no_grad():
|
|
self._common(fn, 2, 5)
|
|
|
|
def test_resume_with_no_grad2(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
with torch.no_grad():
|
|
x = x + 1
|
|
x.sum().tolist() # graph break
|
|
x = x + 2
|
|
x.sum().tolist() # graph break
|
|
x = x + 3
|
|
x = x + 4
|
|
return x
|
|
|
|
self._common(fn, 3, 13)
|
|
|
|
def test_resume_with_no_grad3(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
x = x + 1
|
|
with torch.enable_grad():
|
|
x.sum().tolist() # graph break
|
|
x = x[0] + 2
|
|
x = x + 3
|
|
x = x + 4
|
|
return x
|
|
|
|
self._common(fn, 2, 11)
|
|
|
|
def test_resume_tuple_iterator(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
it = iter(tuple(range(10)))
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = unsupported(x, x)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
return x
|
|
|
|
self._common(fn, 2, 8)
|
|
|
|
def test_tuple_iterator_return(self):
|
|
def fn(x):
|
|
it = iter(tuple(range(10)))
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = unsupported(x, x)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = unsupported(x, x)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
return x, it
|
|
|
|
v1 = torch.randn(10)
|
|
v2, it2 = fn(v1)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
v3, it3 = opt_fn(v1)
|
|
v4, it4 = opt_fn(v1)
|
|
self.assertEqual(v2.tolist(), v3.tolist())
|
|
self.assertEqual(v2.tolist(), v4.tolist())
|
|
self.assertEqual(list(it2), list(it3))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 6)
|
|
|
|
def test_tuple_iterator_mutate(self):
|
|
def fn(x, it):
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
x = x + next(it)
|
|
return x
|
|
|
|
v1 = torch.randn(10)
|
|
it1 = iter(tuple(range(10)))
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist())
|
|
self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9])
|
|
|
|
def test_enumerate_not_break_graph(self):
|
|
def fn(a, b):
|
|
for i, x in enumerate(a.shape):
|
|
b = b + x
|
|
for i, x in enumerate(b.shape, 8):
|
|
b = b + x * i
|
|
return b
|
|
|
|
self._common(fn, 1, ifdynstaticdefault(2, 3))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|