mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import dataclasses
|
|
import unittest.mock
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import same
|
|
|
|
try:
|
|
from transformers import modeling_outputs
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from transformers.file_utils import ModelOutput
|
|
from transformers.modeling_outputs import BaseModelOutput
|
|
except ImportError:
|
|
modeling_outputs = None
|
|
|
|
|
|
def maybe_skip(fn):
|
|
if modeling_outputs is None:
|
|
return unittest.skip("requires HuggingFace")(fn)
|
|
return fn
|
|
|
|
|
|
class TestHFPretrained(torch._dynamo.test_case.TestCase):
|
|
@maybe_skip
|
|
def test_pretrained(self):
|
|
def fn(a, tmp):
|
|
if tmp.return_dict:
|
|
return a + torch.ones(2) * tmp.max_length
|
|
return a
|
|
|
|
x = torch.randn(2)
|
|
tmp = PretrainedConfig(return_dict=True, max_length=20)
|
|
ref = fn(x, tmp)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, tmp)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
|
|
class TestModelOutput(torch._dynamo.test_case.TestCase):
|
|
@maybe_skip
|
|
def test_mo_create(self):
|
|
def fn(a, b):
|
|
tmp = BaseModelOutput(a + 1, attentions=b + 3)
|
|
return tmp
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
|
|
|
|
@maybe_skip
|
|
def test_mo_assign(self):
|
|
def fn(a, b):
|
|
tmp = BaseModelOutput(last_hidden_state=b + 3)
|
|
tmp.hidden_states = a + 7
|
|
tmp["attentions"] = a + b + 6
|
|
return tmp
|
|
|
|
args = [torch.randn(10), torch.randn(10)]
|
|
obj1 = fn(*args)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
obj2 = opt_fn(*args)
|
|
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
|
|
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
|
|
self.assertTrue(same(obj1.attentions, obj2.attentions))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
def _common(self, fn, op_count):
|
|
args = [
|
|
BaseModelOutput(
|
|
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
|
|
)
|
|
]
|
|
obj1 = fn(*args)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
obj2 = opt_fn(*args)
|
|
self.assertTrue(same(obj1, obj2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, op_count)
|
|
|
|
@maybe_skip
|
|
def test_mo_getattr(self):
|
|
def fn(obj: BaseModelOutput):
|
|
x = obj.last_hidden_state * 10
|
|
if obj.hidden_states is not None:
|
|
x += obj.hidden_states
|
|
if obj.attentions is not None:
|
|
x += obj.attentions
|
|
return x
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_getitem(self):
|
|
def fn(obj: BaseModelOutput):
|
|
x = obj["last_hidden_state"] * 10
|
|
if "hidden_stats" in obj:
|
|
x += obj["hidden_states"]
|
|
if "attentions" in obj:
|
|
x += obj["attentions"]
|
|
return x
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_tuple(self):
|
|
def fn(obj: BaseModelOutput):
|
|
a, b = obj.to_tuple()
|
|
return a + b * 10
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_index(self):
|
|
def fn(obj: BaseModelOutput):
|
|
return obj[0] * 10 + obj[1]
|
|
|
|
self._common(fn, 2)
|
|
|
|
@maybe_skip
|
|
def test_mo_init(self):
|
|
@dataclasses.dataclass
|
|
class MyDataClass(ModelOutput):
|
|
a: torch.Tensor
|
|
b: torch.Tensor = None
|
|
c: torch.Tensor = None
|
|
d: torch.Tensor = None
|
|
e: torch.Tensor = None
|
|
|
|
def fn(obj):
|
|
class_fields = dataclasses.fields(obj)
|
|
assert len(class_fields)
|
|
assert all(field.default is None for field in class_fields[1:])
|
|
other_fields_are_none = all(
|
|
getattr(obj, field.name) is None for field in class_fields[1:]
|
|
)
|
|
assert not other_fields_are_none
|
|
|
|
total = getattr(obj, class_fields[0].name)
|
|
for field in class_fields[1:]:
|
|
v = getattr(obj, field.name)
|
|
if v is not None:
|
|
total += v
|
|
|
|
return total
|
|
|
|
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
obj1 = MyDataClass(*tensors)
|
|
correct1 = fn(obj1)
|
|
|
|
obj2 = MyDataClass(*tensors)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(obj2), correct1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|