mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
I applied some flake8 fixes and enabled checking for them in the linter. I also enabled some checks for my previous comprehensions PR. This is a follow up to #94323 where I enable the flake8 checkers for the fixes I made and fix a few more of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94601 Approved by: https://github.com/ezyang
970 lines
36 KiB
Python
970 lines
36 KiB
Python
import collections
|
|
import dataclasses
|
|
import enum
|
|
import functools
|
|
import inspect
|
|
import operator
|
|
import re
|
|
import types
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
|
|
from torch import SymInt
|
|
from torch._guards import GuardSource
|
|
from torch._ops import PyOperator
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch.fx.immutable_collections import immutable_list
|
|
|
|
from .. import config, mutation_guard, replay_record, skipfiles
|
|
from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
|
|
from ..exc import unimplemented
|
|
from ..guards import GuardBuilder
|
|
from ..side_effects import SideEffects
|
|
from ..source import (
|
|
AttrSource,
|
|
ConstantSource,
|
|
GetItemSource,
|
|
GlobalSource,
|
|
GlobalWeakRefSource,
|
|
is_constant_source,
|
|
LocalInputSource,
|
|
LocalSource,
|
|
RandomValueSource,
|
|
Source,
|
|
TupleIteratorGetItemSource,
|
|
)
|
|
from ..utils import (
|
|
clone_input,
|
|
get_fake_value,
|
|
getfile,
|
|
global_key_name,
|
|
HAS_NUMPY,
|
|
is_namedtuple,
|
|
is_numpy_int_type,
|
|
is_typing,
|
|
istensor,
|
|
istype,
|
|
np,
|
|
odict_values,
|
|
preserve_rng_state,
|
|
tuple_iterator,
|
|
tuple_iterator_getitem,
|
|
tuple_iterator_len,
|
|
wrap_fake_exception,
|
|
)
|
|
|
|
from .base import MutableLocal, typestr
|
|
from .builtin import BuiltinVariable
|
|
from .constant import ConstantVariable, EnumVariable
|
|
from .dicts import (
|
|
ConstDictVariable,
|
|
DataClassVariable,
|
|
DefaultDictVariable,
|
|
HFPretrainedConfigVariable,
|
|
)
|
|
from .functions import UserFunctionVariable
|
|
from .lists import (
|
|
ListIteratorVariable,
|
|
ListVariable,
|
|
NamedTupleVariable,
|
|
RangeVariable,
|
|
SizeVariable,
|
|
SliceVariable,
|
|
TupleVariable,
|
|
)
|
|
from .misc import (
|
|
AutogradFunctionContextVariable,
|
|
AutogradFunctionVariable,
|
|
ComptimeVariable,
|
|
GetAttrVariable,
|
|
InspectSignatureVariable,
|
|
LambdaVariable,
|
|
NumpyVariable,
|
|
PythonModuleVariable,
|
|
SkipFilesVariable,
|
|
TypingVariable,
|
|
)
|
|
from .nn_module import UnspecializedNNModuleVariable
|
|
from .tensor import (
|
|
SymNodeVariable,
|
|
TensorVariable,
|
|
TensorWithTFOverrideVariable,
|
|
UnspecializedPythonVariable,
|
|
)
|
|
from .torch import (
|
|
tensor_dunder_fns,
|
|
torch_special_class_types,
|
|
TorchPyOperator,
|
|
TorchVariable,
|
|
)
|
|
from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable
|
|
|
|
|
|
class _missing:
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphArg:
|
|
source: Source
|
|
example: Any
|
|
is_unspecialized: bool
|
|
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor]
|
|
# UnspecializedPythonVariable often masquerades as a tensor.
|
|
# We MUST NOT generate shape guard code
|
|
# that actually tries to access tensor properties on these values.
|
|
# is_tensor lets us tell if this graph arg actually is a tensor
|
|
# or not.
|
|
is_tensor: bool = True
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.example, torch.Tensor):
|
|
assert isinstance(
|
|
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
|
|
)
|
|
# Mapping for downstream systems to remap back into dynamo arg positions
|
|
if isinstance(self.source, LocalInputSource):
|
|
if "graph_arg_pos" not in self.fake_tensor.__dict__:
|
|
self.fake_tensor.__dict__["graph_arg_pos"] = []
|
|
self.fake_tensor.__dict__["graph_arg_pos"].append(self.source.pos)
|
|
if isinstance(self.example, torch._subclasses.fake_tensor.FakeTensor):
|
|
raise AssertionError("Fake Tensor observed in TorchDynamo Fx graph inputs")
|
|
|
|
def load(self, tx):
|
|
return self.source.reconstruct(tx)
|
|
|
|
def get_examples(self):
|
|
return [self.example]
|
|
|
|
def get_fake_examples(self):
|
|
if self.fake_tensor is not None:
|
|
assert isinstance(
|
|
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
|
|
)
|
|
return [self.fake_tensor]
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def erase(self):
|
|
self.example = None
|
|
|
|
|
|
class VariableBuilder:
|
|
"""Wrap a python value in a VariableTracker() instance"""
|
|
|
|
def __init__(
|
|
self,
|
|
tx,
|
|
source: Source,
|
|
):
|
|
assert source is not None
|
|
super().__init__()
|
|
self.tx = tx
|
|
self.source = source
|
|
self.name = source.name()
|
|
|
|
def __call__(self, value):
|
|
if value in self.tx.output.side_effects:
|
|
# TODO(jansel): add guard for alias relationship
|
|
return self.tx.output.side_effects[value]
|
|
return self._wrap(value).clone(**self.options())
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def _common_constants():
|
|
return set(range(17)).union(
|
|
{
|
|
20,
|
|
30,
|
|
40,
|
|
32,
|
|
64,
|
|
96,
|
|
128,
|
|
144,
|
|
240,
|
|
256,
|
|
672,
|
|
1024,
|
|
2048,
|
|
4096,
|
|
0.1,
|
|
0.01,
|
|
0.001,
|
|
0.5,
|
|
0.05,
|
|
800,
|
|
1.873536229133606,
|
|
4.135166556742356, # Work around for vision_maskrcnn where torch.clamp can't be on different devices
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def list_type(value):
|
|
if is_namedtuple(value):
|
|
return functools.partial(NamedTupleVariable, tuple_cls=type(value))
|
|
return {
|
|
tuple: TupleVariable,
|
|
list: ListVariable,
|
|
odict_values: ListVariable,
|
|
torch.nn.ParameterList: ListVariable,
|
|
torch.nn.ModuleList: ListVariable,
|
|
}[type(value)]
|
|
|
|
def get_source(self):
|
|
return self.source
|
|
|
|
def options(self):
|
|
return {"source": self.get_source()}
|
|
|
|
def make_guards(self, *guards):
|
|
source = self.get_source()
|
|
if (
|
|
isinstance(source, ConstantSource)
|
|
or source.guard_source() == GuardSource.CONSTANT
|
|
):
|
|
return None
|
|
return {source.make_guard(guard) for guard in guards}
|
|
|
|
def _wrap(self, value):
|
|
from ..comptime import comptime
|
|
|
|
make_guards = self.make_guards
|
|
if istype(value, (torch.SymInt, torch.SymFloat)):
|
|
return self.wrap_sym(value)
|
|
if istensor(value):
|
|
return self.wrap_tensor(value)
|
|
elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value):
|
|
# One can index a tensor with a list/tuple. Therefore, we need to
|
|
# have a stricter match.
|
|
if istype(value, (tuple, list)) and all(
|
|
[isinstance(x, int) or is_numpy_int_type(x) or x is None for x in value]
|
|
):
|
|
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
|
|
else:
|
|
guards = self.make_guards(GuardBuilder.LIST_LENGTH)
|
|
output = [
|
|
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
|
|
item
|
|
).add_guards(guards)
|
|
for i, item in enumerate(value)
|
|
]
|
|
result = self.list_type(value)(output, guards=guards)
|
|
if istype(value, list):
|
|
return self.tx.output.side_effects.track_list(
|
|
self.source, value, result
|
|
)
|
|
return result
|
|
elif istype(value, tuple_iterator):
|
|
guards = self.make_guards(GuardBuilder.TUPLE_ITERATOR_LEN)
|
|
output = [
|
|
VariableBuilder(
|
|
self.tx, TupleIteratorGetItemSource(self.get_source(), i)
|
|
)(tuple_iterator_getitem(value, i)).add_guards(guards)
|
|
for i in range(tuple_iterator_len(value))
|
|
]
|
|
return ListIteratorVariable(
|
|
output, mutable_local=MutableLocal(), guards=guards
|
|
)
|
|
elif istype(value, (slice, range)):
|
|
items = [
|
|
VariableBuilder(self.tx, AttrSource(self.get_source(), k))(
|
|
getattr(value, k)
|
|
)
|
|
for k in ("start", "stop", "step")
|
|
]
|
|
if isinstance(value, slice):
|
|
return SliceVariable(items, guards=make_guards(GuardBuilder.TYPE_MATCH))
|
|
else:
|
|
return RangeVariable(
|
|
items, guards=make_guards(GuardBuilder.EQUALS_MATCH)
|
|
)
|
|
elif istype(
|
|
value, (dict, collections.defaultdict, collections.OrderedDict)
|
|
) and all(
|
|
map(
|
|
lambda k: ConstantVariable.is_literal(k)
|
|
or self.tensor_can_be_dict_key(k)
|
|
or isinstance(k, enum.Enum),
|
|
value.keys(),
|
|
)
|
|
):
|
|
guards = self.make_guards(GuardBuilder.DICT_KEYS)
|
|
|
|
# store key variables in global location for reconstruction
|
|
for key in value.keys():
|
|
if self.tensor_can_be_dict_key(key):
|
|
self.tx.store_dict_key(global_key_name(key), key)
|
|
|
|
def index_source(key):
|
|
if self.tensor_can_be_dict_key(key):
|
|
return GlobalWeakRefSource(global_key_name(key))
|
|
else:
|
|
return key
|
|
|
|
result = {
|
|
k: VariableBuilder(
|
|
self.tx, GetItemSource(self.get_source(), index_source(k))
|
|
)(value[k]).add_guards(guards)
|
|
for k in value.keys()
|
|
}
|
|
|
|
if istype(value, collections.defaultdict):
|
|
result = DefaultDictVariable(
|
|
result, type(value), value.default_factory, guards=guards
|
|
)
|
|
else:
|
|
result = ConstDictVariable(result, type(value), guards=guards)
|
|
|
|
return self.tx.output.side_effects.track_dict(self.source, value, result)
|
|
elif isinstance(value, torch.nn.Module):
|
|
if (
|
|
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM))
|
|
and not config.allow_rnn
|
|
):
|
|
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs")
|
|
if mutation_guard.is_dynamic_nn_module(value):
|
|
# created dynamically, don't specialize on it
|
|
result = UnspecializedNNModuleVariable(
|
|
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
|
|
)
|
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
|
return result
|
|
return self.tx.output.side_effects.track_object_existing(
|
|
self.source, value, result
|
|
)
|
|
elif getattr(value, "_is_fsdp_managed_module", False) or issubclass(
|
|
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel
|
|
):
|
|
if getattr(value, "_is_fsdp_managed_module", False):
|
|
# Note: we can't do this assert inside FSDP constructor,
|
|
# since we don't know yet whether dynamo will be used
|
|
assert getattr(
|
|
value, "_fsdp_use_orig_params", False
|
|
), "Dynamo only supports FSDP with use_orig_params=True"
|
|
|
|
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule]
|
|
# in fully_sharded_data_parallel.py for more information
|
|
return UnspecializedNNModuleVariable(
|
|
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
|
|
)
|
|
else:
|
|
return self.tx.output.register_attr_or_module(
|
|
value,
|
|
self.name,
|
|
source=self.get_source(),
|
|
# Guards are added inside register_attr_or_module
|
|
)
|
|
elif ConstantVariable.is_literal(value) or istype(
|
|
value, (torch.Size, torch.device, torch.dtype)
|
|
):
|
|
if type(value) in (int, float) and not config.specialize_int_float:
|
|
# unspecializing int/float by default, but still
|
|
# specialize for the following conditions
|
|
if (
|
|
value in self._common_constants()
|
|
or isinstance(self.source, GlobalSource)
|
|
or isinstance(self.source, GetItemSource)
|
|
or (
|
|
isinstance(self.source, AttrSource)
|
|
and isinstance(self.source.base, GlobalSource)
|
|
)
|
|
):
|
|
return ConstantVariable(
|
|
value=value,
|
|
guards=make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
)
|
|
else:
|
|
return self.wrap_unspecialized_primitive(value)
|
|
else:
|
|
return ConstantVariable(
|
|
value=value,
|
|
guards=make_guards(GuardBuilder.CONSTANT_MATCH),
|
|
)
|
|
elif isinstance(value, frozenset) and (
|
|
all(is_allowed(x) or ConstantVariable.is_literal(x) for x in value)
|
|
):
|
|
# For frozenset, we can guard by object ID instead of value
|
|
# equality, this allows us to handle non-literal values
|
|
return ConstantVariable(
|
|
value=value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
)
|
|
elif isinstance(value, enum.Enum):
|
|
return EnumVariable(
|
|
value=value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
)
|
|
elif is_builtin_callable(value):
|
|
return BuiltinVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.BUILTIN_MATCH),
|
|
)
|
|
elif is_allowed(value):
|
|
return TorchVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif is_typing(value):
|
|
# typing.List, typing.Mapping, etc.
|
|
return TypingVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.ID_MATCH),
|
|
)
|
|
elif value is inspect.signature:
|
|
return LambdaVariable(
|
|
InspectSignatureVariable.create,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif value is comptime:
|
|
return ComptimeVariable()
|
|
elif value is dataclasses.fields:
|
|
return LambdaVariable(
|
|
_dataclasses_fields_lambda,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif is_numpy(value):
|
|
return NumpyVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(
|
|
GuardBuilder.FUNCTION_MATCH
|
|
if callable(value)
|
|
else GuardBuilder.TYPE_MATCH
|
|
),
|
|
)
|
|
elif value in tensor_dunder_fns:
|
|
return TorchVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif (
|
|
istype(value, (type, types.FunctionType))
|
|
and skipfiles.check(getfile(value), allow_torch=True)
|
|
and not inspect.getattr_static(value, "_torchdynamo_inline", False)
|
|
):
|
|
return SkipFilesVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif istype(value, types.FunctionType):
|
|
return UserFunctionVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif istype(value, (types.ModuleType, replay_record.DummyModule)):
|
|
return PythonModuleVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.PYMODULE_MATCH),
|
|
)
|
|
elif type(value) is torch.autograd.function.FunctionMeta:
|
|
return AutogradFunctionVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif isinstance(value, torch.autograd.function.FunctionCtx):
|
|
# The autograd.function context
|
|
return AutogradFunctionContextVariable()
|
|
elif (
|
|
isinstance(value, types.MethodType)
|
|
and type(getattr(value, "__self__", None))
|
|
is torch.autograd.function.FunctionMeta
|
|
and getattr(value, "__name__", "") == "apply"
|
|
and value == getattr(value.__self__, "apply", None)
|
|
):
|
|
# handle aliased autograd function `apply` calls
|
|
return GetAttrVariable(
|
|
AutogradFunctionVariable(
|
|
value.__self__,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
),
|
|
"apply",
|
|
)
|
|
elif isinstance(value, (int, float)) or (
|
|
HAS_NUMPY and (isinstance(value, np.number))
|
|
):
|
|
return self.wrap_unspecialized_primitive(value)
|
|
elif DataClassVariable.is_matching_object(value):
|
|
return DataClassVariable.wrap(self, value).add_guards(
|
|
make_guards(GuardBuilder.TYPE_MATCH)
|
|
)
|
|
elif HFPretrainedConfigVariable.is_matching_object(value):
|
|
return HFPretrainedConfigVariable(
|
|
value, guards=make_guards(GuardBuilder.TYPE_MATCH)
|
|
)
|
|
elif isinstance(value, PyOperator):
|
|
return TorchPyOperator(
|
|
value,
|
|
guards=self.make_guards(
|
|
GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH
|
|
),
|
|
)
|
|
elif type(value).__name__ == "builtin_function_or_method" and isinstance(
|
|
value.__self__, torch_special_class_types
|
|
):
|
|
return TorchVariable(
|
|
value,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
elif issubclass(type(value), type):
|
|
# TODO(whc) the following seems preferable but breaks some tests, debug
|
|
# elif inspect.isclass(value):
|
|
return UserDefinedClassVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
|
|
)
|
|
else:
|
|
result = UserDefinedObjectVariable(
|
|
value,
|
|
source=self.source,
|
|
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
|
|
)
|
|
if not SideEffects.cls_supports_mutation_side_effects(type(value)):
|
|
# don't allow STORE_ATTR mutation with custom __setattr__
|
|
return result
|
|
return self.tx.output.side_effects.track_object_existing(
|
|
self.source, value, result
|
|
)
|
|
|
|
def tensor_can_be_dict_key(self, value):
|
|
# only allow Parameter and another specific Tensor can be used as dict key
|
|
return (
|
|
isinstance(value, torch.nn.Parameter)
|
|
or isinstance(self.source, AttrSource)
|
|
and self.source.member == "state"
|
|
and isinstance(self.source.base, LocalSource)
|
|
)
|
|
|
|
def tensor_should_specialize(self):
|
|
return (
|
|
self.source
|
|
and isinstance(self.source, GetItemSource)
|
|
and isinstance(self.source.base, GetItemSource)
|
|
and self.source.base.index == "params"
|
|
and isinstance(self.source.base.base, GetItemSource)
|
|
and isinstance(self.source.base.base.base, AttrSource)
|
|
and self.source.base.base.base.member == "param_groups"
|
|
and isinstance(self.source.base.base.base.base, LocalSource)
|
|
and (
|
|
isinstance(
|
|
self.tx.f_locals[self.source.base.base.base.base.local_name],
|
|
torch.optim.Optimizer,
|
|
)
|
|
if self.source.base.base.base.base.local_name in self.tx.f_locals.keys()
|
|
else True
|
|
)
|
|
)
|
|
|
|
def wrap_sym(self, value: Union[torch.SymInt, torch.SymFloat]):
|
|
if not is_constant_source(self.get_source()):
|
|
self.tx.output.add_grapharg(GraphArg(self.get_source(), value, False, None))
|
|
elif is_constant_source(self.get_source()):
|
|
return self.tx.output.register_attr_or_module(
|
|
value,
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
|
source=None,
|
|
sym_num=value
|
|
# shape Guards live their own rich life via shape_env
|
|
)
|
|
return SymNodeVariable.create(
|
|
tx=self.tx,
|
|
proxy=self.tx.output.create_graph_input(
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
|
|
),
|
|
sym_num=value
|
|
# shape Guards live their own rich life via shape_env
|
|
)
|
|
|
|
def wrap_tensor(self, value: torch.Tensor):
|
|
if self.get_source().guard_source().is_nn_module():
|
|
return self.tx.output.register_attr_or_module(
|
|
value,
|
|
self.name,
|
|
source=self.get_source(),
|
|
# Guards are done inside register_attr_or_module
|
|
# guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
|
|
)
|
|
|
|
if is_constant_source(self.get_source()):
|
|
return self.tx.output.register_attr_or_module(
|
|
value,
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
|
source=self.get_source(),
|
|
# Guards are added inside register_attr_or_module
|
|
)
|
|
|
|
if type(value) in config.traceable_tensor_subclasses:
|
|
# Ordinarily, we would fakeify a tensor so that it can get dynamic
|
|
# shapes and be computed on without triggering actual operations.
|
|
# However, how can we fakeify a tensor subclass? Ordinary
|
|
# inheritance (nor multiple inheritance) won't work work.
|
|
#
|
|
# Instead, our plan is to *manually simulate* the tensor subclass
|
|
# inheriting from a fake tensor with dynamo. This means our
|
|
# data representation for a tensor subclass will be a fake tensor
|
|
# + tensor subclass type + any extra data the subclass may have
|
|
# been storing on the tensor. Because all Python accesses are
|
|
# mediated through TensorWithTFOverrideVariable, we can ensure
|
|
# that we dispatch differently, e.g., according to
|
|
# __torch_function__
|
|
#
|
|
# To simplify things for now, the __dict__ tracking bits haven't
|
|
# been implemented yet, but they can be added into this design at
|
|
# a later point in time.
|
|
ignore_subclass = True
|
|
else:
|
|
assert type(value) in (torch.Tensor, torch.nn.Parameter)
|
|
ignore_subclass = False
|
|
|
|
tensor_proxy = self.tx.output.create_graph_input(
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
|
|
)
|
|
tensor_variable = wrap_fx_proxy(
|
|
tx=self.tx,
|
|
proxy=tensor_proxy,
|
|
example_value=value,
|
|
guards=self.make_guards(GuardBuilder.TENSOR_MATCH),
|
|
should_specialize=self.tensor_should_specialize(),
|
|
ignore_subclass=ignore_subclass,
|
|
source=self.get_source(),
|
|
)
|
|
assert "tensor_dict" not in tensor_proxy.node.meta
|
|
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy()
|
|
|
|
# TODO: I think the result is guaranteed to be fake with
|
|
# ignore_subclass changes
|
|
fake_tensor_value = None
|
|
example_value = tensor_variable.proxy.node.meta["example_value"]
|
|
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
|
|
fake_tensor_value = example_value
|
|
|
|
self.tx.output.add_grapharg(
|
|
GraphArg(self.get_source(), value, False, fake_tensor_value)
|
|
)
|
|
|
|
if type(value) in config.traceable_tensor_subclasses:
|
|
subclass_torch_function__func = value.__torch_function__.__func__
|
|
subclass_type = type(value)
|
|
# NB: This is slightly misnamed, a tensor subclass might not have
|
|
# any explicit __torch_function__ implementation and is relying
|
|
# on the default inherited from torch.Tensor
|
|
return TensorWithTFOverrideVariable(
|
|
tensor_variable,
|
|
self.get_source(),
|
|
subclass_torch_function__func,
|
|
subclass_type,
|
|
)
|
|
|
|
return tensor_variable
|
|
|
|
def wrap_unspecialized_primitive(self, value):
|
|
if self.name in self.tx.output.unspec_variable_map:
|
|
return self.tx.output.unspec_variable_map[self.name]
|
|
else:
|
|
if (
|
|
config.dynamic_shapes
|
|
and isinstance(value, int)
|
|
and not is_constant_source(self.get_source())
|
|
):
|
|
shape_env = self.tx.output.shape_env
|
|
wrapped_value = shape_env.create_symintnode(
|
|
shape_env.create_symbol(value, source=self.source), hint=value
|
|
)
|
|
self.tx.output.tracked_fakes.append(
|
|
TrackedFake(wrapped_value, self.source)
|
|
)
|
|
# TODO: Do float
|
|
else:
|
|
# TODO: Eliminate this case entirely
|
|
wrapped_value = torch.tensor(value)
|
|
if not isinstance(self.get_source(), RandomValueSource):
|
|
guards = {self.get_source().make_guard(GuardBuilder.TYPE_MATCH, True)}
|
|
options = {"guards": guards}
|
|
else:
|
|
options = {}
|
|
options.update({"source": self.get_source()})
|
|
if isinstance(wrapped_value, torch.Tensor):
|
|
options.update({"raw_value": value})
|
|
|
|
proxy = self.tx.output.create_graph_input(
|
|
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(wrapped_value)
|
|
)
|
|
|
|
unspec_var = wrap_fx_proxy_cls(
|
|
UnspecializedPythonVariable,
|
|
tx=self.tx,
|
|
proxy=proxy,
|
|
example_value=wrapped_value,
|
|
**options,
|
|
)
|
|
self.tx.output.unspec_variable_map[self.name] = unspec_var
|
|
if not is_constant_source(self.get_source()):
|
|
fake_tensor_value = None
|
|
example_value = unspec_var.proxy.node.meta["example_value"]
|
|
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
|
|
fake_tensor_value = example_value
|
|
self.tx.output.add_grapharg(
|
|
GraphArg(
|
|
self.get_source(),
|
|
wrapped_value,
|
|
True,
|
|
fake_tensor_value,
|
|
is_tensor=False,
|
|
)
|
|
)
|
|
return unspec_var
|
|
|
|
|
|
def _dataclasses_fields_lambda(obj):
|
|
if isinstance(obj, UserDefinedObjectVariable):
|
|
value = obj.value
|
|
elif isinstance(obj, DataClassVariable):
|
|
value = obj.user_cls
|
|
else:
|
|
unimplemented(f"Dataclass fields handling fails for type {obj}")
|
|
items = []
|
|
for field in dataclasses.fields(value):
|
|
source = None
|
|
if obj.source:
|
|
source = GetItemSource(
|
|
AttrSource(obj.source, "__dataclass_fields__"), field.name
|
|
)
|
|
items.append(UserDefinedObjectVariable(field, source=source).add_options(obj))
|
|
return TupleVariable(items).add_options(obj)
|
|
|
|
|
|
def wrap_fx_proxy(tx, proxy, example_value=None, **options):
|
|
return wrap_fx_proxy_cls(
|
|
target_cls=TensorVariable,
|
|
tx=tx,
|
|
proxy=proxy,
|
|
example_value=example_value,
|
|
**options,
|
|
)
|
|
|
|
|
|
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable
|
|
# Should be compositional instead
|
|
def wrap_fx_proxy_cls(
|
|
target_cls, tx, proxy, example_value=None, ignore_subclass=False, **options
|
|
):
|
|
from ..symbolic_convert import InstructionTranslatorBase
|
|
|
|
assert isinstance(tx, InstructionTranslatorBase)
|
|
if "guards" in options and options["guards"] is not None:
|
|
tx.output.guards.update(options["guards"])
|
|
|
|
assert "example_value" not in proxy.node.meta
|
|
|
|
initial_example_value = example_value
|
|
|
|
def _clone_input(value):
|
|
if isinstance(value, torch.Tensor):
|
|
# tensor subclasses will not be converted to FakeTensors and need to be cloned
|
|
if not isinstance(value, torch._subclasses.fake_tensor.FakeTensor):
|
|
# NB: ensure strides are preserved
|
|
value = clone_input(value)
|
|
|
|
return value
|
|
|
|
with preserve_rng_state():
|
|
if example_value is None:
|
|
example_value = get_fake_value(proxy.node, tx)
|
|
|
|
# Handle recursive calls here
|
|
elif isinstance(example_value, FakeTensor):
|
|
pass
|
|
|
|
elif isinstance(example_value, torch.Tensor):
|
|
if tx.export:
|
|
# The legacy behavior for real value cache with subclasses was
|
|
# to perform a clone WITHOUT preserving the subclass. It's
|
|
# not entirely clear this is what you actually want though.
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
proxy.tracer.real_value_cache[proxy.node] = _clone_input(
|
|
example_value
|
|
)
|
|
# NB: If we're ignoring subclass, then the expectation is you will
|
|
# take the returned TensorVariable and wrap it into a more
|
|
# accurate TensorVariable that is able to track subclass-ness;
|
|
# otherwise this is wrong!
|
|
kwargs = {
|
|
"ignore_subclass": ignore_subclass,
|
|
"is_tensor": target_cls is TensorVariable,
|
|
}
|
|
assert "source" in options and options["source"] is not None
|
|
kwargs["source"] = options["source"]
|
|
example_value = wrap_to_fake_tensor_and_record(
|
|
example_value, tx=tx, **kwargs
|
|
)
|
|
|
|
if isinstance(example_value, torch.Tensor):
|
|
is_parameter = isinstance(example_value, torch.nn.Parameter)
|
|
should_specialize = options.pop("should_specialize", False)
|
|
if is_parameter or should_specialize:
|
|
specialized_value = initial_example_value
|
|
else:
|
|
specialized_value = None
|
|
|
|
# NB: In most (all?) cases, this does not actually do a clone.
|
|
# (WARNING: this means that if we mutate metadata on the fake
|
|
# tensor, the stored example value will update too!)
|
|
example_value = _clone_input(example_value)
|
|
proxy.node.meta["example_value"] = example_value
|
|
specialized_props = target_cls.specialize(example_value)
|
|
if isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor):
|
|
# NB: This will be wrong for ignore_subclass; fix it up later!
|
|
specialized_props["class_type"] = (
|
|
torch.nn.Parameter if is_parameter else torch.Tensor
|
|
)
|
|
|
|
specialized_props["specialized_value"] = specialized_value
|
|
|
|
options.update(specialized_props)
|
|
return target_cls(proxy, **options)
|
|
elif (
|
|
hasattr(proxy.node.target, "__name__")
|
|
and proxy.node.target.__name__ == "set_state"
|
|
and isinstance(proxy.node.target.__self__, torch._C.Generator)
|
|
or proxy.node.target == torch.random.set_rng_state
|
|
):
|
|
from . import TorchVariable
|
|
|
|
return TorchVariable(proxy.node.target)
|
|
elif (
|
|
proxy.node.target == torch._C._DisableFuncTorch
|
|
or proxy.node.target == torch.cuda._is_in_bad_fork
|
|
):
|
|
from . import UserDefinedObjectVariable
|
|
|
|
return UserDefinedObjectVariable(example_value)
|
|
elif istype(example_value, (int, bool, float)) and config.dynamic_shapes:
|
|
proxy.node.meta["example_value"] = example_value
|
|
return SymNodeVariable.create(tx, proxy, example_value, **options)
|
|
elif istype(example_value, torch.Size) and config.dynamic_shapes:
|
|
proxy.node.meta["example_value"] = example_value
|
|
sizes = []
|
|
for i, v in enumerate(example_value):
|
|
proxy_i = proxy[i]
|
|
sizes.append(SymNodeVariable.create(tx, proxy_i, v, **options))
|
|
return SizeVariable(sizes, proxy, **options)
|
|
elif istype(example_value, int) and proxy.node.target in (
|
|
torch.seed,
|
|
operator.mod,
|
|
# some mac builds are missing torch.distributed.get_rank()
|
|
getattr(torch.distributed, "get_rank", _missing),
|
|
getattr(torch.distributed, "get_world_size", _missing),
|
|
):
|
|
if config.dynamic_shapes:
|
|
proxy.node.meta["example_value"] = example_value
|
|
return SymNodeVariable.create(tx, proxy, example_value, **options)
|
|
else:
|
|
return ConstantVariable(example_value, **options)
|
|
elif istype(example_value, torch.Size) and all(
|
|
[isinstance(x, int) for x in example_value]
|
|
):
|
|
sizes = [ConstantVariable(x) for x in example_value]
|
|
return SizeVariable(sizes, **options)
|
|
elif isinstance(example_value, (tuple, list)):
|
|
unpacked = []
|
|
for i, val in enumerate(example_value):
|
|
if val is None:
|
|
# nn.MultiheadAttention() can return None, see issue #175
|
|
unpacked.append(
|
|
ConstantVariable(None, **options),
|
|
)
|
|
else:
|
|
unpacked.append(
|
|
wrap_fx_proxy(
|
|
tx,
|
|
proxy.tracer.create_proxy(
|
|
"call_function", operator.getitem, (proxy, i), {}
|
|
),
|
|
example_value=val,
|
|
**options,
|
|
)
|
|
)
|
|
if istype(example_value, tuple):
|
|
return TupleVariable(unpacked, **options)
|
|
elif istype(example_value, (list, immutable_list)):
|
|
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
|
|
else:
|
|
assert (
|
|
example_value.__class__.__module__ == "torch.return_types"
|
|
or hasattr(example_value, "_fields")
|
|
), ("namedtuple?")
|
|
return NamedTupleVariable(unpacked, example_value.__class__, **options)
|
|
elif example_value is None or proxy.node.target is torch.manual_seed:
|
|
return ConstantVariable(None, **options)
|
|
elif (
|
|
isinstance(example_value, int)
|
|
and proxy.node.target is torch._utils._element_size
|
|
):
|
|
proxy.node.meta["example_value"] = example_value
|
|
return ConstantVariable(example_value, **options)
|
|
elif isinstance(example_value, (torch.SymInt, torch.SymFloat)):
|
|
proxy.node.meta["example_value"] = example_value
|
|
return SymNodeVariable(proxy, example_value, **options)
|
|
elif proxy.node.target in [torch.cuda.streams.Stream, torch.cuda.current_stream]:
|
|
from . import CUDAStreamVariable
|
|
|
|
proxy.node.meta["example_value"] = example_value
|
|
return CUDAStreamVariable(proxy, example_value, **options)
|
|
else:
|
|
unimplemented(
|
|
"torch.* op returned non-Tensor "
|
|
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}"
|
|
)
|
|
|
|
|
|
# Tracks the sources of all fake tensors we wrap in Dynamo.
|
|
# Used by shape guard computation.
|
|
@dataclasses.dataclass
|
|
class TrackedFake:
|
|
fake: Union[FakeTensor, SymInt]
|
|
source: Source
|
|
|
|
|
|
def wrap_to_fake_tensor_and_record(
|
|
e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool
|
|
):
|
|
if type(e) in (torch.Tensor, torch.nn.Parameter) or (
|
|
ignore_subclass and isinstance(e, torch.Tensor)
|
|
):
|
|
static_shapes = (
|
|
source is None
|
|
or type(e) is torch.nn.Parameter
|
|
or config.dynamic_shapes is False
|
|
or not is_tensor
|
|
)
|
|
fake_e = wrap_fake_exception(
|
|
lambda: tx.fake_mode.from_tensor(
|
|
e,
|
|
static_shapes=static_shapes,
|
|
ignore_subclass=ignore_subclass,
|
|
source=source,
|
|
)
|
|
)
|
|
if is_tensor:
|
|
tx.output.tracked_fakes.append(TrackedFake(fake_e, source))
|
|
return fake_e
|
|
else:
|
|
return e
|