mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Adds the MethodWrapperVariable and GetSetDescriptor variable types. These are used in `__torch_function__` tracing to represent attribute reads (`__get__`) and for comparing unbound methods. (the func argument when `__torch_function__` is dispatched from a method call) towards tracing for https://github.com/pytorch/pytorch/issues/93723 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109542 Approved by: https://github.com/jansel
		
			
				
	
	
		
			1792 lines
		
	
	
		
			71 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1792 lines
		
	
	
		
			71 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import abc
 | |
| import collections
 | |
| import contextlib
 | |
| import dataclasses
 | |
| import enum
 | |
| import functools
 | |
| import inspect
 | |
| import logging
 | |
| import operator
 | |
| import re
 | |
| import types
 | |
| from typing import List, NamedTuple, Optional, Union
 | |
| 
 | |
| try:
 | |
|     import numpy as np
 | |
| except ModuleNotFoundError:
 | |
|     np = None
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from torch import SymInt
 | |
| from torch._guards import GuardSource, TracingContext
 | |
| from torch._ops import HigherOrderOperator
 | |
| from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
 | |
| from torch.fx.experimental.symbolic_shapes import (
 | |
|     _constrain_range_for_size,
 | |
|     DimConstraint,
 | |
|     DimDynamic,
 | |
|     RelaxedUnspecConstraint,
 | |
| )
 | |
| from torch.fx.immutable_collections import immutable_list
 | |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass
 | |
| from torch.utils.weak import TensorWeakRef, WeakIdRef
 | |
| 
 | |
| from .. import config, mutation_guard, replay_record, skipfiles
 | |
| from ..allowed_functions import (
 | |
|     is_allowed,
 | |
|     is_builtin_callable,
 | |
|     is_numpy,
 | |
|     is_user_defined_allowed,
 | |
| )
 | |
| from ..exc import unimplemented
 | |
| from ..guards import GuardBuilder, make_dupe_guard
 | |
| from ..side_effects import SideEffects
 | |
| from ..source import (
 | |
|     AttrSource,
 | |
|     ConstantSource,
 | |
|     ConvertIntSource,
 | |
|     GetItemSource,
 | |
|     GlobalWeakRefSource,
 | |
|     is_constant_source,
 | |
|     LocalSource,
 | |
|     NumpyTensorSource,
 | |
|     RandomValueSource,
 | |
|     Source,
 | |
|     TupleIteratorGetItemSource,
 | |
| )
 | |
| from ..utils import (
 | |
|     build_checkpoint_variable,
 | |
|     clone_input,
 | |
|     get_fake_value,
 | |
|     get_static_address_type,
 | |
|     getfile,
 | |
|     global_key_name,
 | |
|     is_namedtuple,
 | |
|     is_typing,
 | |
|     is_utils_checkpoint,
 | |
|     istype,
 | |
|     odict_values,
 | |
|     preserve_rng_state,
 | |
|     tensor_always_has_static_shape,
 | |
|     tuple_iterator,
 | |
|     tuple_iterator_getitem,
 | |
|     tuple_iterator_len,
 | |
|     wrap_fake_exception,
 | |
| )
 | |
| 
 | |
| from .base import MutableLocal, typestr, VariableTracker
 | |
| from .builtin import BuiltinVariable
 | |
| from .constant import ConstantVariable, EnumVariable
 | |
| from .ctx_manager import CUDAStreamVariable, NullContextVariable
 | |
| from .dicts import (
 | |
|     ConstDictVariable,
 | |
|     DataClassVariable,
 | |
|     DefaultDictVariable,
 | |
|     HFPretrainedConfigVariable,
 | |
| )
 | |
| from .distributed import (
 | |
|     DeviceMeshVariable,
 | |
|     PlacementClassVariable,
 | |
|     PlacementVariable,
 | |
|     ProcessGroupVariable,
 | |
| )
 | |
| from .functions import (
 | |
|     CollectiveFunctionRewriteVariable,
 | |
|     FunctoolsPartialVariable,
 | |
|     UserFunctionVariable,
 | |
|     UserMethodVariable,
 | |
| )
 | |
| from .higher_order_ops import TorchHigherOrderOperatorVariable
 | |
| from .lists import (
 | |
|     BaseListVariable,
 | |
|     DequeVariable,
 | |
|     ListVariable,
 | |
|     NamedTupleVariable,
 | |
|     RangeVariable,
 | |
|     SetVariable,
 | |
|     SizeVariable,
 | |
|     SliceVariable,
 | |
|     TupleIteratorVariable,
 | |
|     TupleVariable,
 | |
| )
 | |
| from .misc import (
 | |
|     AutogradFunctionContextVariable,
 | |
|     AutogradFunctionVariable,
 | |
|     ComptimeVariable,
 | |
|     GetAttrVariable,
 | |
|     GetSetDescriptorVariable,
 | |
|     InspectSignatureVariable,
 | |
|     LambdaVariable,
 | |
|     MethodWrapperVariable,
 | |
|     NumpyVariable,
 | |
|     PythonModuleVariable,
 | |
|     SkipFilesVariable,
 | |
|     TypingVariable,
 | |
| )
 | |
| 
 | |
| from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
 | |
| from .optimizer import OptimizerVariable
 | |
| from .tensor import (
 | |
|     NumpyNdarrayVariable,
 | |
|     SymNodeVariable,
 | |
|     TensorSubclassVariable,
 | |
|     TensorVariable,
 | |
|     TensorWithTFOverrideVariable,
 | |
|     UnspecializedPythonVariable,
 | |
| )
 | |
| from .torch import tensor_dunder_fns, torch_special_class_types, TorchVariable
 | |
| from .user_defined import (
 | |
|     KeyedJaggedTensorVariable,
 | |
|     UserDefinedClassVariable,
 | |
|     UserDefinedObjectVariable,
 | |
| )
 | |
| 
 | |
| 
 | |
| log = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| DimList = List
 | |
| 
 | |
| 
 | |
| class _missing:
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class GraphArg:
 | |
|     source: Source
 | |
|     # TODO: storing a SymInt here but not a FakeTensor is a pretty strange
 | |
|     # thing to do.  Probably should have example (which stores an int) and
 | |
|     # fake_example
 | |
|     _example: Union[TensorWeakRef, torch.SymInt]
 | |
|     is_unspecialized: bool
 | |
|     fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
 | |
|     # UnspecializedPythonVariable often masquerades as a tensor.
 | |
|     # We MUST NOT generate shape guard code
 | |
|     # that actually tries to access tensor properties on these values.
 | |
|     # is_tensor lets us tell if this graph arg actually is a tensor
 | |
|     # or not.
 | |
|     is_tensor: bool = True
 | |
|     # Sometimes, the Tensor we pass to example is freshly allocated (smh).
 | |
|     # Then we cannot only keep a weak reference to it.  This lets you
 | |
|     # stash a strong reference too.
 | |
|     example_strong_ref: Optional[torch.Tensor] = None
 | |
| 
 | |
|     @property
 | |
|     def example(self):
 | |
|         if isinstance(self._example, TensorWeakRef):
 | |
|             r = self._example()
 | |
|             assert r is not None
 | |
|             return r
 | |
|         else:
 | |
|             return self._example
 | |
| 
 | |
|     def __post_init__(self):
 | |
|         if isinstance(self._example, torch.Tensor):
 | |
|             self._example = TensorWeakRef(self._example)
 | |
|             assert is_fake(self.fake_tensor)
 | |
| 
 | |
|     def load(self, tx):
 | |
|         return self.source.reconstruct(tx)
 | |
| 
 | |
|     def erase(self):
 | |
|         self._example = None
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         return self.source.name() == other.source.name()
 | |
| 
 | |
| 
 | |
| @dataclasses.dataclass
 | |
| class FrameStateSizeEntry:
 | |
|     scalar: Optional[int]
 | |
|     size: Optional[List[int]]
 | |
| 
 | |
| 
 | |
| class VariableBuilder:
 | |
|     """Wrap a python value in a VariableTracker() instance"""
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         tx,
 | |
|         source: Source,
 | |
|     ):
 | |
|         assert (
 | |
|             source is not None
 | |
|         ), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally."
 | |
|         assert TracingContext.get() is not None, "Expected active TracingContext"
 | |
|         super().__init__()
 | |
|         self.tx = tx
 | |
|         self.source = source
 | |
|         self.name = source.name()
 | |
| 
 | |
|     def __call__(self, value):
 | |
|         if value in self.tx.output.side_effects:
 | |
|             side_effect_result = self.tx.output.side_effects[value]
 | |
|             dup_guard = make_dupe_guard(self.source, side_effect_result.source)
 | |
|             if dup_guard:
 | |
|                 side_effect_result = side_effect_result.add_guards(
 | |
|                     self.make_guards(dup_guard)
 | |
|                 )
 | |
|             return side_effect_result
 | |
|         vt = self._wrap(value).clone(**self.options())
 | |
|         if self._can_lift_attrs_to_inputs(vt):
 | |
|             vt = self.tx.output.side_effects.track_object_existing(
 | |
|                 self.source, value, vt
 | |
|             )
 | |
|         return vt
 | |
| 
 | |
|     def _can_lift_attrs_to_inputs(self, vt):
 | |
|         if type(vt) in [
 | |
|             TensorVariable,
 | |
|             TensorWithTFOverrideVariable,
 | |
|             UserDefinedObjectVariable,
 | |
|             NumpyNdarrayVariable,
 | |
|         ]:
 | |
|             return True
 | |
|         return False
 | |
| 
 | |
|     @staticmethod
 | |
|     @functools.lru_cache(None)
 | |
|     def _common_constants():
 | |
|         return {
 | |
|             # We zero-one specialize shapes, so specialize these constants
 | |
|             # too
 | |
|             0,
 | |
|             1,
 | |
|             # NB: There used to be more constants here, but honestly it was
 | |
|             # pretty confusing.  Note we specialize floats by default, and
 | |
|             # DON'T specialize ints by default.  This all only matters with
 | |
|             # dynamic_shapes
 | |
|         }
 | |
| 
 | |
|     @staticmethod
 | |
|     def list_type(value):
 | |
|         if is_namedtuple(value):
 | |
|             return functools.partial(NamedTupleVariable, tuple_cls=type(value))
 | |
|         # TODO(voz): Why do we have both this and `BaseListVariable`'s `cls_for`?
 | |
|         return {
 | |
|             tuple: TupleVariable,
 | |
|             list: ListVariable,
 | |
|             odict_values: ListVariable,
 | |
|             torch.nn.ParameterList: ListVariable,
 | |
|             torch.nn.ModuleList: ListVariable,
 | |
|             collections.deque: DequeVariable,
 | |
|         }[type(value)]
 | |
| 
 | |
|     def get_source(self):
 | |
|         return self.source
 | |
| 
 | |
|     def options(self):
 | |
|         return {"source": self.get_source()}
 | |
| 
 | |
|     def make_guards(self, *guards):
 | |
|         source = self.get_source()
 | |
|         if (
 | |
|             isinstance(source, ConstantSource)
 | |
|             or source.guard_source() == GuardSource.CONSTANT
 | |
|         ):
 | |
|             return None
 | |
|         return {source.make_guard(guard) for guard in guards}
 | |
| 
 | |
|     @classmethod
 | |
|     @functools.lru_cache(None)
 | |
|     def _type_dispatch(cls):
 | |
|         # NB: Careful not to close over self to avoid ref cycle from lru_cache
 | |
|         entries = [
 | |
|             (
 | |
|                 (torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor),
 | |
|                 cls.wrap_tensor,
 | |
|             ),
 | |
|             ((tuple, list, odict_values, collections.deque), cls.wrap_listlike),
 | |
|             (tuple_iterator, cls.wrap_tuple_iterator),
 | |
|             ((slice, range), cls.wrap_slice_range),
 | |
|             (
 | |
|                 (
 | |
|                     int,
 | |
|                     float,
 | |
|                     bool,
 | |
|                     type(None),
 | |
|                     str,
 | |
|                     torch.Size,
 | |
|                     torch.device,
 | |
|                     torch.dtype,
 | |
|                 ),
 | |
|                 cls.wrap_literal,
 | |
|             ),
 | |
|         ]
 | |
| 
 | |
|         if config.trace_numpy and np:
 | |
|             entries.append((np.ndarray, cls.wrap_numpy_ndarray))
 | |
| 
 | |
|         result = {}
 | |
|         for ts, fn in entries:
 | |
|             for t in ts if isinstance(ts, tuple) else (ts,):
 | |
|                 assert t not in result
 | |
|                 result[t] = fn
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     @classmethod
 | |
|     @functools.lru_cache(None)
 | |
|     def _id_dispatch(cls):
 | |
|         from ..comptime import comptime
 | |
| 
 | |
|         entries = [
 | |
|             (
 | |
|                 inspect.signature,
 | |
|                 lambda self, value: LambdaVariable(
 | |
|                     InspectSignatureVariable.create,
 | |
|                     source=self.source,
 | |
|                     guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|                 ),
 | |
|             ),
 | |
|             (comptime, lambda self, value: ComptimeVariable()),
 | |
|             (
 | |
|                 dataclasses.fields,
 | |
|                 lambda self, value: LambdaVariable(
 | |
|                     _dataclasses_fields_lambda,
 | |
|                     source=self.source,
 | |
|                     guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|                 ),
 | |
|             ),
 | |
|             (
 | |
|                 tensor_dunder_fns,
 | |
|                 lambda self, value: TorchVariable(
 | |
|                     value,
 | |
|                     source=self.source,
 | |
|                     guards=self.make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|                 ),
 | |
|             ),
 | |
|         ]
 | |
| 
 | |
|         result = {}
 | |
|         for ts, fn in entries:
 | |
|             for t in ts if isinstance(ts, (tuple, list)) else (ts,):
 | |
|                 assert t not in result
 | |
|                 result[id(t)] = fn
 | |
| 
 | |
|         return result
 | |
| 
 | |
|     def _wrap(self, value):
 | |
|         make_guards = self.make_guards
 | |
| 
 | |
|         # Handle exact type() match
 | |
|         type_dispatch = self._type_dispatch().get(type(value))
 | |
|         if type_dispatch is not None:
 | |
|             return type_dispatch(self, value)
 | |
| 
 | |
|         # Handle exact id() match
 | |
|         id_dispatch = self._id_dispatch().get(id(value))
 | |
|         if id_dispatch is not None:
 | |
|             return id_dispatch(self, value)
 | |
| 
 | |
|         # Note - There are some nested values where types mismatch!
 | |
|         # We want to get those out and wrap those.
 | |
|         value = inspect.getattr_static(value, "_torchdynamo_inline", value)
 | |
| 
 | |
|         # Everything else (NB: order matters!)
 | |
|         if is_traceable_wrapper_subclass(value) or istype(
 | |
|             value, config.traceable_tensor_subclasses
 | |
|         ):
 | |
|             return self.wrap_tensor(value)
 | |
|         elif is_namedtuple(value):
 | |
|             return self.wrap_listlike(value)
 | |
| 
 | |
|         elif value is torch.utils._pytree.SUPPORTED_NODES:
 | |
|             result = {
 | |
|                 k: UserDefinedObjectVariable(
 | |
|                     value[k],
 | |
|                     source=GetItemSource(self.get_source(), k),
 | |
|                     # For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
 | |
|                     # under the assumption that the values themselves don't change.
 | |
|                     guards=self.make_guards(GuardBuilder.DICT_VERSION),
 | |
|                 )
 | |
|                 for k in value.keys()
 | |
|             }
 | |
|             return ConstDictVariable(result, type(value))
 | |
| 
 | |
|         elif istype(
 | |
|             value, (dict, collections.defaultdict, collections.OrderedDict)
 | |
|         ) and all(
 | |
|             ConstantVariable.is_literal(k)
 | |
|             or self.tensor_can_be_dict_key(k)
 | |
|             or isinstance(k, enum.Enum)
 | |
|             for k in value.keys()
 | |
|         ):
 | |
|             if not value and self.get_source().is_nn_module():
 | |
|                 # It is faster to guard on 'false' property than to guard
 | |
|                 # on actual dict keys, but we can't do this fast guard in general because
 | |
|                 # it omits a crucial type check that ensures the value is actually still a dict at runtime.
 | |
| 
 | |
|                 # Why is this OK for (specialized) nnmodules? We set up a setattr hook
 | |
|                 # to check for module property mutations, which does a reasonable,
 | |
|                 # but not completely secure job ensuring a property wasn't changed.
 | |
|                 guards = self.make_guards(GuardBuilder.BOOL_FALSE)
 | |
|             else:
 | |
|                 guards = self.make_guards(GuardBuilder.DICT_KEYS)
 | |
| 
 | |
|             # store key variables in global location for reconstruction
 | |
|             for key in value.keys():
 | |
|                 if self.tensor_can_be_dict_key(key):
 | |
|                     self.tx.store_global_weakref(global_key_name(key), key)
 | |
| 
 | |
|             def index_source(key):
 | |
|                 if self.tensor_can_be_dict_key(key):
 | |
|                     return GlobalWeakRefSource(global_key_name(key))
 | |
|                 else:
 | |
|                     return key
 | |
| 
 | |
|             result = {
 | |
|                 k: VariableBuilder(
 | |
|                     self.tx, GetItemSource(self.get_source(), index_source(k))
 | |
|                 )(value[k]).add_guards(guards)
 | |
|                 for k in value.keys()
 | |
|             }
 | |
| 
 | |
|             if istype(value, collections.defaultdict):
 | |
|                 result = DefaultDictVariable(
 | |
|                     result,
 | |
|                     type(value),
 | |
|                     self._wrap(value.default_factory),
 | |
|                     guards=guards,
 | |
|                 )
 | |
|             else:
 | |
|                 result = ConstDictVariable(result, type(value), guards=guards)
 | |
| 
 | |
|             return self.tx.output.side_effects.track_dict(self.source, value, result)
 | |
|         elif isinstance(value, torch.nn.Module):
 | |
|             return self.wrap_module(value)
 | |
|         elif ConstantVariable.is_literal(value):  # non-atomic literals
 | |
|             return self.wrap_literal(value)
 | |
|         elif istype(value, frozenset) and (
 | |
|             all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value)
 | |
|         ):
 | |
|             # For frozenset, we can guard by object ID instead of value
 | |
|             # equality, this allows us to handle non-literal values
 | |
|             return ConstantVariable(
 | |
|                 value=value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif isinstance(value, enum.Enum):
 | |
|             return EnumVariable(
 | |
|                 value=value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif is_builtin_callable(value):
 | |
|             return BuiltinVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.BUILTIN_MATCH),
 | |
|             )
 | |
|         elif is_utils_checkpoint(value):
 | |
|             return build_checkpoint_variable(source=self.source)
 | |
|         elif is_allowed(value):
 | |
|             if is_user_defined_allowed(value):
 | |
|                 self.tx.output.has_user_defined_allowed_in_graph = True
 | |
|             return TorchVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif isinstance(value, functools.partial):
 | |
|             func_src = AttrSource(self.get_source(), "func")
 | |
|             func_obj = VariableBuilder(self.tx, func_src)(value.func)
 | |
| 
 | |
|             args = []
 | |
|             args_source = AttrSource(self.get_source(), "args")
 | |
|             for i, arg in enumerate(value.args):
 | |
|                 args.append(
 | |
|                     VariableBuilder(self.tx, GetItemSource(args_source, i))(arg)
 | |
|                 )
 | |
| 
 | |
|             keywords = {}
 | |
|             keywords_source = AttrSource(self.get_source(), "keywords")
 | |
|             for k, v in value.keywords.items():
 | |
|                 keywords[k] = VariableBuilder(
 | |
|                     self.tx, GetItemSource(keywords_source, k)
 | |
|                 )(v)
 | |
| 
 | |
|             guards = {
 | |
|                 self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
 | |
|                 keywords_source.make_guard(GuardBuilder.DICT_KEYS),
 | |
|                 args_source.make_guard(GuardBuilder.LIST_LENGTH),
 | |
|             }
 | |
| 
 | |
|             return FunctoolsPartialVariable(
 | |
|                 func_obj, args, keywords, original=value, guards=guards
 | |
|             )
 | |
|         elif is_typing(value):
 | |
|             # typing.List, typing.Mapping, etc.
 | |
|             return TypingVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif is_numpy(value):
 | |
|             assert np
 | |
|             return NumpyVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(
 | |
|                     GuardBuilder.FUNCTION_MATCH
 | |
|                     if callable(value)
 | |
|                     else GuardBuilder.TYPE_MATCH
 | |
|                 ),
 | |
|             )
 | |
|         elif (
 | |
|             istype(value, (type, types.FunctionType))
 | |
|             and skipfiles.check(getfile(value), allow_torch=True)
 | |
|             and not inspect.getattr_static(value, "_torchdynamo_inline", False)
 | |
|         ):
 | |
|             return SkipFilesVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         # NB: These can't be put in type_dispatch, they have to run later
 | |
|         elif CollectiveFunctionRewriteVariable.can_rewrite(value):
 | |
|             new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(value)
 | |
|             old_source = self.source
 | |
|             self.source = new_source
 | |
|             return CollectiveFunctionRewriteVariable(
 | |
|                 new_fn,
 | |
|                 orig_fn=value,
 | |
|                 orig_source=old_source,
 | |
|                 source=new_source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif istype(value, (types.FunctionType, torch.jit.ScriptFunction)):
 | |
|             return UserFunctionVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif istype(value, (types.ModuleType, replay_record.DummyModule)):
 | |
|             return PythonModuleVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.PYMODULE_MATCH),
 | |
|             )
 | |
|         elif istype(value, torch.autograd.function.FunctionMeta):
 | |
|             return AutogradFunctionVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif isinstance(value, torch.autograd.function.FunctionCtx):
 | |
|             # The autograd.function context
 | |
|             return self.tx.output.side_effects.track_object_existing(
 | |
|                 self.source,
 | |
|                 value,
 | |
|                 AutogradFunctionContextVariable(
 | |
|                     value,
 | |
|                     source=self.source,
 | |
|                     guards=make_guards(GuardBuilder.TYPE_MATCH),
 | |
|                 ),
 | |
|             )
 | |
|         elif (
 | |
|             isinstance(value, types.MethodType)
 | |
|             and istype(
 | |
|                 getattr(value, "__self__", None), torch.autograd.function.FunctionMeta
 | |
|             )
 | |
|             and getattr(value, "__name__", "") == "apply"
 | |
|             and value == getattr(value.__self__, "apply", None)
 | |
|         ):
 | |
|             # handle aliased autograd function `apply` calls
 | |
|             return GetAttrVariable(
 | |
|                 AutogradFunctionVariable(
 | |
|                     value.__self__,
 | |
|                     source=self.source,
 | |
|                     guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|                 ),
 | |
|                 "apply",
 | |
|             )
 | |
|         elif np and isinstance(value, np.number):
 | |
|             return self.wrap_unspecialized_primitive(value)
 | |
|         elif DataClassVariable.is_matching_object(value):
 | |
|             return DataClassVariable.wrap(self, value).add_guards(
 | |
|                 make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             )
 | |
|         elif HFPretrainedConfigVariable.is_matching_object(value):
 | |
|             return HFPretrainedConfigVariable(
 | |
|                 value, guards=make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             )
 | |
|         elif isinstance(value, HigherOrderOperator):
 | |
|             return TorchHigherOrderOperatorVariable.make(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(
 | |
|                     GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
 | |
|                 ),
 | |
|             )
 | |
|         elif type(value).__name__ == "builtin_function_or_method" and isinstance(
 | |
|             value.__self__, torch_special_class_types
 | |
|         ):
 | |
|             return TorchVariable(
 | |
|                 value,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif isinstance(value, torch.cuda.streams.Stream):
 | |
|             unimplemented("CUDAStreamVariable does not currently work soundly.")
 | |
|             # return CUDAStreamVariable(
 | |
|             #     None,
 | |
|             #     value,
 | |
|             #     source=self.source,
 | |
|             #     guards=self.make_guards(GuardBuilder.ID_MATCH),
 | |
|             # )
 | |
|         elif (
 | |
|             isinstance(value, torch._C._TensorMeta)
 | |
|             and value in config.traceable_tensor_subclasses
 | |
|         ):
 | |
|             return TensorSubclassVariable(value, source=self.source)
 | |
|         elif isinstance(value, types.MethodType) and isinstance(
 | |
|             value.__self__, torch.nn.Module
 | |
|         ):
 | |
|             # don't let MethodTypes fall through to UserDefinedObject,
 | |
|             # which doesn't support 'CALL_FUNCTION'
 | |
| 
 | |
|             # TODO(whc): Why do we limit this to methods on NNModules?
 | |
|             # I don't have a good reason for this, but it preserves the existing behavior
 | |
|             # for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise.
 | |
|             # I suspect we probably want to relax this check and dig deeper there.
 | |
| 
 | |
|             # In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python,
 | |
|             # but need to separately wrap its underlying `__func__` and its `self` argument.  We wrap `self` here
 | |
|             # and then `__func__` gets wrapped inside UserMethodVariable.
 | |
|             self_obj = VariableBuilder(
 | |
|                 self.tx, source=AttrSource(self.source, "__self__")
 | |
|             )(value.__self__)
 | |
|             assert self_obj and isinstance(
 | |
|                 self_obj, VariableTracker
 | |
|             ), "Failed to produce a valid self obj"
 | |
|             return UserMethodVariable(
 | |
|                 value.__func__,
 | |
|                 self_obj,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif (
 | |
|             istype(value, contextlib.nullcontext)
 | |
|             and inspect.getattr_static(value, "enter_result", None) is None
 | |
|         ):
 | |
|             return NullContextVariable(
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif KeyedJaggedTensorVariable.is_matching_object(value):
 | |
|             result = KeyedJaggedTensorVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(GuardBuilder.TYPE_MATCH),
 | |
|             )
 | |
|             # TODO: this doing it manually is bad
 | |
|             return self.tx.output.side_effects.track_object_existing(
 | |
|                 self.source, value, result
 | |
|             )
 | |
|         elif isinstance(value, types.GetSetDescriptorType):
 | |
|             return GetSetDescriptorVariable(
 | |
|                 value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH)
 | |
|             )
 | |
|         elif isinstance(value, types.MethodWrapperType):
 | |
|             return MethodWrapperVariable(
 | |
|                 value, guards=self.make_guards(GuardBuilder.FUNCTION_MATCH)
 | |
|             )
 | |
|         elif isinstance(value, torch.optim.Optimizer):
 | |
|             return OptimizerVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(GuardBuilder.TYPE_MATCH),
 | |
|             )
 | |
|         elif ProcessGroupVariable.is_process_group(value):
 | |
|             return ProcessGroupVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif DeviceMeshVariable.is_device_mesh(value):
 | |
|             # TODO: see if we need to add custom guard instead
 | |
|             # of a simple ID_MATCH
 | |
|             return DeviceMeshVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif PlacementClassVariable.is_placement_type(value):
 | |
|             # TODO: see if we need to add custom guard instead
 | |
|             # of a simple ID_MATCH
 | |
|             return PlacementClassVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif PlacementVariable.is_placement(value):
 | |
|             # TODO: see if we need to add custom guard instead
 | |
|             # of a simple ID_MATCH
 | |
|             return PlacementVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.ID_MATCH),
 | |
|             )
 | |
|         elif issubclass(type(value), type):
 | |
|             # TODO(whc) the following seems preferable but breaks some tests, debug
 | |
|             # elif inspect.isclass(value):
 | |
|             return UserDefinedClassVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
 | |
|             )
 | |
|         elif isinstance(value, torch.SymBool):
 | |
|             # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
 | |
|             # user provided SymBool with a SymInt in dynamo.
 | |
| 
 | |
|             # Concretely,
 | |
|             # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
 | |
|             # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
 | |
|             # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
 | |
|             # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
 | |
| 
 | |
|             value_hint = value.node.require_hint()
 | |
|             new_source = ConvertIntSource(self.source)
 | |
| 
 | |
|             new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
 | |
|                 int(value_hint),
 | |
|                 new_source,
 | |
|                 dynamic_dim=DimDynamic.DYNAMIC,
 | |
|             )
 | |
| 
 | |
|             sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
 | |
|                 re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
 | |
|                 type(new_symint),
 | |
|                 source=new_source,
 | |
|             )
 | |
| 
 | |
|             sym_node_proxy.node.meta["grapharg"] = GraphArg(
 | |
|                 new_source,
 | |
|                 new_symint,
 | |
|                 False,
 | |
|                 None,
 | |
|                 is_tensor=False,
 | |
|                 example_strong_ref=new_symint,
 | |
|             )
 | |
|             self.tx.output.tracked_fakes.append(
 | |
|                 TrackedFake(new_symint, new_source, None)
 | |
|             )
 | |
|             return SymNodeVariable(
 | |
|                 sym_node_proxy,
 | |
|                 new_symint == 1,
 | |
|             )
 | |
|         else:
 | |
|             result = UserDefinedObjectVariable(
 | |
|                 value,
 | |
|                 source=self.source,
 | |
|                 guards=self.make_guards(GuardBuilder.TYPE_MATCH),
 | |
|             )
 | |
|             if not SideEffects.cls_supports_mutation_side_effects(type(value)):
 | |
|                 # don't allow STORE_ATTR mutation with custom __setattr__
 | |
|                 return result
 | |
|             return self.tx.output.side_effects.track_object_existing(
 | |
|                 self.source, value, result
 | |
|             )
 | |
| 
 | |
|     def tensor_can_be_dict_key(self, value):
 | |
|         # only allow Parameter and another specific Tensor can be used as dict key
 | |
|         return (
 | |
|             isinstance(value, torch.nn.Parameter)
 | |
|             or isinstance(self.source, AttrSource)
 | |
|             and self.source.member == "state"
 | |
|             and isinstance(self.source.base, LocalSource)
 | |
|         )
 | |
| 
 | |
|     def tensor_should_specialize(self):
 | |
|         return (
 | |
|             self.source
 | |
|             and isinstance(self.source, GetItemSource)
 | |
|             and isinstance(self.source.base, GetItemSource)
 | |
|             and self.source.base.index == "params"
 | |
|             and isinstance(self.source.base.base, GetItemSource)
 | |
|             and isinstance(self.source.base.base.base, AttrSource)
 | |
|             and self.source.base.base.base.member == "param_groups"
 | |
|             and isinstance(self.source.base.base.base.base, LocalSource)
 | |
|             and (
 | |
|                 isinstance(
 | |
|                     self.tx.f_locals[self.source.base.base.base.base.local_name],
 | |
|                     torch.optim.Optimizer,
 | |
|                 )
 | |
|                 if self.source.base.base.base.base.local_name in self.tx.f_locals.keys()
 | |
|                 else True
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]):
 | |
|         # One can index a tensor with a list/tuple. Therefore, we need to
 | |
|         # have a stricter match.
 | |
|         guards = self.make_guards(GuardBuilder.LIST_LENGTH)
 | |
| 
 | |
|         for item in value:
 | |
|             if item is value:
 | |
|                 unimplemented("list elements are pointing to the list itself")
 | |
| 
 | |
|         output = [
 | |
|             VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
 | |
|                 item
 | |
|             ).add_guards(guards)
 | |
|             for i, item in enumerate(value)
 | |
|         ]
 | |
|         result = self.list_type(value)(
 | |
|             output, mutable_local=MutableLocal(), guards=guards
 | |
|         )
 | |
|         if istype(value, list):
 | |
|             return self.tx.output.side_effects.track_list(self.source, value, result)
 | |
|         return result
 | |
| 
 | |
|     def wrap_tuple_iterator(self, value: tuple_iterator):
 | |
|         guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
 | |
|         output = [
 | |
|             VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))(
 | |
|                 tuple_iterator_getitem(value, i)
 | |
|             ).add_guards(guards)
 | |
|             for i in range(tuple_iterator_len(value))
 | |
|         ]
 | |
|         return TupleIteratorVariable(
 | |
|             output, mutable_local=MutableLocal(), guards=guards
 | |
|         )
 | |
| 
 | |
|     def wrap_slice_range(self, value: Union[slice, range]):
 | |
|         items = [
 | |
|             VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
 | |
|                 getattr(value, k)
 | |
|             )
 | |
|             for k in ("start", "stop", "step")
 | |
|         ]
 | |
|         if isinstance(value, slice):
 | |
|             return SliceVariable(
 | |
|                 items, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             )
 | |
|         else:
 | |
|             return RangeVariable(
 | |
|                 items, guards=self.make_guards(GuardBuilder.EQUALS_MATCH)
 | |
|             )
 | |
| 
 | |
|     def wrap_module(self, value: torch.nn.Module):
 | |
|         from ..eval_frame import OptimizedModule
 | |
| 
 | |
|         if istype(value, OptimizedModule):
 | |
|             guards = self.make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             self.source = AttrSource(self.source, "_orig_mod")
 | |
|             return self.wrap_module(value._orig_mod).add_guards(guards)
 | |
| 
 | |
|         if (
 | |
|             isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
 | |
|             and not config.allow_rnn
 | |
|         ):
 | |
|             unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
 | |
|         if mutation_guard.is_dynamic_nn_module(value):
 | |
|             # created dynamically, don't specialize on it
 | |
|             result = UnspecializedNNModuleVariable(
 | |
|                 value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             )
 | |
|             if not SideEffects.cls_supports_mutation_side_effects(type(value)):
 | |
|                 # don't allow STORE_ATTR mutation with custom __setattr__
 | |
|                 return result
 | |
|             return self.tx.output.side_effects.track_object_existing(
 | |
|                 self.source, value, result
 | |
|             )
 | |
|         elif issubclass(
 | |
|             value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
 | |
|         ):
 | |
|             return UnspecializedNNModuleVariable(
 | |
|                 value, guards=self.make_guards(GuardBuilder.TYPE_MATCH)
 | |
|             )
 | |
|         elif getattr(value, "_is_fsdp_managed_module", False):
 | |
|             # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
 | |
|             # in fully_sharded_data_parallel.py for more information
 | |
| 
 | |
|             # we can't do this assert inside FSDP constructor,
 | |
|             # since we don't know yet whether dynamo will be used
 | |
|             assert getattr(
 | |
|                 value, "_fsdp_use_orig_params", False
 | |
|             ), "Dynamo only supports FSDP with use_orig_params=True"
 | |
| 
 | |
|             # Note on FSDP guarding
 | |
|             # 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap).
 | |
|             # 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their
 | |
|             #    model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams.
 | |
|             #
 | |
|             # Due to (1), once we enter this path we expect not to go back nor have to guard on type
 | |
|             # or _is_fsdp_managed_module.
 | |
|             #
 | |
|             # TODO(whc) We could add a guard on the opposite case, where a user compiled/ran
 | |
|             # pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling.
 | |
|             #
 | |
|             # Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the
 | |
|             # guard source.  This behavior is gated on config.skip_fsdp_guards.
 | |
|             #
 | |
|             # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
 | |
|             # them differently with different FSDP configs.  (test_dynamo_distributed.py -k test_fsdp_aot_eager)
 | |
|             return FSDPManagedNNModuleVariable(
 | |
|                 value,
 | |
|                 guards=self.make_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH),
 | |
|                 source=self.get_source(),
 | |
|             )
 | |
|         else:
 | |
|             return self.tx.output.register_attr_or_module(
 | |
|                 value,
 | |
|                 self.name,
 | |
|                 source=self.get_source(),
 | |
|                 # Guards are added inside register_attr_or_module
 | |
|             )
 | |
| 
 | |
|     def wrap_literal(self, value):
 | |
|         unspec = not config.specialize_int
 | |
|         if unspec and type(value) is torch.Size:
 | |
|             return SizeVariable(
 | |
|                 [
 | |
|                     VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(v)
 | |
|                     for i, v in enumerate(value)
 | |
|                 ],
 | |
|                 guards=self.make_guards(GuardBuilder.LIST_LENGTH),
 | |
|             )
 | |
|         elif unspec and type(value) is int:
 | |
|             # unspecializing int by default, but still
 | |
|             # specialize for the following conditions
 | |
|             if not TracingContext.get().force_unspec_int_unbacked_size_like and (
 | |
|                 value in self._common_constants()
 | |
|                 # Assume integers from global variables want to be specialized
 | |
|                 or not self.source.guard_source().is_local()
 | |
|                 # Assume that integers that came from NN modules want to be
 | |
|                 # specialized (as we don't expect users to be changing the
 | |
|                 # NN modules on the fly)
 | |
|                 or self.source.guard_source().is_nn_module()
 | |
|             ):
 | |
|                 return ConstantVariable(
 | |
|                     value=value,
 | |
|                     guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
 | |
|                 )
 | |
|             else:
 | |
|                 return self.wrap_unspecialized_primitive(value)
 | |
|         else:
 | |
|             return ConstantVariable(
 | |
|                 value=value,
 | |
|                 guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
 | |
|             )
 | |
| 
 | |
|     def wrap_tensor(self, value: torch.Tensor):
 | |
|         source = self.get_source()
 | |
| 
 | |
|         if (
 | |
|             source.guard_source().is_nn_module()
 | |
|             or get_static_address_type(value) is not None
 | |
|         ) and not source.guard_source().is_fsdp_module():
 | |
|             return self.tx.output.register_attr_or_module(
 | |
|                 value,
 | |
|                 self.name,
 | |
|                 source=source,
 | |
|                 # Guards are done inside register_attr_or_module
 | |
|                 # guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
 | |
|             )
 | |
| 
 | |
|         if is_constant_source(source):
 | |
|             return self.tx.output.register_attr_or_module(
 | |
|                 value,
 | |
|                 re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
 | |
|                 source=source,
 | |
|                 # Guards are added inside register_attr_or_module
 | |
|             )
 | |
| 
 | |
|         if type(value) in config.traceable_tensor_subclasses:
 | |
|             # Ordinarily, we would fakeify a tensor so that it can get dynamic
 | |
|             # shapes and be computed on without triggering actual operations.
 | |
|             # However, how can we fakeify a tensor subclass?  Ordinary
 | |
|             # inheritance (nor multiple inheritance) won't work work.
 | |
|             #
 | |
|             # Instead, our plan is to *manually simulate* the tensor subclass
 | |
|             # inheriting from a fake tensor with dynamo.  This means our
 | |
|             # data representation for a tensor subclass will be a fake tensor
 | |
|             # + tensor subclass type + any extra data the subclass may have
 | |
|             # been storing on the tensor.  Because all Python accesses are
 | |
|             # mediated through TensorWithTFOverrideVariable, we can ensure
 | |
|             # that we dispatch differently, e.g., according to
 | |
|             # __torch_function__
 | |
|             #
 | |
|             # To simplify things for now, the __dict__ tracking bits haven't
 | |
|             # been implemented yet, but they can be added into this design at
 | |
|             # a later point in time.
 | |
|             ignore_subclass = True
 | |
|         else:
 | |
|             assert type(value) in (
 | |
|                 torch.Tensor,
 | |
|                 torch.nn.Parameter,
 | |
|                 torch._subclasses.fake_tensor.FakeTensor,
 | |
|             ) or is_traceable_wrapper_subclass(value), type(value)
 | |
|             ignore_subclass = False
 | |
| 
 | |
|         # NB: this just says we accessed a tensor from the same source again
 | |
|         # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice).
 | |
|         # This is distinct from two distinct sources mapping to the same
 | |
|         # Tensor (per id())!  No guard is necessary here.  See below for the
 | |
|         # other case.
 | |
|         is_duplicate_tensor = source in self.tx.output.input_source_to_var
 | |
|         if is_duplicate_tensor:
 | |
|             return self.tx.output.input_source_to_var[source]
 | |
| 
 | |
|         # We have accessed the SAME tensor from a different source.  In some
 | |
|         # situations, it doesn't matter if you have the same tensor identity
 | |
|         # or not, but we are unable to do this fine-grained tracking.  So
 | |
|         # instead we just say, if x is y, then to successfully reuse this
 | |
|         # compiled tensor again, you must have x is y again.  Negative
 | |
|         # aliases, that is, that x is not y, are IMPLICITLY checked as part of
 | |
|         # the code cache matching process, you don't need to explicitly
 | |
|         # generate a guard for it (nor would you want to, you need O(n^2)
 | |
|         # pairwise 'is not' tests to do it.)
 | |
|         if value in self.tx.output.real_value_tensor_positive_aliases:
 | |
|             stored_value = self.tx.output.real_value_tensor_positive_aliases[value]
 | |
|             # TODO(voz): Decently common pattern, refactor at some point.
 | |
|             dup_guard = self._make_dupe_guard(stored_value)
 | |
|             if dup_guard:
 | |
|                 stored_value = stored_value.add_guards(self.make_guards(dup_guard))
 | |
|             return stored_value
 | |
| 
 | |
|         # tx.output has multiple tracers if we're introspecting HigherOrderOperator.
 | |
|         # When we've discovered an untracked tensor, then we actually need
 | |
|         # to get Dynamo to track the tensor (which is what this function does)
 | |
|         # and put it as a graph input on the root tracer. Later on,
 | |
|         # if the input is actually used in the body of the HigherOrderOperator,
 | |
|         # then the relevant SubgraphTracer will lift it to being an input of
 | |
|         # the subgraph.
 | |
|         # See NOTE [HigherOrderOperator tracing design] for more details.
 | |
| 
 | |
|         tensor_proxy = self.tx.output.root_tracer.create_graph_input(
 | |
|             re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
 | |
|         )
 | |
|         tensor_variable = wrap_fx_proxy(
 | |
|             tx=self.tx,
 | |
|             proxy=tensor_proxy,
 | |
|             example_value=value,
 | |
|             guards=self.make_guards(
 | |
|                 functools.partial(
 | |
|                     GuardBuilder.TENSOR_MATCH,
 | |
|                     value=value
 | |
|                     if isinstance(source, NumpyTensorSource)
 | |
|                     else TensorWeakRef(value),
 | |
|                 )
 | |
|             ),
 | |
|             should_specialize=self.tensor_should_specialize(),
 | |
|             ignore_subclass=ignore_subclass,
 | |
|             source=source,
 | |
|         )
 | |
|         self.tx.output.input_source_to_var[source] = tensor_variable
 | |
|         assert "tensor_dict" not in tensor_proxy.node.meta
 | |
|         tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
 | |
| 
 | |
|         # TODO: I think the result is guaranteed to be fake with
 | |
|         # ignore_subclass changes
 | |
|         fake_tensor_value = None
 | |
|         example_value = tensor_variable.proxy.node.meta["example_value"]
 | |
|         if is_fake(example_value):
 | |
|             fake_tensor_value = example_value
 | |
| 
 | |
|         grapharg = GraphArg(source, value, False, fake_tensor_value)
 | |
|         tensor_proxy.node.meta["grapharg"] = grapharg
 | |
|         self.tx.output.add_symbol_bindings(grapharg)
 | |
| 
 | |
|         if type(value) in config.traceable_tensor_subclasses:
 | |
|             # NB: This is slightly misnamed, a tensor subclass might not have
 | |
|             # any explicit __torch_function__ implementation and is relying
 | |
|             # on the default inherited from torch.Tensor
 | |
|             return TensorWithTFOverrideVariable.create(
 | |
|                 self.tx,
 | |
|                 tensor_variable,
 | |
|                 source,
 | |
|                 value.__torch_function__.__func__,
 | |
|                 type(value),
 | |
|             )
 | |
| 
 | |
|         return tensor_variable
 | |
| 
 | |
|     def wrap_numpy_ndarray(self, value):
 | |
|         assert np is not None
 | |
|         assert isinstance(value, np.ndarray)
 | |
| 
 | |
|         source = NumpyTensorSource(self.get_source())
 | |
|         tensor_value = torch.as_tensor(value)
 | |
|         # We do this because we want the full behavior of guarding the numpy ndarray as if it were
 | |
|         # a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here
 | |
|         # that there's not another great way to do this atm.
 | |
|         # This creates the right graphargs, as well as registration for guards in tensor names and shape env.
 | |
|         tensor_vt = VariableBuilder(self.tx, source)(tensor_value)
 | |
|         proxy = self.tx.output.root_tracer.create_graph_input(
 | |
|             re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source
 | |
|         )
 | |
|         options = {"source": source, "guards": tensor_vt.guards}
 | |
|         numpy_ndarray_variable = wrap_fx_proxy_cls(
 | |
|             target_cls=NumpyNdarrayVariable,
 | |
|             tx=self.tx,
 | |
|             proxy=proxy,
 | |
|             example_value=tensor_value,
 | |
|             **options,
 | |
|         )
 | |
| 
 | |
|         self.tx.output.input_source_to_var[source] = numpy_ndarray_variable
 | |
|         example_value = numpy_ndarray_variable.proxy.node.meta["example_value"]
 | |
| 
 | |
|         # is_unspecialized should be true because we are wrapping a np.ndarray as argument input, and it needs to be
 | |
|         # converted to a tensor.
 | |
|         grapharg = GraphArg(
 | |
|             source,
 | |
|             tensor_value,
 | |
|             is_unspecialized=True,
 | |
|             fake_tensor=example_value,
 | |
|             is_tensor=True,
 | |
|             example_strong_ref=tensor_value,
 | |
|         )
 | |
|         proxy.node.meta["grapharg"] = grapharg
 | |
| 
 | |
|         return numpy_ndarray_variable
 | |
| 
 | |
|     def wrap_unspecialized_primitive(self, value):
 | |
|         if self.name in self.tx.output.unspec_variable_map:
 | |
|             return self.tx.output.unspec_variable_map[self.name]
 | |
|         else:
 | |
|             shape_env = self.tx.output.shape_env
 | |
|             if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance(
 | |
|                 value, int
 | |
|             ):
 | |
|                 wrapped_value = shape_env.create_unbacked_symint()
 | |
|                 _constrain_range_for_size(wrapped_value)
 | |
|                 self.tx.output.tracked_fakes.append(
 | |
|                     TrackedFake(wrapped_value, self.source, None)
 | |
|                 )
 | |
| 
 | |
|             # NB: We do not do float.  For motivation, see
 | |
|             # https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
 | |
|             # but the general idea is that we generate kernels that can
 | |
|             # take unspecialized floats and use them in sizevar computation
 | |
|             elif (
 | |
|                 isinstance(value, int)
 | |
|                 and not is_constant_source(self.get_source())
 | |
|                 and not isinstance(self.get_source(), RandomValueSource)
 | |
|             ):
 | |
|                 if torch._dynamo.config.specialize_int:
 | |
|                     # If specialize_int is False, also return
 | |
|                     # a constant (but this should have been handled
 | |
|                     # in the caller, TBH)
 | |
|                     return ConstantVariable(
 | |
|                         value=value,
 | |
|                         guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
 | |
|                     )
 | |
| 
 | |
|                 name = self.source.name()
 | |
|                 if name not in self.tx.output.frame_state:
 | |
|                     # Note - this esentially means that if this name gets reused as a tensor,
 | |
|                     # it will start fully dynamic. That should always be a safe option, and not awfully inefficient.
 | |
|                     # Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not
 | |
|                     # sure that is necessary for now.
 | |
|                     frame_state_entry = FrameStateSizeEntry(scalar=value, size=None)
 | |
|                 else:
 | |
|                     frame_state_entry = self.tx.output.frame_state[name]
 | |
|                     if frame_state_entry.scalar != value:
 | |
|                         log.debug(
 | |
|                             "automatic dynamic int %s val %s != %s",
 | |
|                             name,
 | |
|                             value,
 | |
|                             frame_state_entry.scalar,
 | |
|                         )
 | |
|                         frame_state_entry.scalar = None
 | |
|                 self.tx.output.frame_state[name] = frame_state_entry
 | |
| 
 | |
|                 # TODO: This should be dynamic, as we in general do not
 | |
|                 # know if bare integers are actually going to be sizevars
 | |
|                 # and it is inappropriate to eagerly duck size them with
 | |
|                 # real sizevars
 | |
|                 if (
 | |
|                     config.automatic_dynamic_shapes and frame_state_entry.scalar is None
 | |
|                 ) or not config.assume_static_by_default:
 | |
|                     dynamic_dim = DimDynamic.DYNAMIC
 | |
|                 else:  # assume_static_by_default
 | |
|                     # TODO: dynamic_dim = DimDynamic.STATIC should work but
 | |
|                     # for some reason it doesn't
 | |
|                     return ConstantVariable(
 | |
|                         value=value,
 | |
|                         guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
 | |
|                     )
 | |
| 
 | |
|                 wrapped_value = shape_env.create_unspecified_symint_and_symbol(
 | |
|                     value,
 | |
|                     source=self.source,
 | |
|                     dynamic_dim=dynamic_dim,
 | |
|                 )
 | |
| 
 | |
|                 self.tx.output.tracked_fakes.append(
 | |
|                     TrackedFake(wrapped_value, self.source, None)
 | |
|                 )
 | |
|             else:
 | |
|                 wrapped_value = torch.tensor(value)
 | |
|             if not isinstance(self.get_source(), RandomValueSource):
 | |
|                 guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)}
 | |
|                 options = {"guards": guards}
 | |
|             else:
 | |
|                 options = {}
 | |
|             options.update({"source": self.get_source()})
 | |
|             if isinstance(wrapped_value, torch.Tensor):
 | |
|                 options.update({"raw_value": value})
 | |
| 
 | |
|             proxy = self.tx.output.root_tracer.create_graph_input(
 | |
|                 re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
 | |
|                 type(wrapped_value),
 | |
|                 source=self.get_source(),
 | |
|             )
 | |
| 
 | |
|             unspec_var = wrap_fx_proxy_cls(
 | |
|                 UnspecializedPythonVariable,
 | |
|                 tx=self.tx,
 | |
|                 proxy=proxy,
 | |
|                 example_value=wrapped_value,
 | |
|                 **options,
 | |
|             )
 | |
|             self.tx.output.unspec_variable_map[self.name] = unspec_var
 | |
|             if not is_constant_source(self.get_source()):
 | |
|                 if self.tx.export and not isinstance(self.get_source(), LocalSource):
 | |
|                     raise AssertionError(
 | |
|                         "Dynamo attempts to add additional input during export: value={}, source={}".format(
 | |
|                             wrapped_value, self.get_source()
 | |
|                         )
 | |
|                     )
 | |
|                 fake_tensor_value = None
 | |
|                 if isinstance(unspec_var, ConstantVariable):
 | |
|                     example_value = unspec_var.value
 | |
|                 else:
 | |
|                     example_value = unspec_var.proxy.node.meta["example_value"]
 | |
|                 if is_fake(example_value):
 | |
|                     fake_tensor_value = example_value
 | |
|                     assert fake_tensor_value.fake_mode is self.tx.fake_mode, (
 | |
|                         f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode"
 | |
|                         "({self.tx.fake_mode}) from InstructionTranslator"
 | |
|                     )
 | |
| 
 | |
|                 proxy.node.meta["grapharg"] = GraphArg(
 | |
|                     self.get_source(),
 | |
|                     wrapped_value,
 | |
|                     isinstance(wrapped_value, torch.Tensor),
 | |
|                     fake_tensor_value,
 | |
|                     is_tensor=False,
 | |
|                     example_strong_ref=wrapped_value,
 | |
|                 )
 | |
|             return unspec_var
 | |
| 
 | |
| 
 | |
| def _dataclasses_fields_lambda(obj):
 | |
|     if isinstance(obj, UserDefinedObjectVariable):
 | |
|         value = obj.value
 | |
|     elif isinstance(obj, DataClassVariable):
 | |
|         value = obj.user_cls
 | |
|     else:
 | |
|         unimplemented(f"Dataclass fields handling fails for type {obj}")
 | |
|     items = []
 | |
|     for field in dataclasses.fields(value):
 | |
|         source = None
 | |
|         if obj.source:
 | |
|             source = GetItemSource(
 | |
|                 AttrSource(obj.source, "__dataclass_fields__"), field.name
 | |
|             )
 | |
|         items.append(UserDefinedObjectVariable(field, source=source).add_options(obj))
 | |
|     return TupleVariable(items).add_options(obj)
 | |
| 
 | |
| 
 | |
| def wrap_fx_proxy(tx, proxy, example_value=None, **options):
 | |
|     return wrap_fx_proxy_cls(
 | |
|         target_cls=TensorVariable,
 | |
|         tx=tx,
 | |
|         proxy=proxy,
 | |
|         example_value=example_value,
 | |
|         **options,
 | |
|     )
 | |
| 
 | |
| 
 | |
| # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
 | |
| # Should be compositional instead
 | |
| #
 | |
| # This is a horribly complicated function that does too many things, to
 | |
| # explain what it does, let's first talk about the classic usage wrap_fx_proxy
 | |
| # for a TensorVariable.  There are two primary modes of use:
 | |
| #
 | |
| #   1. Wrapping a pre-existing Tensor.  In this case, example_value is set
 | |
| #      to the pre-existing Tensor.  (Note that this example_value will NOT
 | |
| #      be the final example_value we put into node.meta['example_value'],
 | |
| #      instead it is converted into a fake tensor using
 | |
| #      wrap_to_fake_tensor_and_record and registered as a graph input.)
 | |
| #
 | |
| #   2. "Wrapping" the result of some Tensor operation Dynamo traced over.  In
 | |
| #      this case, example_value is None (and we are going to figure it out
 | |
| #      ourselves using FakeTensors, via get_fake_value, which will run
 | |
| #      the operation represented by the (singular!) FX node referenced by
 | |
| #      the passed in proxy.)
 | |
| #
 | |
| # The expectation is you end up with a Tensor output, and everything is
 | |
| # straightforwardly traced into the graph.
 | |
| #
 | |
| # Upon closer inspection, you may notice that there are a slurry of non-Tensor
 | |
| # output cases.  What gives?  Well, we sometimes trace operations into the
 | |
| # graph that don't involve tensors.
 | |
| #
 | |
| #   * Some operators return tuples; we need to recursively handle their
 | |
| #     contents
 | |
| #
 | |
| #   * Some operators have side effects that will affect subsequent AOTAutograd
 | |
| #     tracing but don't otherwise return anything.
 | |
| #
 | |
| #   * Some operators return symbolic ints/floats/bools which can go in the
 | |
| #     graph and be traced (but only if they're actually symbolic!  If they're
 | |
| #     static you don't want to put them in the graph, which means you
 | |
| #     shouldn't call this function.)
 | |
| #
 | |
| # The common theme is that you only use this function WHEN YOU ARE TRACING
 | |
| # SOMETHING INTO THE GRAPH.  This is sort of obvious, because you can't call
 | |
| # this function without a proxy.
 | |
| def wrap_fx_proxy_cls(
 | |
|     target_cls, tx, proxy, example_value=None, ignore_subclass=False, **options
 | |
| ):
 | |
|     import torch._export.constraints
 | |
|     from ..symbolic_convert import InstructionTranslatorBase
 | |
| 
 | |
|     assert isinstance(tx, InstructionTranslatorBase)
 | |
|     if "guards" in options and options["guards"] is not None:
 | |
|         tx.output.guards.update(options["guards"])
 | |
| 
 | |
|     assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}"
 | |
| 
 | |
|     initial_example_value = example_value
 | |
| 
 | |
|     def _is_functional_tensor_fakified_by_dynamo(x):
 | |
|         if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
 | |
|             reapply_views = torch._C._functionalization_reapply_views_tls()
 | |
|             unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
 | |
|             return (
 | |
|                 isinstance(unwrapped, FakeTensor)
 | |
|                 and unwrapped.fake_mode == tx.fake_mode
 | |
|             )
 | |
|         return False
 | |
| 
 | |
|     def _clone_input(value):
 | |
|         if isinstance(value, torch.Tensor):
 | |
|             # tensor subclasses will not be converted to FakeTensors and need to be cloned
 | |
|             if not (
 | |
|                 isinstance(value, FakeTensor)
 | |
|                 or _is_functional_tensor_fakified_by_dynamo(value)
 | |
|             ):
 | |
|                 # NB: ensure strides are preserved
 | |
|                 value = clone_input(value)
 | |
| 
 | |
|         return value
 | |
| 
 | |
|     with preserve_rng_state():
 | |
|         if example_value is None:
 | |
|             example_value = get_fake_value(proxy.node, tx)
 | |
| 
 | |
|         # Handle recursive calls here
 | |
|         elif (
 | |
|             is_fake(example_value)
 | |
|             and maybe_get_fake_mode(example_value) is tx.fake_mode
 | |
|         ) or _is_functional_tensor_fakified_by_dynamo(example_value):
 | |
|             pass
 | |
| 
 | |
|         elif isinstance(example_value, torch.Tensor):
 | |
|             if tx.export:
 | |
|                 # The legacy behavior for real value cache with subclasses was
 | |
|                 # to perform a clone WITHOUT preserving the subclass.  It's
 | |
|                 # not entirely clear this is what you actually want though.
 | |
|                 with torch._C.DisableTorchFunctionSubclass():
 | |
|                     proxy.tracer.real_value_cache[proxy.node] = _clone_input(
 | |
|                         example_value
 | |
|                     )
 | |
|             # NB: If we're ignoring subclass, then the expectation is you will
 | |
|             # take the returned TensorVariable and wrap it into a more
 | |
|             # accurate TensorVariable that is able to track subclass-ness;
 | |
|             # otherwise this is wrong!
 | |
|             kwargs = {
 | |
|                 "ignore_subclass": ignore_subclass,
 | |
|                 "is_tensor": target_cls is TensorVariable,
 | |
|             }
 | |
|             assert "source" in options and options["source"] is not None
 | |
|             kwargs["source"] = options["source"]
 | |
|             example_value = wrap_to_fake_tensor_and_record(
 | |
|                 example_value, tx=tx, **kwargs
 | |
|             )
 | |
| 
 | |
|     if isinstance(example_value, torch.Tensor):
 | |
|         is_parameter = isinstance(example_value, torch.nn.Parameter)
 | |
|         should_specialize = options.pop("should_specialize", False)
 | |
|         if is_parameter or should_specialize:
 | |
|             specialized_value = initial_example_value
 | |
|         else:
 | |
|             specialized_value = None
 | |
| 
 | |
|         # NB: In most (all?) cases, this does not actually do a clone.
 | |
|         # (WARNING: this means that if we mutate metadata on the fake
 | |
|         # tensor, the stored example value will update too!)
 | |
|         example_value = _clone_input(example_value)
 | |
|         proxy.node.meta["example_value"] = example_value
 | |
|         specialized_props = target_cls.specialize(example_value)
 | |
|         # TODO: not sure about this fake mode test
 | |
|         if (
 | |
|             isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
 | |
|             and example_value.fake_mode is tx.fake_mode
 | |
|         ):
 | |
|             # NB: This will be wrong for ignore_subclass; fix it up later!
 | |
|             specialized_props["class_type"] = (
 | |
|                 torch.nn.Parameter if is_parameter else torch.Tensor
 | |
|             )
 | |
| 
 | |
|         specialized_props["specialized_value"] = specialized_value
 | |
| 
 | |
|         options.update(specialized_props)
 | |
|         return target_cls(proxy, **options)
 | |
|     elif (
 | |
|         hasattr(proxy.node.target, "__name__")
 | |
|         and proxy.node.target.__name__ == "set_state"
 | |
|         and isinstance(proxy.node.target.__self__, torch._C.Generator)
 | |
|         or proxy.node.target == torch.random.set_rng_state
 | |
|     ):
 | |
|         from . import TorchVariable
 | |
| 
 | |
|         return TorchVariable(proxy.node.target)
 | |
|     elif (
 | |
|         proxy.node.target == torch._C._DisableFuncTorch
 | |
|         or proxy.node.target == torch.cuda._is_in_bad_fork
 | |
|     ):
 | |
|         from . import UserDefinedObjectVariable
 | |
| 
 | |
|         return UserDefinedObjectVariable(example_value)
 | |
|     elif istype(example_value, torch.Size) and all(
 | |
|         isinstance(x, int) for x in example_value
 | |
|     ):
 | |
|         sizes = [ConstantVariable(x) for x in example_value]
 | |
|         return SizeVariable(sizes, **options)
 | |
|     elif isinstance(example_value, (tuple, list, set)):
 | |
|         proxy.node.meta["example_value"] = example_value
 | |
|         unpacked = []
 | |
|         for i, val in enumerate(example_value):
 | |
|             if val is None:
 | |
|                 # nn.MultiheadAttention() can return None, see issue #175
 | |
|                 unpacked.append(
 | |
|                     ConstantVariable(None, **options),
 | |
|                 )
 | |
|             else:
 | |
|                 unpacked.append(
 | |
|                     wrap_fx_proxy_cls(
 | |
|                         target_cls,
 | |
|                         tx,
 | |
|                         proxy.tracer.create_proxy(
 | |
|                             "call_function", operator.getitem, (proxy, i), {}
 | |
|                         ),
 | |
|                         example_value=val,
 | |
|                         **options,
 | |
|                     )
 | |
|                 )
 | |
|         if isinstance(example_value, torch.Size):
 | |
|             # NB: Keep the old proxy around.  See SizeVariable for an
 | |
|             # explanation why
 | |
|             return SizeVariable(unpacked, proxy, **options)
 | |
|         elif istype(example_value, tuple):
 | |
|             return TupleVariable(unpacked, **options)
 | |
|         elif istype(example_value, (list, immutable_list)):
 | |
|             return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
 | |
|         elif istype(example_value, set):
 | |
|             return SetVariable(tx, unpacked, mutable_local=MutableLocal(), **options)
 | |
|         else:
 | |
|             assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
 | |
|                 example_value, "_fields"
 | |
|             ), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
 | |
|             return NamedTupleVariable(unpacked, example_value.__class__, **options)
 | |
|     elif example_value is None or proxy.node.target is torch.manual_seed:
 | |
|         return ConstantVariable(None, **options)
 | |
|     elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)):
 | |
|         proxy.node.meta["example_value"] = example_value
 | |
|         return SymNodeVariable(proxy, example_value, **options)
 | |
|     elif proxy.node.target in [torch.cuda.streams.Stream, torch.cuda.current_stream]:
 | |
|         proxy.node.meta["example_value"] = example_value
 | |
|         return CUDAStreamVariable(proxy, example_value, **options)
 | |
|     elif isinstance(example_value, int) and proxy.node.target in [
 | |
|         torch.sym_int,
 | |
|         getattr,
 | |
|         operator.getitem,
 | |
|         torch._utils._element_size,
 | |
|         torch.seed,
 | |
|         operator.mod,
 | |
|         # some mac builds are missing torch.distributed.get_rank()
 | |
|         getattr(torch.distributed, "get_rank", _missing),
 | |
|         getattr(torch.distributed, "get_world_size", _missing),
 | |
|         # This always wants to be in the graph, even if the constraint
 | |
|         # results in a constant int
 | |
|         torch._export.constraints.constrain_as_value,
 | |
|     ]:
 | |
|         proxy.node.meta["example_value"] = example_value
 | |
|         return ConstantVariable(example_value, **options)
 | |
|     else:
 | |
|         unimplemented(
 | |
|             "torch.* op returned non-Tensor "
 | |
|             + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| # Tracks the sources of all fake tensors we wrap in Dynamo.
 | |
| # Used by shape guard computation.
 | |
| @dataclasses.dataclass
 | |
| class TrackedFake:
 | |
|     fake: Union[FakeTensor, SymInt]
 | |
|     source: Source
 | |
|     # Is None when fake is SymInt
 | |
|     constraint_dims: Optional[DimList[DimConstraint]]
 | |
| 
 | |
|     def __hash__(self) -> int:
 | |
|         return hash((self.fake, self.source.name()))
 | |
| 
 | |
|     def __eq__(self, other: object) -> bool:
 | |
|         if isinstance(other, TrackedFake):
 | |
|             return self.fake is other.fake and self.source.name() == other.source.name()
 | |
|         return False
 | |
| 
 | |
| 
 | |
| # Performs automatic dynamic dim determination.
 | |
| # Returns tuple of (dynamic_dims, constraint_dims) where each is either a list of dims or None.
 | |
| def _automatic_dynamic(e, tx, name, static_shapes):
 | |
|     if static_shapes:
 | |
|         return [DimDynamic.STATIC] * e.dim(), [None] * e.dim()
 | |
| 
 | |
|     # We preserve the dynamism of inputs. For example, when users call
 | |
|     # make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes.
 | |
|     if any(isinstance(s, SymInt) for s in e.size()):
 | |
|         return [
 | |
|             DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC
 | |
|             for s in e.size()
 | |
|         ], [None] * e.dim()
 | |
| 
 | |
|     # Prep for automatic dynamic
 | |
|     frame_state_entry = None
 | |
|     if name not in tx.output.frame_state:
 | |
|         # If there is no entry for this source, add the tensor to frame state with its current static size.
 | |
|         # E.g., {} -> {"x": [2, 4]}
 | |
|         frame_state_entry = FrameStateSizeEntry(None, None)
 | |
|         frame_state_entry.size = list(e.size())
 | |
|     else:
 | |
|         frame_state_entry = tx.output.frame_state[name]
 | |
|         if frame_state_entry.size is not None:
 | |
|             if e.ndim != len(frame_state_entry.size):
 | |
|                 # If there is already an entry, and the dim mismatches, replace the frame state entry with None.
 | |
|                 # E.g. {"x": [2, 3, 4]} -> {"x": None}
 | |
|                 log.debug(
 | |
|                     "automatic dynamic %s dim %s != %s",
 | |
|                     name,
 | |
|                     e.ndim,
 | |
|                     frame_state_entry.size,
 | |
|                 )
 | |
|                 frame_state_entry.size = None
 | |
|             else:
 | |
|                 # If there is already an entry, and the dim matches, for every size in the frame state which
 | |
|                 # disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]}
 | |
|                 for i, dim in enumerate(frame_state_entry.size):
 | |
|                     if dim is not None and e.size()[i] != dim:
 | |
|                         log.debug(
 | |
|                             "automatic dynamic %s size(%s) %s != %s",
 | |
|                             name,
 | |
|                             i,
 | |
|                             e.size(i),
 | |
|                             dim,
 | |
|                         )
 | |
|                         frame_state_entry.size[i] = None
 | |
| 
 | |
|     # TODO: index export_constraints ahead of time so we don't have to
 | |
|     # do a linear scan every time here
 | |
|     t_id = id(e)
 | |
|     dim2constraint = {}
 | |
| 
 | |
|     def update_dim2constraint(dim, constraint_range, debug_name):
 | |
|         if dim in dim2constraint:
 | |
|             from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
 | |
| 
 | |
|             old_constraint_range, old_debug_name = dim2constraint[dim]
 | |
|             new_constraint_range = StrictMinMaxConstraint(
 | |
|                 vr=constraint_range.vr & old_constraint_range.vr,
 | |
|                 warn_only=False,
 | |
|             )
 | |
|             if old_debug_name is not None:
 | |
|                 assert debug_name is None or debug_name == old_debug_name
 | |
|                 new_debug_name = old_debug_name
 | |
|             else:
 | |
|                 new_debug_name = debug_name
 | |
|             dim2constraint[dim] = new_constraint_range, new_debug_name
 | |
|         else:
 | |
|             dim2constraint[dim] = constraint_range, debug_name
 | |
| 
 | |
|     if tx.output.export_constraints:
 | |
|         for constraint in tx.output.export_constraints:
 | |
|             if constraint.t_id == t_id:
 | |
|                 update_dim2constraint(
 | |
|                     constraint.dim, constraint.constraint_range, constraint.debug_name
 | |
|                 )
 | |
|             if constraint.shared is not None and constraint.shared.t_id == t_id:
 | |
|                 # We process constraint ranges for each shared dimension separately
 | |
|                 # so that we can directly check range constraint violations on them
 | |
|                 # without looking up which other shared dimensions have this info.
 | |
|                 # In other words, for this t_id, we will have processed all of its
 | |
|                 # constraint ranges, no matter where / how they were specified, by
 | |
|                 # by the end of this loop.
 | |
|                 update_dim2constraint(
 | |
|                     constraint.shared.dim,
 | |
|                     constraint.constraint_range,
 | |
|                     constraint.debug_name,
 | |
|                 )
 | |
| 
 | |
|     dynamic_dims = []
 | |
|     constraint_dims = []
 | |
|     for i in range(e.dim()):
 | |
|         # NB: mark dynamic has precedence over static
 | |
|         marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
 | |
|         marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
 | |
|         marked_static = i in getattr(e, "_dynamo_static_indices", set())
 | |
| 
 | |
|         # NB: both static and dynamic have precedence over
 | |
|         automatic_dynamic = config.automatic_dynamic_shapes and (
 | |
|             frame_state_entry.size is None or frame_state_entry.size[i] is None
 | |
|         )
 | |
| 
 | |
|         # Reflect the user directive in the frame_state
 | |
|         # For dynamic, apply None always
 | |
|         if frame_state_entry.size and marked_dynamic:
 | |
|             log.debug("automatic dynamic %s marked dynamic", name)
 | |
|             frame_state_entry.size[i] = None
 | |
| 
 | |
|         # We will process constraints first, as they will imply that we
 | |
|         # have a dynamic dimension
 | |
|         # Precedence: export constraints > eager constraints
 | |
|         constraint = dim2constraint.get(i)
 | |
|         if constraint is None:
 | |
|             if marked_dynamic and not config.allow_ignore_mark_dynamic:
 | |
|                 constraint_dim = RelaxedUnspecConstraint(warn_only=False)
 | |
|             elif not marked_static and automatic_dynamic:
 | |
|                 constraint_dim = RelaxedUnspecConstraint(warn_only=True)
 | |
|             else:
 | |
|                 constraint_dim = None
 | |
|         else:
 | |
|             constraint_dim, debug_name = constraint
 | |
|             if debug_name is not None:
 | |
|                 dim_name = f"{name}.size()[{i}]"
 | |
|                 tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name
 | |
|         constraint_dims.append(constraint_dim)
 | |
| 
 | |
|         # Now, figure out if the dim is dynamic/duck/static
 | |
|         if constraint_dim is not None or marked_dynamic or marked_weak_dynamic:
 | |
|             # NB: We could assert static_shapes is False here, but it
 | |
|             # seems better to allow the user to override policy in this
 | |
|             # case
 | |
|             dynamic = DimDynamic.DYNAMIC
 | |
|         elif static_shapes or config.assume_static_by_default or marked_static:
 | |
|             dynamic = DimDynamic.STATIC
 | |
|         else:
 | |
|             dynamic = DimDynamic.DUCK
 | |
| 
 | |
|         dynamic_dims.append(dynamic)
 | |
| 
 | |
|     tx.output.frame_state[name] = frame_state_entry
 | |
| 
 | |
|     return dynamic_dims, constraint_dims
 | |
| 
 | |
| 
 | |
| def wrap_to_fake_tensor_and_record(
 | |
|     e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool
 | |
| ):
 | |
|     if (
 | |
|         type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor)
 | |
|         or (ignore_subclass and isinstance(e, torch.Tensor))
 | |
|         or is_traceable_wrapper_subclass(e)
 | |
|     ):
 | |
|         assert source is not None
 | |
|         static_shapes, reason = tensor_always_has_static_shape(
 | |
|             e, is_tensor, guard_source=source.guard_source()
 | |
|         )
 | |
| 
 | |
|         dynamic_dims, constraint_dims = _automatic_dynamic(
 | |
|             e, tx, source.name(), static_shapes
 | |
|         )
 | |
| 
 | |
|         log.debug(
 | |
|             "wrap_to_fake %s %s %s %s",
 | |
|             source.name(),
 | |
|             tuple(e.shape),
 | |
|             dynamic_dims,
 | |
|             constraint_dims,
 | |
|         )
 | |
|         fake_e = wrap_fake_exception(
 | |
|             lambda: tx.fake_mode.from_tensor(
 | |
|                 e,
 | |
|                 ignore_subclass=ignore_subclass,
 | |
|                 source=source,
 | |
|                 dynamic_dims=dynamic_dims,
 | |
|                 constraint_dims=constraint_dims,
 | |
|             )
 | |
|         )
 | |
|         if is_tensor and not (static_shapes and source.is_nn_module()):
 | |
|             tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims))
 | |
|             tx.output.tracked_fakes_id_to_source[id(e)].append(source)
 | |
|         tx.output.tensor_weakref_to_sizes_strides[WeakIdRef(e)] = {
 | |
|             "size": fake_e.size(),
 | |
|             "stride": fake_e.stride(),
 | |
|         }
 | |
|         return fake_e
 | |
|     else:
 | |
|         return e
 | |
| 
 | |
| 
 | |
| class SourcelessBuilder:
 | |
|     """
 | |
|     Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects
 | |
|     that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over
 | |
|     .), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However,
 | |
|     there may be reasons to represent it as a ListVariable internally.
 | |
| 
 | |
|     NOTE - Objects produced here are born UNGUARDED due to the nature of sources!
 | |
| 
 | |
|     NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant
 | |
|     if/else type->VariableTracker trees that were cropping up all over dynamo.
 | |
|     """
 | |
| 
 | |
|     def __call__(self, tx, value) -> VariableTracker:
 | |
|         if isinstance(value, VariableTracker):
 | |
|             # This is always valid to call, and useful for recursive calls.
 | |
|             return value
 | |
|         if isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS):
 | |
|             return UserDefinedObjectVariable(value)
 | |
|         if ConstantVariable.is_literal(value):
 | |
|             return SourcelessBuilder.wrap_constant_literal(value)
 | |
|         elif is_builtin_callable(value):
 | |
|             return BuiltinVariable(value)
 | |
|         elif is_allowed(value):
 | |
|             if is_user_defined_allowed(value):
 | |
|                 self.tx.output.has_user_defined_allowed_in_graph = True
 | |
|             return TorchVariable(value)
 | |
|         elif isinstance(value, types.FunctionType):
 | |
|             return UserFunctionVariable(value)
 | |
|         elif isinstance(value, enum.Enum):
 | |
|             return EnumVariable(value)
 | |
|         elif isinstance(value, (type, abc.ABCMeta)):
 | |
|             return UserDefinedClassVariable(value)
 | |
|         elif isinstance(value, dict):
 | |
|             return ConstDictVariable(
 | |
|                 {k: self(tx, v) for k, v in value.items()},
 | |
|                 dict,
 | |
|                 mutable_local=MutableLocal(),
 | |
|             )
 | |
|         elif isinstance(value, (tuple, list)):
 | |
|             cls = BaseListVariable.cls_for(type(value))
 | |
|             return cls([self(tx, x) for x in value], mutable_local=MutableLocal())
 | |
|         elif isinstance(value, types.MethodWrapperType):
 | |
|             return MethodWrapperVariable(value)
 | |
|         unimplemented(f"Unexpected type in sourceless builder {type(value)}")
 | |
| 
 | |
|     @staticmethod
 | |
|     def wrap_constant_literal(value):
 | |
|         assert ConstantVariable.is_literal(value)
 | |
|         return ConstantVariable(value=value)
 |