mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] Simplify creation of VariableTrackers (#135714)
## `VariableTracker::build()` hides the Builders
### The problem
In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)
).
Variations on this code are repeated in many places.
More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.
### The solution
In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.
This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.
**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1581a93e87
commit
e1c4548441
@ -91,7 +91,6 @@ from .variables.builder import (
|
|||||||
BackwardStateGraphArg,
|
BackwardStateGraphArg,
|
||||||
GraphArg,
|
GraphArg,
|
||||||
TrackedFake,
|
TrackedFake,
|
||||||
VariableBuilder,
|
|
||||||
wrap_fx_proxy,
|
wrap_fx_proxy,
|
||||||
)
|
)
|
||||||
from .variables.lists import BaseListVariable
|
from .variables.lists import BaseListVariable
|
||||||
@ -498,7 +497,7 @@ class OutputGraph:
|
|||||||
cg.store(varname)
|
cg.store(varname)
|
||||||
self.pregraph_bytecode.extend(cg.get_instructions())
|
self.pregraph_bytecode.extend(cg.get_instructions())
|
||||||
source = SyntheticLocalSource(varname)
|
source = SyntheticLocalSource(varname)
|
||||||
result = VariableBuilder(self.root_tx, source)(example_value)
|
result = VariableTracker.build(self.root_tx, example_value, source)
|
||||||
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
||||||
source
|
source
|
||||||
)
|
)
|
||||||
@ -767,8 +766,8 @@ class OutputGraph:
|
|||||||
):
|
):
|
||||||
if is_dynamic_nn_module(target, self.root_tx.export):
|
if is_dynamic_nn_module(target, self.root_tx.export):
|
||||||
# Instead of returning UnspecializedNNModuleVariable, call
|
# Instead of returning UnspecializedNNModuleVariable, call
|
||||||
# VariableBuilder so that it is tracked for mutation.
|
# VariableTracker.build so that it is tracked for mutation.
|
||||||
return VariableBuilder(self.current_tx, **options)(target)
|
return VariableTracker.build(self.current_tx, target, **options)
|
||||||
|
|
||||||
options = dict(options)
|
options = dict(options)
|
||||||
assert "source" in options
|
assert "source" in options
|
||||||
@ -860,8 +859,8 @@ class OutputGraph:
|
|||||||
def wrap_name(module_key):
|
def wrap_name(module_key):
|
||||||
self.output.update_co_names(module_key)
|
self.output.update_co_names(module_key)
|
||||||
self.global_scope[module_key] = target
|
self.global_scope[module_key] = target
|
||||||
return VariableBuilder(self, ConstantSource(source_name=module_key))(
|
return VariableTracker.build(
|
||||||
target
|
self, target, ConstantSource(source_name=module_key)
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in self.nn_modules.items():
|
for k, v in self.nn_modules.items():
|
||||||
|
@ -71,7 +71,7 @@ from .utils import (
|
|||||||
proxy_args_kwargs,
|
proxy_args_kwargs,
|
||||||
)
|
)
|
||||||
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
|
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
|
||||||
from .variables.builder import VariableBuilder, wrap_fx_proxy
|
from .variables.builder import wrap_fx_proxy
|
||||||
from .variables.builtin import BuiltinVariable
|
from .variables.builtin import BuiltinVariable
|
||||||
from .variables.constant import ConstantVariable
|
from .variables.constant import ConstantVariable
|
||||||
from .variables.ctx_manager import (
|
from .variables.ctx_manager import (
|
||||||
@ -1224,15 +1224,14 @@ class InstructionTranslatorBase(
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return self.load_builtin(inst)
|
return self.load_builtin(inst)
|
||||||
|
|
||||||
source = GlobalSource(name)
|
self.push(VariableTracker.build(self, value, GlobalSource(name)))
|
||||||
self.push(VariableBuilder(self, source)(value))
|
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def nn_modules_globals_vt(self):
|
def nn_modules_globals_vt(self):
|
||||||
module_name = "torch.nn.modules.module"
|
module_name = "torch.nn.modules.module"
|
||||||
module_source = self.import_source(module_name)
|
module_source = self.import_source(module_name)
|
||||||
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
||||||
return VariableBuilder(self, module_source)(fglobals_value)
|
return VariableTracker.build(self, fglobals_value, module_source)
|
||||||
|
|
||||||
def LOAD_GLOBAL(self, inst):
|
def LOAD_GLOBAL(self, inst):
|
||||||
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
|
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
|
||||||
@ -1374,7 +1373,7 @@ class InstructionTranslatorBase(
|
|||||||
self.output.name_of_builtins_dict_key_in_fglobals
|
self.output.name_of_builtins_dict_key_in_fglobals
|
||||||
)
|
)
|
||||||
var_source = GetItemSource(builtins_source, argval)
|
var_source = GetItemSource(builtins_source, argval)
|
||||||
self.push(VariableBuilder(self, var_source)(val))
|
self.push(VariableTracker.build(self, val, var_source))
|
||||||
else:
|
else:
|
||||||
assert is_builtin_constant(val)
|
assert is_builtin_constant(val)
|
||||||
self.push(ConstantVariable.create(value=val))
|
self.push(ConstantVariable.create(value=val))
|
||||||
@ -3403,7 +3402,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
|
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
||||||
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
|
fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
|
||||||
global_source = AttrSource(module_source, name)
|
global_source = AttrSource(module_source, name)
|
||||||
else:
|
else:
|
||||||
globals_name = self.output.install_global_by_id(
|
globals_name = self.output.install_global_by_id(
|
||||||
@ -3411,7 +3410,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
)
|
)
|
||||||
globals_source = GlobalSource(globals_name)
|
globals_source = GlobalSource(globals_name)
|
||||||
fglobals_value = self.f_globals # type: ignore[assignment]
|
fglobals_value = self.f_globals # type: ignore[assignment]
|
||||||
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
|
fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
|
||||||
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
|
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
|
||||||
return fglobals_value, fglobals_vt, global_source
|
return fglobals_value, fglobals_vt, global_source
|
||||||
|
|
||||||
@ -3430,7 +3429,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
return self.load_builtin(inst)
|
return self.load_builtin(inst)
|
||||||
|
|
||||||
self.push(VariableBuilder(self, global_source)(value))
|
self.push(VariableTracker.build(self, value, global_source))
|
||||||
|
|
||||||
def STORE_GLOBAL(self, inst):
|
def STORE_GLOBAL(self, inst):
|
||||||
if self.f_globals is self.parent.f_globals:
|
if self.f_globals is self.parent.f_globals:
|
||||||
|
@ -12,7 +12,7 @@ from ..utils import istype
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
|
||||||
|
|
||||||
|
|
||||||
class MutableLocalSource(Enum):
|
class MutableLocalSource(Enum):
|
||||||
@ -121,6 +121,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
|
|
||||||
VariableTracker instances are immutable and should be copied in
|
VariableTracker instances are immutable and should be copied in
|
||||||
order to change them.
|
order to change them.
|
||||||
|
|
||||||
|
Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# fields to leave unmodified in apply()
|
# fields to leave unmodified in apply()
|
||||||
@ -244,9 +246,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
value = self.const_getattr(tx, name)
|
value = self.const_getattr(tx, name)
|
||||||
if not variables.ConstantVariable.is_literal(value):
|
if not variables.ConstantVariable.is_literal(value):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
source = None
|
source = self.source and AttrSource(self.source, name)
|
||||||
if self.source:
|
|
||||||
source = AttrSource(self.source, name)
|
|
||||||
return variables.ConstantVariable.create(value, source=source)
|
return variables.ConstantVariable.create(value, source=source)
|
||||||
|
|
||||||
def is_proxy(self):
|
def is_proxy(self):
|
||||||
@ -363,6 +363,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||||||
def is_strict_mode(self, tx):
|
def is_strict_mode(self, tx):
|
||||||
return tx.strict_checks_fn and tx.strict_checks_fn(self)
|
return tx.strict_checks_fn and tx.strict_checks_fn(self)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build(
|
||||||
|
tx: "InstructionTranslatorBase",
|
||||||
|
value: Any,
|
||||||
|
source: Optional[Source] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Create a new VariableTracker from a value and optional Source"""
|
||||||
|
from . import builder
|
||||||
|
|
||||||
|
if source is None:
|
||||||
|
return builder.SourcelessBuilder.create(tx, value)
|
||||||
|
else:
|
||||||
|
return builder.VariableBuilder(tx, source)(value)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -701,7 +701,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
|
def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
from .lazy import LazyVariableTracker
|
from .lazy import LazyVariableTracker
|
||||||
|
|
||||||
obj = BuiltinVariable(fn)
|
obj = BuiltinVariable(fn)
|
||||||
@ -794,8 +793,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
handlers.append(call_self_handler)
|
handlers.append(call_self_handler)
|
||||||
|
|
||||||
if obj.can_constant_fold_through():
|
if obj.can_constant_fold_through():
|
||||||
builder = SourcelessBuilder.create
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
all(issubclass(x, ConstantVariable) for x in arg_types)
|
all(issubclass(x, ConstantVariable) for x in arg_types)
|
||||||
and not has_kwargs
|
and not has_kwargs
|
||||||
@ -809,7 +806,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
unimplemented(f"constant fold exception: {repr(exc)}")
|
unimplemented(f"constant fold exception: {repr(exc)}")
|
||||||
return builder(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -825,7 +822,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
unimplemented(f"constant fold exception: {repr(exc)}")
|
unimplemented(f"constant fold exception: {repr(exc)}")
|
||||||
return builder(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
handlers.append(constant_fold_handler)
|
handlers.append(constant_fold_handler)
|
||||||
|
|
||||||
@ -1361,8 +1358,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
|
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
if not args:
|
if not args:
|
||||||
args = ({},)
|
args = ({},)
|
||||||
@ -1399,7 +1394,7 @@ class BuiltinVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_dict = dict(arg.value.items())
|
new_dict = dict(arg.value.items())
|
||||||
return SourcelessBuilder.create(tx, new_dict)
|
return VariableTracker.build(tx, new_dict)
|
||||||
else:
|
else:
|
||||||
func_var = arg.var_getattr(tx, "items")
|
func_var = arg.var_getattr(tx, "items")
|
||||||
if not isinstance(func_var, variables.UserFunctionVariable):
|
if not isinstance(func_var, variables.UserFunctionVariable):
|
||||||
@ -1631,7 +1626,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
TorchInGraphFunctionVariable,
|
TorchInGraphFunctionVariable,
|
||||||
UserFunctionVariable,
|
UserFunctionVariable,
|
||||||
)
|
)
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
name = name_var.as_python_constant()
|
name = name_var.as_python_constant()
|
||||||
|
|
||||||
@ -1666,34 +1660,21 @@ class BuiltinVariable(VariableTracker):
|
|||||||
if not hasattr_var.as_python_constant():
|
if not hasattr_var.as_python_constant():
|
||||||
return default
|
return default
|
||||||
|
|
||||||
options = {}
|
source = obj.source and AttrSource(obj.source, name)
|
||||||
if obj.source:
|
|
||||||
source = AttrSource(obj.source, name)
|
|
||||||
options["source"] = source
|
|
||||||
else:
|
|
||||||
source = None
|
|
||||||
|
|
||||||
if name in {"__bases__", "__base__", "__flags__"}:
|
if name in {"__bases__", "__base__", "__flags__"}:
|
||||||
try:
|
try:
|
||||||
value = obj.as_python_constant()
|
value = obj.as_python_constant()
|
||||||
if isinstance(value, type):
|
if isinstance(value, type):
|
||||||
if name == "__bases__":
|
if name == "__bases__":
|
||||||
bases = value.__bases__
|
tuple_args = [
|
||||||
if source is not None:
|
VariableTracker.build(
|
||||||
tuple_args = [
|
tx, b, source and GetItemSource(source, i)
|
||||||
VariableBuilder(tx, GetItemSource(source, i))(b)
|
)
|
||||||
for i, b in enumerate(bases)
|
for i, b in enumerate(value.__bases__)
|
||||||
]
|
]
|
||||||
else:
|
return variables.TupleVariable(tuple_args, source=source)
|
||||||
tuple_args = [
|
|
||||||
SourcelessBuilder.create(tx, b) for b in bases
|
|
||||||
]
|
|
||||||
return variables.TupleVariable(tuple_args, **options)
|
|
||||||
if name == "__base__":
|
if name == "__base__":
|
||||||
base = value.__base__
|
return VariableTracker.build(tx, value.__base__, source)
|
||||||
if source is not None:
|
|
||||||
return VariableBuilder(tx, source)(base)
|
|
||||||
return SourcelessBuilder.create(tx, base)
|
|
||||||
if name == "__flags__":
|
if name == "__flags__":
|
||||||
return ConstantVariable.create(value.__flags__)
|
return ConstantVariable.create(value.__flags__)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
@ -1715,14 +1696,14 @@ class BuiltinVariable(VariableTracker):
|
|||||||
try:
|
try:
|
||||||
return obj.var_getattr(tx, name)
|
return obj.var_getattr(tx, name)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return GetAttrVariable(obj, name, **options)
|
return GetAttrVariable(obj, name, source=source)
|
||||||
elif isinstance(obj, TorchInGraphFunctionVariable):
|
elif isinstance(obj, TorchInGraphFunctionVariable):
|
||||||
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
|
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
|
||||||
member = getattr(obj.value, name)
|
member = getattr(obj.value, name)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
||||||
) and trace_rules.is_aten_op_or_tensor_method(member):
|
) and trace_rules.is_aten_op_or_tensor_method(member):
|
||||||
return TorchInGraphFunctionVariable(member, **options)
|
return TorchInGraphFunctionVariable(member, source=source)
|
||||||
elif isinstance(obj, DummyModule):
|
elif isinstance(obj, DummyModule):
|
||||||
# TODO(mlazos) - Do we need this?
|
# TODO(mlazos) - Do we need this?
|
||||||
if obj.is_torch or name not in obj.value.__dict__:
|
if obj.is_torch or name not in obj.value.__dict__:
|
||||||
@ -1732,18 +1713,15 @@ class BuiltinVariable(VariableTracker):
|
|||||||
|
|
||||||
if config.replay_record_enabled:
|
if config.replay_record_enabled:
|
||||||
tx.exec_recorder.record_module_access(obj.value, name, member)
|
tx.exec_recorder.record_module_access(obj.value, name, member)
|
||||||
|
return VariableTracker.build(tx, member, source)
|
||||||
|
|
||||||
if source is not None:
|
|
||||||
return VariableBuilder(tx, source)(member)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, member)
|
|
||||||
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
||||||
return ConstantVariable.create(getattr(obj.fn, name))
|
return ConstantVariable.create(getattr(obj.fn, name))
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return obj.var_getattr(tx, name)
|
return obj.var_getattr(tx, name)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return GetAttrVariable(obj, name, **options)
|
return GetAttrVariable(obj, name, source=source)
|
||||||
|
|
||||||
def call_setattr(
|
def call_setattr(
|
||||||
self,
|
self,
|
||||||
@ -1882,8 +1860,6 @@ class BuiltinVariable(VariableTracker):
|
|||||||
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
|
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
|
||||||
|
|
||||||
def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
|
def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
py_type = obj.python_type()
|
py_type = obj.python_type()
|
||||||
except NotImplementedError as error:
|
except NotImplementedError as error:
|
||||||
@ -1893,10 +1869,8 @@ class BuiltinVariable(VariableTracker):
|
|||||||
case_name="unknown_python_type",
|
case_name="unknown_python_type",
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
if obj.source is None:
|
source = obj.source and TypeSource(obj.source)
|
||||||
return SourcelessBuilder.create(tx, py_type)
|
return VariableTracker.build(tx, py_type, source)
|
||||||
else:
|
|
||||||
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
|
|
||||||
|
|
||||||
def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
|
def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
|
||||||
if obj.has_unpack_var_sequence(tx):
|
if obj.has_unpack_var_sequence(tx):
|
||||||
|
@ -984,12 +984,10 @@ class HFPretrainedConfigVariable(VariableTracker):
|
|||||||
assert self.is_matching_cls(type(obj))
|
assert self.is_matching_cls(type(obj))
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attr_value = getattr(self.obj, name)
|
attr_value = getattr(self.obj, name)
|
||||||
attr_source = AttrSource(self.source, name)
|
source = self.source and AttrSource(self.source, name)
|
||||||
return VariableBuilder(tx, attr_source)(attr_value)
|
return VariableTracker.build(tx, attr_value, source)
|
||||||
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
unimplemented(f"getattr({self.value}, {name})")
|
unimplemented(f"getattr({self.value}, {name})")
|
||||||
@ -1053,15 +1051,11 @@ class PythonSysModulesVariable(VariableTracker):
|
|||||||
key: VariableTracker,
|
key: VariableTracker,
|
||||||
default: Optional[VariableTracker] = None,
|
default: Optional[VariableTracker] = None,
|
||||||
):
|
):
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
k, has_key = self._contains_helper(tx, key)
|
k, has_key = self._contains_helper(tx, key)
|
||||||
|
|
||||||
if has_key:
|
if has_key:
|
||||||
return VariableBuilder(
|
source = self.source and GetItemSource(self.source, k)
|
||||||
tx,
|
return VariableTracker.build(tx, sys.modules[k], source)
|
||||||
GetItemSource(self.source, k),
|
|
||||||
)(sys.modules[k])
|
|
||||||
|
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
@ -1069,10 +1063,6 @@ class PythonSysModulesVariable(VariableTracker):
|
|||||||
return ConstantVariable.create(value=None)
|
return ConstantVariable.create(value=None)
|
||||||
|
|
||||||
def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
|
def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
k, has_key = self._contains_helper(tx, key)
|
k, has_key = self._contains_helper(tx, key)
|
||||||
return VariableBuilder(
|
source = self.source and GetItemSource(self.source, k)
|
||||||
tx,
|
return VariableTracker.build(tx, sys.modules[k], source)
|
||||||
GetItemSource(self.source, k),
|
|
||||||
)(sys.modules[k])
|
|
||||||
|
@ -46,9 +46,7 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
|
|||||||
if isinstance(val, VariableTracker):
|
if isinstance(val, VariableTracker):
|
||||||
return val
|
return val
|
||||||
elif not source:
|
elif not source:
|
||||||
from torch._dynamo.variables.builder import SourcelessBuilder
|
return VariableTracker.build(tx, val)
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, val)
|
|
||||||
else:
|
else:
|
||||||
# Create a lazy variable to avoid guarding on __defaults__ unless really
|
# Create a lazy variable to avoid guarding on __defaults__ unless really
|
||||||
# needed.
|
# needed.
|
||||||
@ -240,8 +238,6 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||||||
# optimization for cleaner codegen
|
# optimization for cleaner codegen
|
||||||
result[name] = var
|
result[name] = var
|
||||||
elif self.source:
|
elif self.source:
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
side_effects = parent.output.side_effects
|
side_effects = parent.output.side_effects
|
||||||
if cell in side_effects:
|
if cell in side_effects:
|
||||||
out = side_effects[cell]
|
out = side_effects[cell]
|
||||||
@ -253,9 +249,9 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||||||
closure_cell, "cell_contents"
|
closure_cell, "cell_contents"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
contents_var = VariableBuilder(
|
contents_var = VariableTracker.build(
|
||||||
parent, closure_cell_contents
|
parent, cell.cell_contents, closure_cell_contents
|
||||||
)(cell.cell_contents)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Cell has not yet been assigned
|
# Cell has not yet been assigned
|
||||||
contents_var = variables.DeletedVariable()
|
contents_var = variables.DeletedVariable()
|
||||||
@ -286,9 +282,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||||||
result[name] = out
|
result[name] = out
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .builder import SourcelessBuilder
|
result[name] = VariableTracker.build(tx, cell.cell_contents)
|
||||||
|
|
||||||
result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
|
|
||||||
|
|
||||||
return result, closure_cells
|
return result, closure_cells
|
||||||
|
|
||||||
@ -296,17 +290,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||||
source = AttrSource(self.source, name) if self.source else None
|
source = self.source and AttrSource(self.source, name)
|
||||||
try:
|
try:
|
||||||
subobj = inspect.getattr_static(self.fn, name)
|
subobj = inspect.getattr_static(self.fn, name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
options = {"source": source}
|
return variables.GetAttrVariable(self, name, source=source)
|
||||||
return variables.GetAttrVariable(self, name, **options)
|
|
||||||
if source:
|
if source:
|
||||||
return variables.LazyVariableTracker.create(subobj, source)
|
return variables.LazyVariableTracker.create(subobj, source)
|
||||||
from .builder import SourcelessBuilder
|
return VariableTracker.build(tx, subobj)
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, subobj)
|
|
||||||
|
|
||||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||||
result = hasattr(self.fn, name)
|
result = hasattr(self.fn, name)
|
||||||
@ -757,14 +748,8 @@ class WrapperUserFunctionVariable(VariableTracker):
|
|||||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||||
if name == self.attr_to_trace:
|
if name == self.attr_to_trace:
|
||||||
val = getattr(self.wrapper_obj, self.attr_to_trace)
|
val = getattr(self.wrapper_obj, self.attr_to_trace)
|
||||||
if self.source:
|
source = self.source and AttrSource(self.source, name)
|
||||||
from .builder import VariableBuilder
|
return VariableTracker.build(tx, val, source)
|
||||||
|
|
||||||
return VariableBuilder(tx, AttrSource(self.source, name))(val)
|
|
||||||
else:
|
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, val)
|
|
||||||
|
|
||||||
return super().var_getattr(tx, name)
|
return super().var_getattr(tx, name)
|
||||||
|
|
||||||
@ -999,8 +984,6 @@ class PolyfilledFunctionVariable(VariableTracker):
|
|||||||
args: "List[VariableTracker]",
|
args: "List[VariableTracker]",
|
||||||
kwargs: "Dict[str, VariableTracker]",
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from torch._dynamo.variables.builder import SourcelessBuilder
|
|
||||||
|
|
||||||
if self.can_constant_fold_through() and check_unspec_or_constant_args(
|
if self.can_constant_fold_through() and check_unspec_or_constant_args(
|
||||||
args, kwargs
|
args, kwargs
|
||||||
):
|
):
|
||||||
@ -1010,9 +993,9 @@ class PolyfilledFunctionVariable(VariableTracker):
|
|||||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return SourcelessBuilder.create(tx, result)
|
return VariableTracker.build(tx, result)
|
||||||
|
|
||||||
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn)
|
traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
|
||||||
return traceable_function_variable.call_function(tx, args, kwargs)
|
return traceable_function_variable.call_function(tx, args, kwargs)
|
||||||
|
|
||||||
def call_method(
|
def call_method(
|
||||||
|
@ -13,7 +13,6 @@ import torch.fx
|
|||||||
import torch.nn
|
import torch.nn
|
||||||
from torch._dynamo.utils import get_fake_value
|
from torch._dynamo.utils import get_fake_value
|
||||||
from torch._dynamo.variables import ConstantVariable
|
from torch._dynamo.variables import ConstantVariable
|
||||||
from torch._dynamo.variables.base import VariableTracker
|
|
||||||
from torch._dynamo.variables.builtin import BuiltinVariable
|
from torch._dynamo.variables.builtin import BuiltinVariable
|
||||||
from torch._dynamo.variables.functions import UserFunctionVariable
|
from torch._dynamo.variables.functions import UserFunctionVariable
|
||||||
from torch._dynamo.variables.tensor import SymNodeVariable
|
from torch._dynamo.variables.tensor import SymNodeVariable
|
||||||
@ -31,6 +30,7 @@ from ..exc import (
|
|||||||
)
|
)
|
||||||
from ..source import AttrSource
|
from ..source import AttrSource
|
||||||
from ..utils import proxy_args_kwargs
|
from ..utils import proxy_args_kwargs
|
||||||
|
from .base import VariableTracker
|
||||||
from .dicts import ConstDictVariable
|
from .dicts import ConstDictVariable
|
||||||
from .lazy import LazyVariableTracker
|
from .lazy import LazyVariableTracker
|
||||||
from .lists import ListVariable, TupleVariable
|
from .lists import ListVariable, TupleVariable
|
||||||
@ -1040,7 +1040,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
args: List[VariableTracker],
|
args: List[VariableTracker],
|
||||||
kwargs: Dict[str, VariableTracker],
|
kwargs: Dict[str, VariableTracker],
|
||||||
) -> VariableTracker:
|
) -> VariableTracker:
|
||||||
from .builder import SourcelessBuilder, wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
|
|
||||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||||
|
|
||||||
@ -1062,7 +1062,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
tx,
|
tx,
|
||||||
"new_empty",
|
"new_empty",
|
||||||
args=(
|
args=(
|
||||||
SourcelessBuilder.create(
|
VariableTracker.build(
|
||||||
tx,
|
tx,
|
||||||
leaf.size
|
leaf.size
|
||||||
if leaf.size is not None
|
if leaf.size is not None
|
||||||
@ -1072,8 +1072,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
kwargs={
|
kwargs={
|
||||||
"dtype": SourcelessBuilder.create(tx, leaf.dtype),
|
"dtype": VariableTracker.build(tx, leaf.dtype),
|
||||||
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad),
|
"requires_grad": VariableTracker.build(tx, leaf.requires_grad),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for leaf in itertools.chain(xs.items, xs.items)
|
for leaf in itertools.chain(xs.items, xs.items)
|
||||||
@ -2057,7 +2057,6 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
fn_name: str,
|
fn_name: str,
|
||||||
):
|
):
|
||||||
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
tx: InstructionTranslator = tx
|
tx: InstructionTranslator = tx
|
||||||
|
|
||||||
@ -2065,9 +2064,9 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
return query.call_method(
|
return query.call_method(
|
||||||
tx,
|
tx,
|
||||||
"new_empty",
|
"new_empty",
|
||||||
(SourcelessBuilder.create(tx, []),),
|
(VariableTracker.build(tx, []),),
|
||||||
{
|
{
|
||||||
"dtype": SourcelessBuilder.create(tx, torch.int32),
|
"dtype": VariableTracker.build(tx, torch.int32),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2077,8 +2076,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||||||
score = query.call_method(
|
score = query.call_method(
|
||||||
tx,
|
tx,
|
||||||
"new_empty",
|
"new_empty",
|
||||||
(SourcelessBuilder.create(tx, []),),
|
(VariableTracker.build(tx, []),),
|
||||||
{"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)},
|
{"requires_grad": VariableTracker.build(tx, scores_require_grad)},
|
||||||
)
|
)
|
||||||
new_args = [score, *bhmn]
|
new_args = [score, *bhmn]
|
||||||
else:
|
else:
|
||||||
|
@ -172,10 +172,8 @@ class ItertoolsVariable(VariableTracker):
|
|||||||
*args, mutable_local=MutableLocal()
|
*args, mutable_local=MutableLocal()
|
||||||
)
|
)
|
||||||
|
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
|
VariableTracker.build(tx, polyfills.repeat), args, kwargs
|
||||||
)
|
)
|
||||||
elif self.value is itertools.count:
|
elif self.value is itertools.count:
|
||||||
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
|
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
|
||||||
|
@ -20,14 +20,15 @@ class LazyCache:
|
|||||||
def realize(self) -> None:
|
def realize(self) -> None:
|
||||||
assert self.vt is None
|
assert self.vt is None
|
||||||
from ..symbolic_convert import InstructionTranslator
|
from ..symbolic_convert import InstructionTranslator
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
tx = InstructionTranslator.current_tx()
|
tx = InstructionTranslator.current_tx()
|
||||||
if isinstance(self.value, LazySymNodeFormatString):
|
|
||||||
self.vt = SourcelessBuilder.create(tx, self.value)
|
|
||||||
else:
|
|
||||||
self.vt = VariableBuilder(tx, self.source)(self.value)
|
|
||||||
|
|
||||||
|
if isinstance(self.value, LazySymNodeFormatString):
|
||||||
|
source = None
|
||||||
|
else:
|
||||||
|
source = self.source
|
||||||
|
|
||||||
|
self.vt = VariableTracker.build(tx, self.value, source)
|
||||||
del self.value
|
del self.value
|
||||||
del self.source
|
del self.source
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ class LazyVariableTracker(VariableTracker):
|
|||||||
A structure that defers the creation of the actual VariableTracker
|
A structure that defers the creation of the actual VariableTracker
|
||||||
for a given underlying value until it is accessed.
|
for a given underlying value until it is accessed.
|
||||||
|
|
||||||
The `realize` function invokes VariableBuilder to produce the real object.
|
The `realize` function invokes VariableTracker.build() to produce the real object.
|
||||||
Once a LazyVariableTracker has been realized, internal bookkeeping will
|
Once a LazyVariableTracker has been realized, internal bookkeeping will
|
||||||
prevent double realization.
|
prevent double realization.
|
||||||
|
|
||||||
|
@ -135,10 +135,8 @@ class BaseListVariable(VariableTracker):
|
|||||||
assert not kwargs
|
assert not kwargs
|
||||||
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
||||||
elif name == "index":
|
elif name == "index":
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.index),
|
VariableTracker.build(tx, polyfills.index),
|
||||||
[self] + list(args),
|
[self] + list(args),
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
@ -207,12 +207,10 @@ class SuperVariable(VariableTracker):
|
|||||||
and len(kwargs) == 0
|
and len(kwargs) == 0
|
||||||
and args[0].is_python_constant()
|
and args[0].is_python_constant()
|
||||||
):
|
):
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
key = args[0].as_python_constant()
|
key = args[0].as_python_constant()
|
||||||
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
|
value = collections.OrderedDict.__getitem__(self.objvar.value, key)
|
||||||
collections.OrderedDict.__getitem__(self.objvar.value, key)
|
source = ODictGetItemSource(self.objvar.source, key)
|
||||||
)
|
return VariableTracker.build(tx, value, source)
|
||||||
elif inner_fn in (
|
elif inner_fn in (
|
||||||
collections.OrderedDict.__setitem__,
|
collections.OrderedDict.__setitem__,
|
||||||
object.__setattr__,
|
object.__setattr__,
|
||||||
@ -467,15 +465,10 @@ class InspectParameterVariable(VariableTracker):
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attr_value = getattr(self.value, name)
|
attr_value = getattr(self.value, name)
|
||||||
if self.source:
|
source = self.source and AttrSource(self.source, name)
|
||||||
attr_source = AttrSource(self.source, name)
|
return VariableTracker.build(tx, attr_value, source)
|
||||||
return VariableBuilder(tx, attr_source)(attr_value)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, attr_value)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
unimplemented(f"getattr({self.value}, {name})")
|
unimplemented(f"getattr({self.value}, {name})")
|
||||||
|
|
||||||
@ -909,11 +902,9 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
|||||||
if self.needs_input_grad is not None:
|
if self.needs_input_grad is not None:
|
||||||
return variables.ConstantVariable.create(self.needs_input_grad)
|
return variables.ConstantVariable.create(self.needs_input_grad)
|
||||||
if self.source:
|
if self.source:
|
||||||
from .builder import VariableBuilder
|
source = AttrSource(self.source, "needs_input_grad")
|
||||||
|
return VariableTracker.build(tx, self.value.needs_input_grad, source)
|
||||||
|
|
||||||
return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
|
|
||||||
self.value.needs_input_grad
|
|
||||||
)
|
|
||||||
return super().var_getattr(tx, name)
|
return super().var_getattr(tx, name)
|
||||||
|
|
||||||
|
|
||||||
@ -1118,11 +1109,8 @@ class GetSetDescriptorVariable(VariableTracker):
|
|||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||||
if name == "__get__" and self.source:
|
if name == "__get__" and self.source:
|
||||||
from .builder import VariableBuilder
|
source = AttrSource(self.source, "__get__")
|
||||||
|
return VariableTracker.build(tx, self.desc.__get__, source)
|
||||||
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
|
|
||||||
self.desc.__get__
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return super().var_getattr(tx, name)
|
return super().var_getattr(tx, name)
|
||||||
|
|
||||||
@ -1162,18 +1150,13 @@ class PythonModuleVariable(VariableTracker):
|
|||||||
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
||||||
return tx.output.side_effects.load_attr(self, name)
|
return tx.output.side_effects.load_attr(self, name)
|
||||||
|
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
if self.is_torch or name not in self.value.__dict__:
|
if self.is_torch or name not in self.value.__dict__:
|
||||||
attr_value = getattr(self.value, name)
|
attr_value = getattr(self.value, name)
|
||||||
else:
|
else:
|
||||||
attr_value = self.value.__dict__[name]
|
attr_value = self.value.__dict__[name]
|
||||||
|
|
||||||
if self.source:
|
source = self.source and AttrSource(self.source, name)
|
||||||
new_source = AttrSource(self.source, name)
|
return VariableTracker.build(tx, attr_value, source)
|
||||||
return VariableBuilder(tx, new_source)(attr_value)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, attr_value)
|
|
||||||
|
|
||||||
|
|
||||||
class TypingVariable(VariableTracker):
|
class TypingVariable(VariableTracker):
|
||||||
|
@ -244,12 +244,7 @@ class NNModuleVariable(VariableTracker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||||
from .builder import VariableBuilder
|
source = self.source and AttrSource(self.source, name)
|
||||||
|
|
||||||
if self.source:
|
|
||||||
source = AttrSource(self.source, name)
|
|
||||||
else:
|
|
||||||
source = None
|
|
||||||
|
|
||||||
base = tx.output.get_submodule(self.module_key)
|
base = tx.output.get_submodule(self.module_key)
|
||||||
base_dict = object.__getattribute__(base, "__dict__")
|
base_dict = object.__getattribute__(base, "__dict__")
|
||||||
@ -297,7 +292,7 @@ class NNModuleVariable(VariableTracker):
|
|||||||
return variables.UserDefinedClassVariable(base.__class__, source=source)
|
return variables.UserDefinedClassVariable(base.__class__, source=source)
|
||||||
|
|
||||||
if object_member:
|
if object_member:
|
||||||
out = VariableBuilder(tx, NNModuleSource(source))(subobj)
|
out = VariableTracker.build(tx, subobj, NNModuleSource(source))
|
||||||
|
|
||||||
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
|
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
|
||||||
# nn_module_stack source is BC surface area. Ensure that
|
# nn_module_stack source is BC surface area. Ensure that
|
||||||
@ -333,7 +328,7 @@ class NNModuleVariable(VariableTracker):
|
|||||||
return variables.UserMethodVariable(subobj, self, source=source)
|
return variables.UserMethodVariable(subobj, self, source=source)
|
||||||
elif is_safe_constant(subobj) or istensor(subobj):
|
elif is_safe_constant(subobj) or istensor(subobj):
|
||||||
# Support possibly common cases of class members
|
# Support possibly common cases of class members
|
||||||
return VariableBuilder(tx, NNModuleSource(source))(subobj)
|
return VariableTracker.build(tx, subobj, NNModuleSource(source))
|
||||||
else:
|
else:
|
||||||
unimplemented(
|
unimplemented(
|
||||||
f"class property {name} - {typestr(base)} {typestr(subobj)}"
|
f"class property {name} - {typestr(base)} {typestr(subobj)}"
|
||||||
@ -1083,7 +1078,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
|||||||
)
|
)
|
||||||
return variables.ConstDictVariable({})
|
return variables.ConstDictVariable({})
|
||||||
|
|
||||||
# For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable.
|
# For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable.
|
||||||
# However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
|
# However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
|
||||||
# differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
|
# differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
|
||||||
# key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a
|
# key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a
|
||||||
|
@ -17,6 +17,7 @@ from ..source import (
|
|||||||
GradSource,
|
GradSource,
|
||||||
)
|
)
|
||||||
from ..utils import GLOBAL_KEY_PREFIX
|
from ..utils import GLOBAL_KEY_PREFIX
|
||||||
|
from .base import VariableTracker
|
||||||
from .constant import ConstantVariable
|
from .constant import ConstantVariable
|
||||||
from .dicts import ConstDictVariable
|
from .dicts import ConstDictVariable
|
||||||
from .lists import ListVariable
|
from .lists import ListVariable
|
||||||
@ -27,8 +28,6 @@ from .user_defined import UserDefinedObjectVariable
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||||
|
|
||||||
from .base import VariableTracker
|
|
||||||
|
|
||||||
|
|
||||||
class ArgMappingException(Exception):
|
class ArgMappingException(Exception):
|
||||||
pass
|
pass
|
||||||
@ -147,7 +146,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
def _set_capturable(self, tx):
|
def _set_capturable(self, tx):
|
||||||
from . import LazyVariableTracker
|
from . import LazyVariableTracker
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
# We only set capturable if params are on cuda
|
# We only set capturable if params are on cuda
|
||||||
# and the state is not initialized
|
# and the state is not initialized
|
||||||
@ -168,10 +166,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
if safe_to_set_capturable(group):
|
if safe_to_set_capturable(group):
|
||||||
group["capturable"] = True
|
group["capturable"] = True
|
||||||
|
|
||||||
|
source = self.source and AttrSource(self.source, "param_groups")
|
||||||
param_groups_vt = LazyVariableTracker.realize_all(
|
param_groups_vt = LazyVariableTracker.realize_all(
|
||||||
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
VariableTracker.build(tx, self.value.param_groups, source)
|
||||||
self.value.param_groups
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
for ind, param_group_vt in enumerate(param_groups_vt.items):
|
for ind, param_group_vt in enumerate(param_groups_vt.items):
|
||||||
key = ConstDictVariable._HashableTracker(
|
key = ConstDictVariable._HashableTracker(
|
||||||
@ -214,7 +211,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
def map_sources_and_install_guards(self, tx):
|
def map_sources_and_install_guards(self, tx):
|
||||||
from ..decorators import mark_static_address
|
from ..decorators import mark_static_address
|
||||||
from .builder import VariableBuilder
|
|
||||||
from .lazy import LazyVariableTracker
|
from .lazy import LazyVariableTracker
|
||||||
|
|
||||||
self.grad_to_source = {}
|
self.grad_to_source = {}
|
||||||
@ -235,15 +231,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
# Recursively realize the variable trackers for optim.state and
|
# Recursively realize the variable trackers for optim.state and
|
||||||
# optim.param_groups, which recursively install the necessary guards.
|
# optim.param_groups, which recursively install the necessary guards.
|
||||||
|
params_groups_source = self.source and AttrSource(self.source, "param_groups")
|
||||||
param_groups_vt = LazyVariableTracker.realize_all(
|
param_groups_vt = LazyVariableTracker.realize_all(
|
||||||
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
VariableTracker.build(tx, self.value.param_groups, params_groups_source)
|
||||||
self.value.param_groups
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
|
state_source = self.source and AttrSource(self.source, "state")
|
||||||
self.value.state
|
state_vt = VariableTracker.build(tx, self.value.state, state_source)
|
||||||
)
|
|
||||||
|
|
||||||
# We need to realize the top level state dict to populate
|
# We need to realize the top level state dict to populate
|
||||||
# the guard locals
|
# the guard locals
|
||||||
@ -265,15 +259,15 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
key_index = i
|
key_index = i
|
||||||
break
|
break
|
||||||
if key_index:
|
if key_index:
|
||||||
state_source = AttrSource(self.source, "state")
|
|
||||||
LazyVariableTracker.realize_all(
|
LazyVariableTracker.realize_all(
|
||||||
VariableBuilder(
|
VariableTracker.build(
|
||||||
tx,
|
tx,
|
||||||
|
self.value.state[param],
|
||||||
GetItemSource(
|
GetItemSource(
|
||||||
state_source,
|
state_source,
|
||||||
ConstDictKeySource(state_source, key_index),
|
ConstDictKeySource(state_source, key_index),
|
||||||
),
|
),
|
||||||
)(self.value.state[param])
|
)
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -312,7 +306,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
# We have to again iterate over the state dict to collect the
|
# We have to again iterate over the state dict to collect the
|
||||||
# tensor_to_source dict. This is used for the finalizer.
|
# tensor_to_source dict. This is used for the finalizer.
|
||||||
state_source = AttrSource(self.source, "state")
|
|
||||||
for idx, (p, value) in enumerate(self.value.state.items()):
|
for idx, (p, value) in enumerate(self.value.state.items()):
|
||||||
p_state_source = GetItemSource(
|
p_state_source = GetItemSource(
|
||||||
state_source, ConstDictKeySource(state_source, idx)
|
state_source, ConstDictKeySource(state_source, idx)
|
||||||
@ -328,7 +321,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
|
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
|
||||||
"""Wrap state tensor in a TensorVariable"""
|
"""Wrap state tensor in a TensorVariable"""
|
||||||
from ..decorators import mark_static_address
|
from ..decorators import mark_static_address
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
# If we have a source for a tensor already use it,
|
# If we have a source for a tensor already use it,
|
||||||
# if we have not seen a tensor before, stash and use a
|
# if we have not seen a tensor before, stash and use a
|
||||||
@ -338,20 +330,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
if tensor_value in self.tensor_to_source:
|
if tensor_value in self.tensor_to_source:
|
||||||
# mark these tensors as static for cudagraphs
|
# mark these tensors as static for cudagraphs
|
||||||
mark_static_address(tensor_value)
|
mark_static_address(tensor_value)
|
||||||
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
|
source = self.tensor_to_source[tensor_value]
|
||||||
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
|
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||||
elif tensor_value in self.grad_to_source:
|
elif tensor_value in self.grad_to_source:
|
||||||
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
|
source = self.grad_to_source[tensor_value]
|
||||||
else:
|
else:
|
||||||
# mark these tensors as static for cudagraphs
|
# mark these tensors as static for cudagraphs
|
||||||
mark_static_address(tensor_value)
|
mark_static_address(tensor_value)
|
||||||
|
|
||||||
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
||||||
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
|
source = GlobalWeakRefSource(global_name)
|
||||||
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
|
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||||
|
|
||||||
result = builder(tensor_value)
|
return VariableTracker.build(tx, tensor_value, source)
|
||||||
return result
|
|
||||||
|
|
||||||
def update_list_args(
|
def update_list_args(
|
||||||
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
|
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
|
||||||
@ -367,14 +358,8 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
|||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
arg.items.append(self.wrap_tensor(tx, val))
|
arg.items.append(self.wrap_tensor(tx, val))
|
||||||
else:
|
else:
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
source = arg.source and GetItemSource(arg.source, i)
|
||||||
|
arg.items.append(VariableTracker.build(tx, val, source))
|
||||||
if arg.source:
|
|
||||||
arg.items.append(
|
|
||||||
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
arg.items.append(SourcelessBuilder.create(tx, val))
|
|
||||||
|
|
||||||
def create_finalizer(self, tx):
|
def create_finalizer(self, tx):
|
||||||
names_to_delete = self.static_tensor_names
|
names_to_delete = self.static_tensor_names
|
||||||
|
@ -5,12 +5,15 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from ..bytecode_transformation import create_call_function
|
from ..bytecode_transformation import create_call_function
|
||||||
from ..exc import Unsupported
|
from ..exc import Unsupported
|
||||||
|
from ..source import AttrSource
|
||||||
from .base import VariableTracker
|
from .base import VariableTracker
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||||
|
|
||||||
|
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
|
||||||
|
|
||||||
|
|
||||||
class SDPAParamsVariable(VariableTracker):
|
class SDPAParamsVariable(VariableTracker):
|
||||||
"""Represents the c++ params struct for scaled dot product attention.
|
"""Represents the c++ params struct for scaled dot product attention.
|
||||||
@ -20,35 +23,13 @@ class SDPAParamsVariable(VariableTracker):
|
|||||||
def create(tx: "InstructionTranslator", value, source):
|
def create(tx: "InstructionTranslator", value, source):
|
||||||
from torch.backends.cuda import SDPAParams
|
from torch.backends.cuda import SDPAParams
|
||||||
|
|
||||||
from ..source import AttrSource
|
|
||||||
from .builder import VariableBuilder
|
|
||||||
from .torch import TorchInGraphFunctionVariable
|
from .torch import TorchInGraphFunctionVariable
|
||||||
|
|
||||||
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
|
params = [
|
||||||
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
|
VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
|
||||||
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
|
for p in PARAM_NAMES
|
||||||
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
|
|
||||||
value.attn_mask
|
|
||||||
)
|
|
||||||
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
|
|
||||||
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
|
|
||||||
value.is_causal
|
|
||||||
)
|
|
||||||
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
|
|
||||||
value.enable_gqa
|
|
||||||
)
|
|
||||||
param_vars = [
|
|
||||||
query_var,
|
|
||||||
key_var,
|
|
||||||
value_var,
|
|
||||||
attn_mask_var,
|
|
||||||
dropout_var,
|
|
||||||
is_causal_var,
|
|
||||||
enable_gqa_var,
|
|
||||||
]
|
]
|
||||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(
|
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
|
||||||
tx, param_vars, {}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, proxy, param_vars, **kwargs) -> None:
|
def __init__(self, proxy, param_vars, **kwargs) -> None:
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
@ -70,7 +51,6 @@ class SDPAParamsVariable(VariableTracker):
|
|||||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||||
import torch._C
|
import torch._C
|
||||||
|
|
||||||
from ..source import AttrSource
|
|
||||||
from .builder import wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
from .misc import GetAttrVariable
|
from .misc import GetAttrVariable
|
||||||
|
|
||||||
|
@ -238,9 +238,7 @@ class TensorVariable(VariableTracker):
|
|||||||
# any other attributes on the subclass (that are not methods)
|
# any other attributes on the subclass (that are not methods)
|
||||||
# are assumed to be constant metadata.
|
# are assumed to be constant metadata.
|
||||||
elif not callable(example_value):
|
elif not callable(example_value):
|
||||||
from .builder import SourcelessBuilder
|
return VariableTracker.build(tx, example_value)
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, example_value)
|
|
||||||
|
|
||||||
if not (self.source and self.source.subguards_allowed()):
|
if not (self.source and self.source.subguards_allowed()):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -277,12 +275,9 @@ class TensorVariable(VariableTracker):
|
|||||||
# Note - at a certain point we may want to handle
|
# Note - at a certain point we may want to handle
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
from ..guards import GuardBuilder
|
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
attr_source = AttrSource(self.source, name)
|
attr_source = AttrSource(self.source, name)
|
||||||
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
|
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
|
||||||
return VariableBuilder(tx, attr_source)(real_value)
|
return VariableTracker.build(tx, real_value, attr_source)
|
||||||
|
|
||||||
def method_attr_ndim(self, tx):
|
def method_attr_ndim(self, tx):
|
||||||
if self.ndim is not None:
|
if self.ndim is not None:
|
||||||
@ -695,7 +690,6 @@ class TensorVariable(VariableTracker):
|
|||||||
def method_as_subclass(self, cls):
|
def method_as_subclass(self, cls):
|
||||||
if isinstance(cls, TensorSubclassVariable) and cls.source:
|
if isinstance(cls, TensorSubclassVariable) and cls.source:
|
||||||
from ..symbolic_convert import InstructionTranslator
|
from ..symbolic_convert import InstructionTranslator
|
||||||
from .builder import VariableBuilder
|
|
||||||
from .torch_function import TensorWithTFOverrideVariable
|
from .torch_function import TensorWithTFOverrideVariable
|
||||||
|
|
||||||
tx = InstructionTranslator.current_tx()
|
tx = InstructionTranslator.current_tx()
|
||||||
@ -705,10 +699,11 @@ class TensorVariable(VariableTracker):
|
|||||||
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
|
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
|
||||||
# It is up to the user whether this is correct behavior or not.
|
# It is up to the user whether this is correct behavior or not.
|
||||||
py_cls = cls.as_python_constant()
|
py_cls = cls.as_python_constant()
|
||||||
torch_fn = VariableBuilder(
|
torch_fn = VariableTracker.build(
|
||||||
tx,
|
tx,
|
||||||
|
py_cls.__torch_function__.__func__,
|
||||||
AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
|
AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
|
||||||
)(py_cls.__torch_function__.__func__)
|
)
|
||||||
|
|
||||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||||
tx, self, py_cls, torch_fn
|
tx, self, py_cls, torch_fn
|
||||||
@ -750,7 +745,6 @@ class TensorVariable(VariableTracker):
|
|||||||
|
|
||||||
def method_tolist(self):
|
def method_tolist(self):
|
||||||
from ..symbolic_convert import InstructionTranslator
|
from ..symbolic_convert import InstructionTranslator
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
tx = InstructionTranslator.current_tx()
|
tx = InstructionTranslator.current_tx()
|
||||||
|
|
||||||
@ -787,7 +781,7 @@ class TensorVariable(VariableTracker):
|
|||||||
|
|
||||||
tensor = self.as_proxy().node.meta["example_value"]
|
tensor = self.as_proxy().node.meta["example_value"]
|
||||||
out = tolist(tensor, self.as_proxy())
|
out = tolist(tensor, self.as_proxy())
|
||||||
return SourcelessBuilder.create(tx, out)
|
return VariableTracker.build(tx, out)
|
||||||
|
|
||||||
def method_backward(self, *args, **kwargs):
|
def method_backward(self, *args, **kwargs):
|
||||||
unimplemented("Tensor.backward")
|
unimplemented("Tensor.backward")
|
||||||
@ -857,10 +851,9 @@ class TensorVariable(VariableTracker):
|
|||||||
tx = InstructionTranslator.current_tx()
|
tx = InstructionTranslator.current_tx()
|
||||||
if value is not None:
|
if value is not None:
|
||||||
from .. import polyfills
|
from .. import polyfills
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.addcmul_inplace),
|
VariableTracker.build(tx, polyfills.addcmul_inplace),
|
||||||
[self, tensor1, tensor2, value],
|
[self, tensor1, tensor2, value],
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
@ -1155,9 +1148,7 @@ class SymNodeVariable(VariableTracker):
|
|||||||
|
|
||||||
def as_tensor(self, tx):
|
def as_tensor(self, tx):
|
||||||
if self._tensor_var is None:
|
if self._tensor_var is None:
|
||||||
from .builder import SourcelessBuilder
|
self._tensor_var = VariableTracker.build(
|
||||||
|
|
||||||
self._tensor_var = SourcelessBuilder.create(
|
|
||||||
tx, torch.scalar_tensor
|
tx, torch.scalar_tensor
|
||||||
).call_function(tx, [self], {})
|
).call_function(tx, [self], {})
|
||||||
return self._tensor_var
|
return self._tensor_var
|
||||||
@ -1362,12 +1353,10 @@ class TensorSubclassVariable(VariableTracker):
|
|||||||
kwargs: Dict[str, VariableTracker],
|
kwargs: Dict[str, VariableTracker],
|
||||||
) -> VariableTracker:
|
) -> VariableTracker:
|
||||||
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
||||||
from .builder import VariableBuilder
|
|
||||||
from .torch_function import TensorWithTFOverrideVariable
|
from .torch_function import TensorWithTFOverrideVariable
|
||||||
|
|
||||||
torch_fn = VariableBuilder(
|
source = AttrSource(self.source, "__torch_function__")
|
||||||
tx, AttrSource(self.source, "__torch_function__")
|
torch_fn = VariableTracker.build(tx, self.value.__torch_function__, source)
|
||||||
)(self.value.__torch_function__)
|
|
||||||
|
|
||||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||||
tx, args[0], self.value, torch_fn
|
tx, args[0], self.value, torch_fn
|
||||||
|
@ -397,7 +397,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
TensorVariable,
|
TensorVariable,
|
||||||
UserDefinedObjectVariable,
|
UserDefinedObjectVariable,
|
||||||
)
|
)
|
||||||
from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
|
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
|
||||||
|
|
||||||
@register(*tracing_state_functions)
|
@register(*tracing_state_functions)
|
||||||
def handle_tracing_state_functions(
|
def handle_tracing_state_functions(
|
||||||
@ -422,14 +422,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
# the set of functions that we trace __torch_function__ on to
|
# the set of functions that we trace __torch_function__ on to
|
||||||
# functions outside of the actual set. Implementing this properly will require implementing
|
# functions outside of the actual set. Implementing this properly will require implementing
|
||||||
# some variable types to track and compare tensor getset descriptors
|
# some variable types to track and compare tensor getset descriptors
|
||||||
return SourcelessBuilder.create(
|
return VariableTracker.build(
|
||||||
tx, torch.overrides.get_default_nowrap_functions()
|
tx, torch.overrides.get_default_nowrap_functions()
|
||||||
)
|
)
|
||||||
|
|
||||||
@register(torch.ops.inductor.accumulate_grad_.default)
|
@register(torch.ops.inductor.accumulate_grad_.default)
|
||||||
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
|
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs
|
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@register(math.radians)
|
@register(math.radians)
|
||||||
@ -437,7 +437,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
if not check_unspec_or_constant_args(args, kwargs):
|
if not check_unspec_or_constant_args(args, kwargs):
|
||||||
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
|
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.radians), args, kwargs
|
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@register(torch.is_tensor, torch.overrides.is_tensor_like)
|
@register(torch.is_tensor, torch.overrides.is_tensor_like)
|
||||||
@ -622,7 +622,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
):
|
):
|
||||||
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
|
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace),
|
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
@ -635,7 +635,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
# in compile, it's more performant to not graph break.
|
# in compile, it's more performant to not graph break.
|
||||||
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
|
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar),
|
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
@ -704,7 +704,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||||||
# Note - while we *could* cook up sources around invocations, like a FunctionSource
|
# Note - while we *could* cook up sources around invocations, like a FunctionSource
|
||||||
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
|
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
|
||||||
# guard propagation via options is the best we can do.
|
# guard propagation via options is the best we can do.
|
||||||
return SourcelessBuilder.create(tx, invocation_result)
|
return VariableTracker.build(tx, invocation_result)
|
||||||
|
|
||||||
@register(DTensor.from_local)
|
@register(DTensor.from_local)
|
||||||
def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
|
def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||||
@ -1143,8 +1143,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
|
def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
|
||||||
# Alternate version if we have a .source
|
# Alternate version if we have a .source
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
varname = tx.output.new_var()
|
varname = tx.output.new_var()
|
||||||
|
|
||||||
# construct the nn.Parmeter before the graph save it to varname
|
# construct the nn.Parmeter before the graph save it to varname
|
||||||
@ -1167,7 +1165,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||||||
example_value = torch.nn.Parameter(
|
example_value = torch.nn.Parameter(
|
||||||
tx.output.example_value_from_input_node(data.as_proxy().node)
|
tx.output.example_value_from_input_node(data.as_proxy().node)
|
||||||
)
|
)
|
||||||
result = VariableBuilder(tx, source)(example_value)
|
result = VariableTracker.build(tx, example_value, source)
|
||||||
# No need to guard on this since we already guarded on `data`.
|
# No need to guard on this since we already guarded on `data`.
|
||||||
# These guards would fail since varname doesn't exist until after the function starts
|
# These guards would fail since varname doesn't exist until after the function starts
|
||||||
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
||||||
|
@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
|
|||||||
if isinstance(var, TensorWithTFOverrideVariable):
|
if isinstance(var, TensorWithTFOverrideVariable):
|
||||||
return var.class_type_var(tx)
|
return var.class_type_var(tx)
|
||||||
elif isinstance(var, UserDefinedObjectVariable):
|
elif isinstance(var, UserDefinedObjectVariable):
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
source = var.source and TypeSource(var.source)
|
||||||
|
return VariableTracker.build(tx, var.python_type(), source)
|
||||||
if var.source:
|
|
||||||
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, var.python_type())
|
|
||||||
|
|
||||||
|
|
||||||
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
||||||
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
|||||||
def call_torch_function(
|
def call_torch_function(
|
||||||
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
|
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
|
||||||
):
|
):
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
# signature:
|
# signature:
|
||||||
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||||
tf_args = (
|
tf_args = (
|
||||||
torch_function_type,
|
torch_function_type,
|
||||||
fn,
|
fn,
|
||||||
types,
|
types,
|
||||||
SourcelessBuilder.create(tx, tuple(args)),
|
VariableTracker.build(tx, tuple(args)),
|
||||||
SourcelessBuilder.create(tx, kwargs),
|
VariableTracker.build(tx, kwargs),
|
||||||
)
|
)
|
||||||
return tx.inline_user_function_return(torch_function_var, tf_args, {})
|
return tx.inline_user_function_return(torch_function_var, tf_args, {})
|
||||||
|
|
||||||
@ -515,20 +509,13 @@ def call_torch_function(
|
|||||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||||
from types import FunctionType
|
from types import FunctionType
|
||||||
|
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
func = value.__torch_function__.__func__
|
func = value.__torch_function__.__func__
|
||||||
|
|
||||||
if not isinstance(func, FunctionType):
|
if not isinstance(func, FunctionType):
|
||||||
unimplemented("Builtin/C++ torch function implementations NYI")
|
unimplemented("Builtin/C++ torch function implementations NYI")
|
||||||
|
|
||||||
if source:
|
source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
|
||||||
return VariableBuilder(
|
return VariableTracker.build(tx, func, source)
|
||||||
tx,
|
|
||||||
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
|
|
||||||
)(value.__torch_function__.__func__)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
|
|
||||||
|
|
||||||
|
|
||||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||||
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
# base tensors, custom attribute accesses will graph break.
|
# base tensors, custom attribute accesses will graph break.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
if name in banned_attrs:
|
if name in banned_attrs:
|
||||||
unimplemented(
|
unimplemented(
|
||||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||||
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
GuardBuilder.FUNCTION_MATCH
|
GuardBuilder.FUNCTION_MATCH
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
|
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
||||||
|
|
||||||
return self.call_torch_function(
|
return self.call_torch_function(
|
||||||
tx,
|
tx,
|
||||||
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
if tx.output.torch_function_enabled:
|
if tx.output.torch_function_enabled:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
if _is_attr_overidden(tx, self, name):
|
if _is_attr_overidden(tx, self, name):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
f"Calling overridden method {name} on a tensor"
|
f"Calling overridden method {name} on a tensor"
|
||||||
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||||||
# We've established with the above check that the method is not overridden, so we guard that the method is the same
|
# We've established with the above check that the method is not overridden, so we guard that the method is the same
|
||||||
# as the impl defined on tensor and retrieve it
|
# as the impl defined on tensor and retrieve it
|
||||||
if self.source:
|
if self.source:
|
||||||
func_var = VariableBuilder(
|
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||||
tx, AttrSource(AttrSource(self.source, "__class__"), name)
|
value = inspect.getattr_static(self.python_type(), name)
|
||||||
)(inspect.getattr_static(self.python_type(), name))
|
|
||||||
else:
|
else:
|
||||||
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
|
source = None
|
||||||
|
value = getattr(torch.Tensor, name)
|
||||||
|
func_var = VariableTracker.build(tx, value, source)
|
||||||
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
|
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
|
||||||
else:
|
else:
|
||||||
return super().call_method(tx, name, args, kwargs)
|
return super().call_method(tx, name, args, kwargs)
|
||||||
|
@ -158,7 +158,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
|
|
||||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||||
from . import ConstantVariable, EnumVariable
|
from . import ConstantVariable, EnumVariable
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
source = AttrSource(self.source, name) if self.source is not None else None
|
source = AttrSource(self.source, name) if self.source is not None else None
|
||||||
|
|
||||||
@ -187,11 +186,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
if isinstance(obj, staticmethod):
|
if isinstance(obj, staticmethod):
|
||||||
func = obj.__get__(self.value)
|
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||||
if source is not None:
|
|
||||||
return VariableBuilder(tx, source)(func)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, func)
|
|
||||||
elif isinstance(obj, classmethod):
|
elif isinstance(obj, classmethod):
|
||||||
if isinstance(obj.__func__, property):
|
if isinstance(obj.__func__, property):
|
||||||
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
||||||
@ -202,16 +197,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
||||||
# inspect.getattr_static(itertools.chain, "from_iterable")
|
# inspect.getattr_static(itertools.chain, "from_iterable")
|
||||||
func = obj.__get__(None, self.value)
|
func = obj.__get__(None, self.value)
|
||||||
if source is not None:
|
return VariableTracker.build(tx, func, source)
|
||||||
return VariableBuilder(tx, source)(func)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, func)
|
|
||||||
elif source:
|
elif source:
|
||||||
# __mro__ is a member in < 3.12, an attribute in >= 3.12
|
# __mro__ is a member in < 3.12, an attribute in >= 3.12
|
||||||
if inspect.ismemberdescriptor(obj) or (
|
if inspect.ismemberdescriptor(obj) or (
|
||||||
sys.version_info >= (3, 12) and name == "__mro__"
|
sys.version_info >= (3, 12) and name == "__mro__"
|
||||||
):
|
):
|
||||||
return VariableBuilder(tx, source)(obj.__get__(self.value))
|
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||||
|
|
||||||
if ConstantVariable.is_literal(obj):
|
if ConstantVariable.is_literal(obj):
|
||||||
return ConstantVariable.create(obj)
|
return ConstantVariable.create(obj)
|
||||||
@ -222,14 +214,15 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
or self.value.__module__ == "torch"
|
or self.value.__module__ == "torch"
|
||||||
):
|
):
|
||||||
if source:
|
if source:
|
||||||
return VariableBuilder(tx, source)(obj)
|
return VariableTracker.build(tx, obj, source)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
source
|
source
|
||||||
and not inspect.ismethoddescriptor(obj)
|
and not inspect.ismethoddescriptor(obj)
|
||||||
and not is_wrapper_or_member_descriptor(obj)
|
and not is_wrapper_or_member_descriptor(obj)
|
||||||
):
|
):
|
||||||
return VariableBuilder(tx, source)(obj)
|
return VariableTracker.build(tx, obj, source)
|
||||||
|
|
||||||
return super().var_getattr(tx, name)
|
return super().var_getattr(tx, name)
|
||||||
|
|
||||||
def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
|
def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
|
||||||
@ -341,7 +334,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
kwargs: "Dict[str, VariableTracker]",
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from ..side_effects import SideEffects
|
from ..side_effects import SideEffects
|
||||||
from .builder import SourcelessBuilder, wrap_fx_proxy
|
from .builder import wrap_fx_proxy
|
||||||
from .builtin import BuiltinVariable
|
from .builtin import BuiltinVariable
|
||||||
|
|
||||||
constant_args = check_constant_args(args, kwargs)
|
constant_args = check_constant_args(args, kwargs)
|
||||||
@ -452,7 +445,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
field_var = kwargs[field_name]
|
field_var = kwargs[field_name]
|
||||||
else:
|
else:
|
||||||
assert field_name in field_defaults
|
assert field_name in field_defaults
|
||||||
field_var = SourcelessBuilder.create(
|
field_var = VariableTracker.build(
|
||||||
tx, field_defaults[field_name]
|
tx, field_defaults[field_name]
|
||||||
)
|
)
|
||||||
var_tracker_kwargs[field_name] = field_var
|
var_tracker_kwargs[field_name] = field_var
|
||||||
@ -465,8 +458,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
|
|
||||||
return variables.NamedTupleVariable(items, self.value)
|
return variables.NamedTupleVariable(items, self.value)
|
||||||
elif is_frozen_dataclass(self.value) and self.is_standard_new():
|
elif is_frozen_dataclass(self.value) and self.is_standard_new():
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
fields = dataclasses.fields(self.value)
|
fields = dataclasses.fields(self.value)
|
||||||
items = list(args)
|
items = list(args)
|
||||||
items.extend([None] * (len(fields) - len(items)))
|
items.extend([None] * (len(fields) - len(items)))
|
||||||
@ -481,9 +472,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if field.default is not dataclasses.MISSING:
|
if field.default is not dataclasses.MISSING:
|
||||||
var_tracker = SourcelessBuilder.create(tx, field.default)
|
var_tracker = VariableTracker.build(tx, field.default)
|
||||||
elif field.default_factory is not dataclasses.MISSING:
|
elif field.default_factory is not dataclasses.MISSING:
|
||||||
factory_fn = SourcelessBuilder.create(
|
factory_fn = VariableTracker.build(
|
||||||
tx, field.default_factory
|
tx, field.default_factory
|
||||||
)
|
)
|
||||||
var_tracker = factory_fn.call_function(tx, [], {})
|
var_tracker = factory_fn.call_function(tx, [], {})
|
||||||
@ -573,7 +564,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||||||
and self.source
|
and self.source
|
||||||
):
|
):
|
||||||
return tx.inline_user_function_return(
|
return tx.inline_user_function_return(
|
||||||
SourcelessBuilder.create(
|
VariableTracker.build(
|
||||||
tx, polyfills.instantiate_user_defined_class_object
|
tx, polyfills.instantiate_user_defined_class_object
|
||||||
),
|
),
|
||||||
[self, *args],
|
[self, *args],
|
||||||
@ -857,7 +848,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
kwargs: "Dict[str, VariableTracker]",
|
kwargs: "Dict[str, VariableTracker]",
|
||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
from .. import trace_rules
|
from .. import trace_rules
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.is_supported_random()
|
self.is_supported_random()
|
||||||
@ -894,9 +884,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
"Sourceless UserDefinedObjectVariable method not supported"
|
"Sourceless UserDefinedObjectVariable method not supported"
|
||||||
)
|
)
|
||||||
func_src = AttrSource(self.source, "__func__")
|
func_src = AttrSource(self.source, "__func__")
|
||||||
func_var = VariableBuilder(tx, func_src)(func)
|
func_var = VariableTracker.build(tx, func, func_src)
|
||||||
obj_src = AttrSource(self.source, "__self__")
|
obj_src = AttrSource(self.source, "__self__")
|
||||||
obj_var = VariableBuilder(tx, obj_src)(obj)
|
obj_var = VariableTracker.build(tx, obj, obj_src)
|
||||||
return func_var.call_function(tx, [obj_var] + args, kwargs)
|
return func_var.call_function(tx, [obj_var] + args, kwargs)
|
||||||
elif (
|
elif (
|
||||||
istype(self.value, functools.partial)
|
istype(self.value, functools.partial)
|
||||||
@ -998,7 +988,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||||
from .. import trace_rules
|
from .. import trace_rules
|
||||||
from . import ConstantVariable
|
from . import ConstantVariable
|
||||||
from .builder import SourcelessBuilder, VariableBuilder
|
|
||||||
|
|
||||||
source = AttrSource(self.source, name) if self.source else None
|
source = AttrSource(self.source, name) if self.source else None
|
||||||
self._check_for_getattribute()
|
self._check_for_getattribute()
|
||||||
@ -1090,10 +1079,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
elif isinstance(subobj, types.ClassMethodDescriptorType):
|
elif isinstance(subobj, types.ClassMethodDescriptorType):
|
||||||
# e.g.: inspect.getattr_static({}, "fromkeys")
|
# e.g.: inspect.getattr_static({}, "fromkeys")
|
||||||
func = subobj.__get__(self.value, None)
|
func = subobj.__get__(self.value, None)
|
||||||
if source is not None:
|
return VariableTracker.build(tx, func, source)
|
||||||
return VariableBuilder(tx, source)(func)
|
|
||||||
else:
|
|
||||||
return SourcelessBuilder.create(tx, func)
|
|
||||||
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
|
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
|
||||||
subobj.__get__
|
subobj.__get__
|
||||||
):
|
):
|
||||||
@ -1188,7 +1174,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
subobj_from_class, src_from_class
|
subobj_from_class, src_from_class
|
||||||
)
|
)
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, subobj)
|
return VariableTracker.build(tx, subobj)
|
||||||
|
|
||||||
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
|
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
|
||||||
raise_observed_exception(AttributeError, tx)
|
raise_observed_exception(AttributeError, tx)
|
||||||
@ -1212,7 +1198,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
return variables.ConstantVariable.create(False)
|
return variables.ConstantVariable.create(False)
|
||||||
|
|
||||||
def odict_getitem(self, tx: "InstructionTranslator", key):
|
def odict_getitem(self, tx: "InstructionTranslator", key):
|
||||||
from .builder import VariableBuilder
|
|
||||||
from .dicts import is_hashable
|
from .dicts import is_hashable
|
||||||
|
|
||||||
# TODO this should probably be merged with the dict handling
|
# TODO this should probably be merged with the dict handling
|
||||||
@ -1223,10 +1208,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
|||||||
else key.as_python_constant()
|
else key.as_python_constant()
|
||||||
)
|
)
|
||||||
|
|
||||||
return VariableBuilder(
|
return VariableTracker.build(
|
||||||
tx,
|
tx,
|
||||||
ODictGetItemSource(self.source, index),
|
collections.OrderedDict.__getitem__(self.value, key.as_python_constant()),
|
||||||
)(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
|
self.source and ODictGetItemSource(self.source, index),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FrozenDataClassVariable(UserDefinedObjectVariable):
|
class FrozenDataClassVariable(UserDefinedObjectVariable):
|
||||||
@ -1236,14 +1222,14 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
|||||||
|
|
||||||
assert is_frozen_dataclass(value)
|
assert is_frozen_dataclass(value)
|
||||||
|
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
field_map = {}
|
field_map = {}
|
||||||
for field in fields(value):
|
for field in fields(value):
|
||||||
if hasattr(value, field.name):
|
if hasattr(value, field.name):
|
||||||
field_map[field.name] = VariableBuilder(
|
field_map[field.name] = VariableTracker.build(
|
||||||
tx, AttrSource(source, field.name)
|
tx,
|
||||||
)(getattr(value, field.name))
|
getattr(value, field.name),
|
||||||
|
source and AttrSource(source, field.name),
|
||||||
|
)
|
||||||
|
|
||||||
return FrozenDataClassVariable(value, fields=field_map, source=source)
|
return FrozenDataClassVariable(value, fields=field_map, source=source)
|
||||||
|
|
||||||
@ -1315,16 +1301,8 @@ class WeakRefVariable(UserDefinedObjectVariable):
|
|||||||
) -> "VariableTracker":
|
) -> "VariableTracker":
|
||||||
call_source = None
|
call_source = None
|
||||||
referent = self.value()
|
referent = self.value()
|
||||||
|
source = self.source and WeakRefCallSource(self.source)
|
||||||
if self.source:
|
return VariableTracker.build(tx, referent, source)
|
||||||
from .builder import VariableBuilder
|
|
||||||
|
|
||||||
call_source = WeakRefCallSource(self.source)
|
|
||||||
return VariableBuilder(tx, call_source)(referent)
|
|
||||||
else:
|
|
||||||
from .builder import SourcelessBuilder
|
|
||||||
|
|
||||||
return SourcelessBuilder.create(tx, referent)
|
|
||||||
|
|
||||||
|
|
||||||
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
|
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
|
||||||
|
Reference in New Issue
Block a user