mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[Dynamo] Support typing.Mapping & Support function as argument (#88963)
These missing features come from https://github.com/pytorch/benchmark/pull/1302, where we'd like to enable E2E hf_bert dynamo train/eval. The dependent [HuggingFace accelerate library](https://huggingface.co/docs/accelerate/index) requires these improvements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88963 Approved by: https://github.com/jansel
This commit is contained in:
		
				
					committed by
					
						
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							126e44173d
						
					
				
				
					commit
					b72f5b9ae3
				
			@ -2792,6 +2792,42 @@ class MiscTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        res = opt_fn(x)
 | 
			
		||||
        self.assertTrue(torch.allclose(ref, res))
 | 
			
		||||
 | 
			
		||||
    def test_user_function_variable_supports_function_argument(self):
 | 
			
		||||
        def add1(x):
 | 
			
		||||
            return x + 1
 | 
			
		||||
 | 
			
		||||
        def add2(x):
 | 
			
		||||
            return x + 2
 | 
			
		||||
 | 
			
		||||
        def gn(x, f=add1):
 | 
			
		||||
            if f is add1:
 | 
			
		||||
                return x + 1
 | 
			
		||||
            else:
 | 
			
		||||
                return x + 2
 | 
			
		||||
 | 
			
		||||
        def fn(x, f):
 | 
			
		||||
            return gn(x, f)
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3)
 | 
			
		||||
        ref = fn(x, add2)
 | 
			
		||||
        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
 | 
			
		||||
        res = opt_fn(x, add2)
 | 
			
		||||
        self.assertTrue(torch.allclose(ref, res))
 | 
			
		||||
 | 
			
		||||
    def test_typing_variable_isinstance(self):
 | 
			
		||||
        def fn(x, m):
 | 
			
		||||
            if isinstance(m, typing.Mapping):
 | 
			
		||||
                return x + 1
 | 
			
		||||
            else:
 | 
			
		||||
                return x - 1
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3)
 | 
			
		||||
        m = {"x": torch.randn(3)}
 | 
			
		||||
        ref = fn(x, m)
 | 
			
		||||
        opt_fn = torch._dynamo.optimize("eager")(fn)
 | 
			
		||||
        res = opt_fn(x, m)
 | 
			
		||||
        self.assertTrue(torch.allclose(ref, res))
 | 
			
		||||
 | 
			
		||||
    def test_repro_graph_breaks_in__get_item_by_idx(self):
 | 
			
		||||
        class Mod(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ import re
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
import types
 | 
			
		||||
import typing
 | 
			
		||||
import weakref
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
@ -275,6 +276,13 @@ def istype(obj, allowed_types):
 | 
			
		||||
    return type(obj) is allowed_types
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_typing(value):
 | 
			
		||||
    if sys.version_info < (3, 9):
 | 
			
		||||
        return isinstance(value, typing._GenericAlias)
 | 
			
		||||
    else:
 | 
			
		||||
        return isinstance(value, typing._SpecialGenericAlias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_numpy_int_type(value):
 | 
			
		||||
    return istype(
 | 
			
		||||
        value,
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ import operator
 | 
			
		||||
import re
 | 
			
		||||
import types
 | 
			
		||||
from abc import ABCMeta
 | 
			
		||||
from typing import Any, List, Union
 | 
			
		||||
from typing import Any, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from functorch.experimental.ops import PyOperator
 | 
			
		||||
@ -43,6 +43,7 @@ from ..utils import (
 | 
			
		||||
    global_key_name,
 | 
			
		||||
    is_namedtuple,
 | 
			
		||||
    is_numpy_int_type,
 | 
			
		||||
    is_typing,
 | 
			
		||||
    istensor,
 | 
			
		||||
    istype,
 | 
			
		||||
    odict_values,
 | 
			
		||||
@ -360,7 +361,8 @@ class VariableBuilder:
 | 
			
		||||
                value,
 | 
			
		||||
                guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | 
			
		||||
            )
 | 
			
		||||
        elif value is List:
 | 
			
		||||
        elif is_typing(value):
 | 
			
		||||
            # typing.List, typing.Mapping, etc.
 | 
			
		||||
            return TypingVariable(
 | 
			
		||||
                value,
 | 
			
		||||
                guards=make_guards(GuardBuilder.ID_MATCH),
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,8 @@ def wrap_bound_arg(val, options):
 | 
			
		||||
        return cls([wrap_bound_arg(x, options) for x in val], **options)
 | 
			
		||||
    elif variables.ConstantVariable.is_literal(val):
 | 
			
		||||
        return variables.ConstantVariable(val, **options)
 | 
			
		||||
    elif isinstance(val, types.FunctionType):
 | 
			
		||||
        return variables.UserFunctionVariable(val, **options)
 | 
			
		||||
    elif isinstance(val, enum.Enum):
 | 
			
		||||
        return variables.EnumVariable(val, **options)
 | 
			
		||||
    elif isinstance(val, (type, abc.ABCMeta)):
 | 
			
		||||
 | 
			
		||||
@ -654,6 +654,12 @@ class TypingVariable(VariableTracker):
 | 
			
		||||
            )
 | 
			
		||||
        unimplemented("typing")
 | 
			
		||||
 | 
			
		||||
    def python_type(self):
 | 
			
		||||
        return type(self.value)
 | 
			
		||||
 | 
			
		||||
    def as_python_constant(self):
 | 
			
		||||
        return self.value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NumpyVariable(VariableTracker):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user