[Dynamo] Support typing.Mapping & Support function as argument (#88963)

These missing features come from https://github.com/pytorch/benchmark/pull/1302, where we'd like to enable E2E hf_bert dynamo train/eval. The dependent [HuggingFace accelerate library](https://huggingface.co/docs/accelerate/index) requires these improvements.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88963
Approved by: https://github.com/jansel
This commit is contained in:
Yanbo Liang
2022-11-17 06:57:42 +00:00
committed by PyTorch MergeBot
parent 126e44173d
commit b72f5b9ae3
5 changed files with 56 additions and 2 deletions

View File

@ -2792,6 +2792,42 @@ class MiscTests(torch._dynamo.test_case.TestCase):
res = opt_fn(x)
self.assertTrue(torch.allclose(ref, res))
def test_user_function_variable_supports_function_argument(self):
def add1(x):
return x + 1
def add2(x):
return x + 2
def gn(x, f=add1):
if f is add1:
return x + 1
else:
return x + 2
def fn(x, f):
return gn(x, f)
x = torch.randn(2, 3)
ref = fn(x, add2)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(x, add2)
self.assertTrue(torch.allclose(ref, res))
def test_typing_variable_isinstance(self):
def fn(x, m):
if isinstance(m, typing.Mapping):
return x + 1
else:
return x - 1
x = torch.randn(2, 3)
m = {"x": torch.randn(3)}
ref = fn(x, m)
opt_fn = torch._dynamo.optimize("eager")(fn)
res = opt_fn(x, m)
self.assertTrue(torch.allclose(ref, res))
def test_repro_graph_breaks_in__get_item_by_idx(self):
class Mod(torch.nn.Module):
def __init__(self):

View File

@ -19,6 +19,7 @@ import re
import sys
import time
import types
import typing
import weakref
from contextlib import contextmanager
from functools import lru_cache
@ -275,6 +276,13 @@ def istype(obj, allowed_types):
return type(obj) is allowed_types
def is_typing(value):
if sys.version_info < (3, 9):
return isinstance(value, typing._GenericAlias)
else:
return isinstance(value, typing._SpecialGenericAlias)
def is_numpy_int_type(value):
return istype(
value,

View File

@ -9,7 +9,7 @@ import operator
import re
import types
from abc import ABCMeta
from typing import Any, List, Union
from typing import Any, Union
import numpy as np
from functorch.experimental.ops import PyOperator
@ -43,6 +43,7 @@ from ..utils import (
global_key_name,
is_namedtuple,
is_numpy_int_type,
is_typing,
istensor,
istype,
odict_values,
@ -360,7 +361,8 @@ class VariableBuilder:
value,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif value is List:
elif is_typing(value):
# typing.List, typing.Mapping, etc.
return TypingVariable(
value,
guards=make_guards(GuardBuilder.ID_MATCH),

View File

@ -24,6 +24,8 @@ def wrap_bound_arg(val, options):
return cls([wrap_bound_arg(x, options) for x in val], **options)
elif variables.ConstantVariable.is_literal(val):
return variables.ConstantVariable(val, **options)
elif isinstance(val, types.FunctionType):
return variables.UserFunctionVariable(val, **options)
elif isinstance(val, enum.Enum):
return variables.EnumVariable(val, **options)
elif isinstance(val, (type, abc.ABCMeta)):

View File

@ -654,6 +654,12 @@ class TypingVariable(VariableTracker):
)
unimplemented("typing")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
class NumpyVariable(VariableTracker):
"""