import collections import dataclasses import enum import functools import inspect import operator import re import types from typing import Any, Optional, Union import torch from torch import SymInt from torch._guards import GuardSource from torch._ops import PyOperator from torch._subclasses.fake_tensor import FakeTensor from torch.fx.immutable_collections import immutable_list from .. import config, mutation_guard, replay_record, skipfiles from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy from ..exc import unimplemented from ..guards import GuardBuilder from ..side_effects import SideEffects from ..source import ( AttrSource, ConstantSource, GetItemSource, GlobalSource, GlobalWeakRefSource, is_constant_source, LocalInputSource, LocalSource, RandomValueSource, Source, TupleIteratorGetItemSource, ) from ..utils import ( clone_input, get_fake_value, getfile, global_key_name, HAS_NUMPY, is_namedtuple, is_numpy_int_type, is_typing, istensor, istype, np, odict_values, preserve_rng_state, tuple_iterator, tuple_iterator_getitem, tuple_iterator_len, wrap_fake_exception, ) from .base import MutableLocal, typestr from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .dicts import ( ConstDictVariable, DataClassVariable, DefaultDictVariable, HFPretrainedConfigVariable, ) from .functions import UserFunctionVariable from .lists import ( ListIteratorVariable, ListVariable, NamedTupleVariable, RangeVariable, SizeVariable, SliceVariable, TupleVariable, ) from .misc import ( AutogradFunctionContextVariable, AutogradFunctionVariable, ComptimeVariable, GetAttrVariable, InspectSignatureVariable, LambdaVariable, NumpyVariable, PythonModuleVariable, SkipFilesVariable, TypingVariable, ) from .nn_module import UnspecializedNNModuleVariable from .tensor import ( SymNodeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedPythonVariable, ) from .torch import ( tensor_dunder_fns, torch_special_class_types, TorchPyOperator, TorchVariable, ) from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable class _missing: pass @dataclasses.dataclass class GraphArg: source: Source example: Any is_unspecialized: bool fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] # UnspecializedPythonVariable often masquerades as a tensor. # We MUST NOT generate shape guard code # that actually tries to access tensor properties on these values. # is_tensor lets us tell if this graph arg actually is a tensor # or not. is_tensor: bool = True def __post_init__(self): if isinstance(self.example, torch.Tensor): assert isinstance( self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor ) # Mapping for downstream systems to remap back into dynamo arg positions if isinstance(self.source, LocalInputSource): if "graph_arg_pos" not in self.fake_tensor.__dict__: self.fake_tensor.__dict__["graph_arg_pos"] = [] self.fake_tensor.__dict__["graph_arg_pos"].append(self.source.pos) if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor): raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs") def load(self, tx): return self.source.reconstruct(tx) def get_examples(self): return [self.example] def get_fake_examples(self): if self.fake_tensor is not None: assert isinstance( self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor ) return [self.fake_tensor] def __len__(self): return 1 def erase(self): self.example = None class VariableBuilder: """Wrap a python value in a VariableTracker() instance""" def __init__( self, tx, source: Source, ): assert source is not None super().__init__() self.tx = tx self.source = source self.name = source.name() def __call__(self, value): if value in self.tx.output.side_effects: # TODO(jansel): add guard for alias relationship return self.tx.output.side_effects[value] return self._wrap(value).clone(**self.options()) @staticmethod @functools.lru_cache(None) def _common_constants(): return set(range(17)).union( { 20, 30, 40, 32, 64, 96, 128, 144, 240, 256, 672, 1024, 2048, 4096, 0.1, 0.01, 0.001, 0.5, 0.05, 800, 1.873536229133606, 4.135166556742356, # Work around for vision_maskrcnn where torch.clamp can't be on different devices } ) @staticmethod def list_type(value): if is_namedtuple(value): return functools.partial(NamedTupleVariable, tuple_cls=type(value)) return { tuple: TupleVariable, list: ListVariable, odict_values: ListVariable, torch.nn.ParameterList: ListVariable, torch.nn.ModuleList: ListVariable, }[type(value)] def get_source(self): return self.source def options(self): return {"source": self.get_source()} def make_guards(self, *guards): source = self.get_source() if ( isinstance(source, ConstantSource) or source.guard_source() == GuardSource.CONSTANT ): return None return {source.make_guard(guard) for guard in guards} def _wrap(self, value): from ..comptime import comptime make_guards = self.make_guards if istype(value, (torch.SymInt, torch.SymFloat)): return self.wrap_sym(value) if istensor(value): return self.wrap_tensor(value) elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value): # One can index a tensor with a list/tuple. Therefore, we need to # have a stricter match. if istype(value, (tuple, list)) and all( [isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value] ): guards = self.make_guards(GuardBuilder.EQUALS_MATCH) else: guards = self.make_guards(GuardBuilder.LIST_LENGTH) output = [ VariableBuilder(self.tx, GetItemSource(self.get_source(), i))( item ).add_guards(guards) for i, item in enumerate(value) ] result = self.list_type(value)(output, guards=guards) if istype(value, list): return self.tx.output.side_effects.track_list( self.source, value, result ) return result elif istype(value, tuple_iterator): guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN) output = [ VariableBuilder( self.tx, TupleIteratorGetItemSource(self.get_source(), i) )(tuple_iterator_getitem(value, i)).add_guards(guards) for i in range(tuple_iterator_len(value)) ] return ListIteratorVariable( output, mutable_local=MutableLocal(), guards=guards ) elif istype(value, (slice, range)): items = [ VariableBuilder(self.tx, AttrSource(self.get_source(), k))( getattr(value, k) ) for k in ("start", "stop", "step") ] if isinstance(value, slice): return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH)) else: return RangeVariable( items, guards=make_guards(GuardBuilder.EQUALS_MATCH) ) elif istype( value, (dict, collections.defaultdict, collections.OrderedDict) ) and all( map( lambda k: ConstantVariable.is_literal(k) or self.tensor_can_be_dict_key(k) or isinstance(k, enum.Enum), value.keys(), ) ): guards = self.make_guards(GuardBuilder.DICT_KEYS) # store key variables in global location for reconstruction for key in value.keys(): if self.tensor_can_be_dict_key(key): self.tx.store_dict_key(global_key_name(key), key) def index_source(key): if self.tensor_can_be_dict_key(key): return GlobalWeakRefSource(global_key_name(key)) else: return key result = { k: VariableBuilder( self.tx, GetItemSource(self.get_source(), index_source(k)) )(value[k]).add_guards(guards) for k in value.keys() } if istype(value, collections.defaultdict): result = DefaultDictVariable( result, type(value), value.default_factory, guards=guards ) else: result = ConstDictVariable(result, type(value), guards=guards) return self.tx.output.side_effects.track_dict(self.source, value, result) elif isinstance(value, torch.nn.Module): if ( isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) and not config.allow_rnn ): unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") if mutation_guard.is_dynamic_nn_module(value): # created dynamically, don't specialize on it result = UnspecializedNNModuleVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ return result return self.tx.output.side_effects.track_object_existing( self.source, value, result ) elif getattr(value, "_is_fsdp_managed_module", False) or issubclass( value.__class__, torch.nn.parallel.distributed.DistributedDataParallel ): if getattr(value, "_is_fsdp_managed_module", False): # Note: we can't do this assert inside FSDP constructor, # since we don't know yet whether dynamo will be used assert getattr( value, "_fsdp_use_orig_params", False ), "Dynamo only supports FSDP with use_orig_params=True" # See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] # in fully_sharded_data_parallel.py for more information return UnspecializedNNModuleVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) else: return self.tx.output.register_attr_or_module( value, self.name, source=self.get_source(), # Guards are added inside register_attr_or_module ) elif ConstantVariable.is_literal(value) or istype( value, (torch.Size, torch.device, torch.dtype) ): if type(value) in (int, float) and not config.specialize_int_float: # unspecializing int/float by default, but still # specialize for the following conditions if ( value in self._common_constants() or isinstance(self.source, GlobalSource) or isinstance(self.source, GetItemSource) or ( isinstance(self.source, AttrSource) and isinstance(self.source.base, GlobalSource) ) ): return ConstantVariable( value=value, guards=make_guards(GuardBuilder.CONSTANT_MATCH), ) else: return self.wrap_unspecialized_primitive(value) else: return ConstantVariable( value=value, guards=make_guards(GuardBuilder.CONSTANT_MATCH), ) elif isinstance(value, frozenset) and ( all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value) ): # For frozenset, we can guard by object ID instead of value # equality, this allows us to handle non-literal values return ConstantVariable( value=value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif isinstance(value, enum.Enum): return EnumVariable( value=value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif is_builtin_callable(value): return BuiltinVariable( value, source=self.source, guards=make_guards(GuardBuilder.BUILTIN_MATCH), ) elif is_allowed(value): return TorchVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif is_typing(value): # typing.List, typing.Mapping, etc. return TypingVariable( value, source=self.source, guards=make_guards(GuardBuilder.ID_MATCH), ) elif value is inspect.signature: return LambdaVariable( InspectSignatureVariable.create, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif value is comptime: return ComptimeVariable() elif value is dataclasses.fields: return LambdaVariable( _dataclasses_fields_lambda, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif is_numpy(value): return NumpyVariable( value, source=self.source, guards=make_guards( GuardBuilder.FUNCTION_MATCH if callable(value) else GuardBuilder.TYPE_MATCH ), ) elif value in tensor_dunder_fns: return TorchVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif ( istype(value, (type, types.FunctionType)) and skipfiles.check(getfile(value), allow_torch=True) and not inspect.getattr_static(value, "_torchdynamo_inline", False) ): return SkipFilesVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, types.FunctionType): return UserFunctionVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif istype(value, (types.ModuleType, replay_record.DummyModule)): return PythonModuleVariable( value, source=self.source, guards=make_guards(GuardBuilder.PYMODULE_MATCH), ) elif type(value) is torch.autograd.function.FunctionMeta: return AutogradFunctionVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif isinstance(value, torch.autograd.function.FunctionCtx): # The autograd.function context return AutogradFunctionContextVariable() elif ( isinstance(value, types.MethodType) and type(getattr(value, "__self__", None)) is torch.autograd.function.FunctionMeta and getattr(value, "__name__", "") == "apply" and value == getattr(value.__self__, "apply", None) ): # handle aliased autograd function `apply` calls return GetAttrVariable( AutogradFunctionVariable( value.__self__, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ), "apply", ) elif isinstance(value, (int, float)) or ( HAS_NUMPY and (isinstance(value, np.number)) ): return self.wrap_unspecialized_primitive(value) elif DataClassVariable.is_matching_object(value): return DataClassVariable.wrap(self, value).add_guards( make_guards(GuardBuilder.TYPE_MATCH) ) elif HFPretrainedConfigVariable.is_matching_object(value): return HFPretrainedConfigVariable( value, guards=make_guards(GuardBuilder.TYPE_MATCH) ) elif isinstance(value, PyOperator): return TorchPyOperator( value, guards=self.make_guards( GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH ), ) elif type(value).__name__ == "builtin_function_or_method" and isinstance( value.__self__, torch_special_class_types ): return TorchVariable( value, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) elif issubclass(type(value), type): # TODO(whc) the following seems preferable but breaks some tests, debug # elif inspect.isclass(value): return UserDefinedClassVariable( value, source=self.source, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) else: result = UserDefinedObjectVariable( value, source=self.source, guards=self.make_guards(GuardBuilder.TYPE_MATCH), ) if not SideEffects.cls_supports_mutation_side_effects(type(value)): # don't allow STORE_ATTR mutation with custom __setattr__ return result return self.tx.output.side_effects.track_object_existing( self.source, value, result ) def tensor_can_be_dict_key(self, value): # only allow Parameter and another specific Tensor can be used as dict key return ( isinstance(value, torch.nn.Parameter) or isinstance(self.source, AttrSource) and self.source.member == "state" and isinstance(self.source.base, LocalSource) ) def tensor_should_specialize(self): return ( self.source and isinstance(self.source, GetItemSource) and isinstance(self.source.base, GetItemSource) and self.source.base.index == "params" and isinstance(self.source.base.base, GetItemSource) and isinstance(self.source.base.base.base, AttrSource) and self.source.base.base.base.member == "param_groups" and isinstance(self.source.base.base.base.base, LocalSource) and ( isinstance( self.tx.f_locals[self.source.base.base.base.base.local_name], torch.optim.Optimizer, ) if self.source.base.base.base.base.local_name in self.tx.f_locals.keys() else True ) ) def wrap_sym(self, value: Union[torch.SymInt, torch.SymFloat]): if not is_constant_source(self.get_source()): self.tx.output.add_grapharg(GraphArg(self.get_source(), value, False, None)) elif is_constant_source(self.get_source()): return self.tx.output.register_attr_or_module( value, re.sub(r"[^a-zA-Z0-9]+", "_", self.name), source=None, sym_num=value # shape Guards live their own rich life via shape_env ) return SymNodeVariable.create( tx=self.tx, proxy=self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) ), sym_num=value # shape Guards live their own rich life via shape_env ) def wrap_tensor(self, value: torch.Tensor): if self.get_source().guard_source().is_nn_module(): return self.tx.output.register_attr_or_module( value, self.name, source=self.get_source(), # Guards are done inside register_attr_or_module # guards=self.make_guards(GuardBuilder.TENSOR_MATCH), ) if is_constant_source(self.get_source()): return self.tx.output.register_attr_or_module( value, re.sub(r"[^a-zA-Z0-9]+", "_", self.name), source=self.get_source(), # Guards are added inside register_attr_or_module ) if type(value) in config.traceable_tensor_subclasses: # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary # inheritance (nor multiple inheritance) won't work work. # # Instead, our plan is to *manually simulate* the tensor subclass # inheriting from a fake tensor with dynamo. This means our # data representation for a tensor subclass will be a fake tensor # + tensor subclass type + any extra data the subclass may have # been storing on the tensor. Because all Python accesses are # mediated through TensorWithTFOverrideVariable, we can ensure # that we dispatch differently, e.g., according to # __torch_function__ # # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. ignore_subclass = True else: assert type(value) in (torch.Tensor, torch.nn.Parameter) ignore_subclass = False tensor_proxy = self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) ) tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, example_value=value, guards=self.make_guards(GuardBuilder.TENSOR_MATCH), should_specialize=self.tensor_should_specialize(), ignore_subclass=ignore_subclass, source=self.get_source(), ) assert "tensor_dict" not in tensor_proxy.node.meta tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy() # TODO: I think the result is guaranteed to be fake with # ignore_subclass changes fake_tensor_value = None example_value = tensor_variable.proxy.node.meta["example_value"] if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): fake_tensor_value = example_value self.tx.output.add_grapharg( GraphArg(self.get_source(), value, False, fake_tensor_value) ) if type(value) in config.traceable_tensor_subclasses: subclass_torch_function__func = value.__torch_function__.__func__ subclass_type = type(value) # NB: This is slightly misnamed, a tensor subclass might not have # any explicit __torch_function__ implementation and is relying # on the default inherited from torch.Tensor return TensorWithTFOverrideVariable( tensor_variable, self.get_source(), subclass_torch_function__func, subclass_type, ) return tensor_variable def wrap_unspecialized_primitive(self, value): if self.name in self.tx.output.unspec_variable_map: return self.tx.output.unspec_variable_map[self.name] else: if ( config.dynamic_shapes and isinstance(value, int) and not is_constant_source(self.get_source()) ): shape_env = self.tx.output.shape_env wrapped_value = shape_env.create_symintnode( shape_env.create_symbol(value, source=self.source), hint=value ) self.tx.output.tracked_fakes.append( TrackedFake(wrapped_value, self.source) ) # TODO: Do float else: # TODO: Eliminate this case entirely wrapped_value = torch.tensor(value) if not isinstance(self.get_source(), RandomValueSource): guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)} options = {"guards": guards} else: options = {} options.update({"source": self.get_source()}) if isinstance(wrapped_value, torch.Tensor): options.update({"raw_value": value}) proxy = self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(wrapped_value) ) unspec_var = wrap_fx_proxy_cls( UnspecializedPythonVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) self.tx.output.unspec_variable_map[self.name] = unspec_var if not is_constant_source(self.get_source()): fake_tensor_value = None example_value = unspec_var.proxy.node.meta["example_value"] if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): fake_tensor_value = example_value self.tx.output.add_grapharg( GraphArg( self.get_source(), wrapped_value, True, fake_tensor_value, is_tensor=False, ) ) return unspec_var def _dataclasses_fields_lambda(obj): if isinstance(obj, UserDefinedObjectVariable): value = obj.value elif isinstance(obj, DataClassVariable): value = obj.user_cls else: unimplemented(f"Dataclass fields handling fails for type {obj}") items = [] for field in dataclasses.fields(value): source = None if obj.source: source = GetItemSource( AttrSource(obj.source, "__dataclass_fields__"), field.name ) items.append(UserDefinedObjectVariable(field, source=source).add_options(obj)) return TupleVariable(items).add_options(obj) def wrap_fx_proxy(tx, proxy, example_value=None, **options): return wrap_fx_proxy_cls( target_cls=TensorVariable, tx=tx, proxy=proxy, example_value=example_value, **options, ) # Note: Unfortunate split due to some gross classes existing that subclass TensorVariable # Should be compositional instead def wrap_fx_proxy_cls( target_cls, tx, proxy, example_value=None, ignore_subclass=False, **options ): from ..symbolic_convert import InstructionTranslatorBase assert isinstance(tx, InstructionTranslatorBase) if "guards" in options and options["guards"] is not None: tx.output.guards.update(options["guards"]) assert "example_value" not in proxy.node.meta initial_example_value = example_value def _clone_input(value): if isinstance(value, torch.Tensor): # tensor subclasses will not be converted to FakeTensors and need to be cloned if not isinstance(value, torch._subclasses.fake_tensor.FakeTensor): # NB: ensure strides are preserved value = clone_input(value) return value with preserve_rng_state(): if example_value is None: example_value = get_fake_value(proxy.node, tx) # Handle recursive calls here elif isinstance(example_value, FakeTensor): pass elif isinstance(example_value, torch.Tensor): if tx.export: # The legacy behavior for real value cache with subclasses was # to perform a clone WITHOUT preserving the subclass. It's # not entirely clear this is what you actually want though. with torch._C.DisableTorchFunctionSubclass(): proxy.tracer.real_value_cache[proxy.node] = _clone_input( example_value ) # NB: If we're ignoring subclass, then the expectation is you will # take the returned TensorVariable and wrap it into a more # accurate TensorVariable that is able to track subclass-ness; # otherwise this is wrong! kwargs = { "ignore_subclass": ignore_subclass, "is_tensor": target_cls is TensorVariable, } assert "source" in options and options["source"] is not None kwargs["source"] = options["source"] example_value = wrap_to_fake_tensor_and_record( example_value, tx=tx, **kwargs ) if isinstance(example_value, torch.Tensor): is_parameter = isinstance(example_value, torch.nn.Parameter) should_specialize = options.pop("should_specialize", False) if is_parameter or should_specialize: specialized_value = initial_example_value else: specialized_value = None # NB: In most (all?) cases, this does not actually do a clone. # (WARNING: this means that if we mutate metadata on the fake # tensor, the stored example value will update too!) example_value = _clone_input(example_value) proxy.node.meta["example_value"] = example_value specialized_props = target_cls.specialize(example_value) if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor): # NB: This will be wrong for ignore_subclass; fix it up later! specialized_props["class_type"] = ( torch.nn.Parameter if is_parameter else torch.Tensor ) specialized_props["specialized_value"] = specialized_value options.update(specialized_props) return target_cls(proxy, **options) elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" and isinstance(proxy.node.target.__self__, torch._C.Generator) or proxy.node.target == torch.random.set_rng_state ): from . import TorchVariable return TorchVariable(proxy.node.target) elif ( proxy.node.target == torch._C._DisableFuncTorch or proxy.node.target == torch.cuda._is_in_bad_fork ): from . import UserDefinedObjectVariable return UserDefinedObjectVariable(example_value) elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: proxy.node.meta["example_value"] = example_value return SymNodeVariable.create(tx, proxy, example_value, **options) elif istype(example_value, torch.Size) and config.dynamic_shapes: proxy.node.meta["example_value"] = example_value sizes = [] for i, v in enumerate(example_value): proxy_i = proxy[i] sizes.append(SymNodeVariable.create(tx, proxy_i, v, **options)) return SizeVariable(sizes, proxy, **options) elif istype(example_value, int) and proxy.node.target in ( torch.seed, operator.mod, # some mac builds are missing torch.distributed.get_rank() getattr(torch.distributed, "get_rank", _missing), getattr(torch.distributed, "get_world_size", _missing), ): if config.dynamic_shapes: proxy.node.meta["example_value"] = example_value return SymNodeVariable.create(tx, proxy, example_value, **options) else: return ConstantVariable(example_value, **options) elif istype(example_value, torch.Size) and all( [isinstance(x, int) for x in example_value] ): sizes = [ConstantVariable(x) for x in example_value] return SizeVariable(sizes, **options) elif isinstance(example_value, (tuple, list)): unpacked = [] for i, val in enumerate(example_value): if val is None: # nn.MultiheadAttention() can return None, see issue #175 unpacked.append( ConstantVariable(None, **options), ) else: unpacked.append( wrap_fx_proxy( tx, proxy.tracer.create_proxy( "call_function", operator.getitem, (proxy, i), {} ), example_value=val, **options, ) ) if istype(example_value, tuple): return TupleVariable(unpacked, **options) elif istype(example_value, (list, immutable_list)): return ListVariable(unpacked, mutable_local=MutableLocal(), **options) else: assert ( example_value.__class__.__module__ == "torch.return_types" or hasattr(example_value, "_fields") ), ("namedtuple?") return NamedTupleVariable(unpacked, example_value.__class__, **options) elif example_value is None or proxy.node.target is torch.manual_seed: return ConstantVariable(None, **options) elif ( isinstance(example_value, int) and proxy.node.target is torch._utils._element_size ): proxy.node.meta["example_value"] = example_value return ConstantVariable(example_value, **options) elif isinstance(example_value, (torch.SymInt, torch.SymFloat)): proxy.node.meta["example_value"] = example_value return SymNodeVariable(proxy, example_value, **options) elif proxy.node.target in [torch.cuda.streams.Stream, torch.cuda.current_stream]: from . import CUDAStreamVariable proxy.node.meta["example_value"] = example_value return CUDAStreamVariable(proxy, example_value, **options) else: unimplemented( "torch.* op returned non-Tensor " + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" ) # Tracks the sources of all fake tensors we wrap in Dynamo. # Used by shape guard computation. @dataclasses.dataclass class TrackedFake: fake: Union[FakeTensor, SymInt] source: Source def wrap_to_fake_tensor_and_record( e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool ): if type(e) in (torch.Tensor, torch.nn.Parameter) or ( ignore_subclass and isinstance(e, torch.Tensor) ): static_shapes = ( source is None or type(e) is torch.nn.Parameter or config.dynamic_shapes is False or not is_tensor ) fake_e = wrap_fake_exception( lambda: tx.fake_mode.from_tensor( e, static_shapes=static_shapes, ignore_subclass=ignore_subclass, source=source, ) ) if is_tensor: tx.output.tracked_fakes.append(TrackedFake(fake_e, source)) return fake_e else: return e