mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
related commits: - #139706 - #140238 - #140247 - #140253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238 Approved by: https://github.com/soulitzer
196 lines
4.8 KiB
Python
196 lines
4.8 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
|
|
_variable = 0
|
|
_variable_2 = 0
|
|
|
|
|
|
def user_function():
|
|
return torch._utils.is_compiling()
|
|
|
|
|
|
def user_generator():
|
|
for _ in range(1):
|
|
yield torch._utils.is_compiling()
|
|
return
|
|
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, mode: int):
|
|
super().__init__()
|
|
self.mode = mode
|
|
self.register_forward_pre_hook(self.pre_forward, with_kwargs=True)
|
|
|
|
def pre_forward(self, module, args, kwargs):
|
|
if self.mode == 5:
|
|
if user_function():
|
|
global _variable
|
|
_variable += 1
|
|
return args, kwargs
|
|
|
|
def forward(self, x):
|
|
global _variable, _variable_2
|
|
|
|
if self.mode == 1:
|
|
if torch._utils.is_compiling():
|
|
_variable += 1
|
|
else:
|
|
_variable_2 += 1
|
|
elif self.mode == 2:
|
|
if user_function():
|
|
_variable += 1
|
|
elif self.mode == 3:
|
|
lambda_f = lambda: torch._utils.is_compiling() # noqa: E731
|
|
if lambda_f():
|
|
_variable += 1
|
|
elif self.mode == 4:
|
|
for cond in user_generator():
|
|
if cond:
|
|
_variable += 1
|
|
elif self.mode == 5:
|
|
x += 1
|
|
elif self.mode == 6:
|
|
if user_function():
|
|
torch._dynamo.graph_break()
|
|
_variable += 1
|
|
return x
|
|
|
|
|
|
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
|
def test_add_tensor1(self):
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor2(self):
|
|
def fn(a, b):
|
|
return torch.add(a, b)
|
|
|
|
counter = CompileCounter()
|
|
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor_list(self):
|
|
def fn(lst):
|
|
return lst[0] + lst[1]
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn([x, y])
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor_dict(self):
|
|
def fn(dt):
|
|
return dt["a"] + dt["b"]
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn({"a": x, "b": y})
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_skip(self):
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
counter = CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
x = 4
|
|
y = 5
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_recursive_list(self):
|
|
def fn(x):
|
|
return x
|
|
|
|
counter = CompileCounter()
|
|
|
|
x = []
|
|
x.append(x)
|
|
with torch._dynamo.optimize_assert(counter):
|
|
fn(x)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_custom_list(self):
|
|
def fn(x):
|
|
return x[0] + x[1]
|
|
|
|
counter = CompileCounter()
|
|
|
|
class Foo(list):
|
|
def __iter__(self):
|
|
raise Exception # noqa: TRY002
|
|
|
|
def __len__(self):
|
|
raise Exception # noqa: TRY002
|
|
|
|
x = Foo()
|
|
x.append(torch.randn(4))
|
|
x.append(torch.randn(4))
|
|
with torch._dynamo.optimize_assert(counter):
|
|
fn(x)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
def test_do_not_skip_side_effects(self):
|
|
# https://github.com/pytorch/pytorch/issues/110765
|
|
|
|
# By invoking torch._utils.is_compiling(),
|
|
# there may be side-effects inconsistent with eager when
|
|
# compiling. Thus we force dynamo to commit the graph,
|
|
# even if it does not perform any tensor operation
|
|
global _variable, _variable_2
|
|
|
|
for mode in range(1, 7):
|
|
torch._dynamo.reset()
|
|
|
|
_variable = 0
|
|
_variable_2 = 0
|
|
|
|
mod = MyModule(mode=mode)
|
|
model = torch.compile(mod, backend="eager", fullgraph=mode != 6)
|
|
assert _variable == 0
|
|
assert _variable_2 == 0
|
|
|
|
model(torch.tensor([1]))
|
|
assert _variable == 1
|
|
assert _variable_2 == 0
|
|
|
|
model(torch.tensor([1]))
|
|
assert _variable == 2
|
|
assert _variable_2 == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|