[dynamo] Simplify creation of VariableTrackers (#135714)

## `VariableTracker::build()` hides the Builders

### The problem

In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).

Variations on this code are repeated in many places.

More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.

### The solution

In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.

This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.

**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
Tom Ritchford
2024-10-17 16:21:48 +00:00
committed by PyTorch MergeBot
parent 1581a93e87
commit e1c4548441
18 changed files with 180 additions and 333 deletions

View File

@ -91,7 +91,6 @@ from .variables.builder import (
BackwardStateGraphArg, BackwardStateGraphArg,
GraphArg, GraphArg,
TrackedFake, TrackedFake,
VariableBuilder,
wrap_fx_proxy, wrap_fx_proxy,
) )
from .variables.lists import BaseListVariable from .variables.lists import BaseListVariable
@ -498,7 +497,7 @@ class OutputGraph:
cg.store(varname) cg.store(varname)
self.pregraph_bytecode.extend(cg.get_instructions()) self.pregraph_bytecode.extend(cg.get_instructions())
source = SyntheticLocalSource(varname) source = SyntheticLocalSource(varname)
result = VariableBuilder(self.root_tx, source)(example_value) result = VariableTracker.build(self.root_tx, example_value, source)
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
source source
) )
@ -767,8 +766,8 @@ class OutputGraph:
): ):
if is_dynamic_nn_module(target, self.root_tx.export): if is_dynamic_nn_module(target, self.root_tx.export):
# Instead of returning UnspecializedNNModuleVariable, call # Instead of returning UnspecializedNNModuleVariable, call
# VariableBuilder so that it is tracked for mutation. # VariableTracker.build so that it is tracked for mutation.
return VariableBuilder(self.current_tx, **options)(target) return VariableTracker.build(self.current_tx, target, **options)
options = dict(options) options = dict(options)
assert "source" in options assert "source" in options
@ -860,8 +859,8 @@ class OutputGraph:
def wrap_name(module_key): def wrap_name(module_key):
self.output.update_co_names(module_key) self.output.update_co_names(module_key)
self.global_scope[module_key] = target self.global_scope[module_key] = target
return VariableBuilder(self, ConstantSource(source_name=module_key))( return VariableTracker.build(
target self, target, ConstantSource(source_name=module_key)
) )
for k, v in self.nn_modules.items(): for k, v in self.nn_modules.items():

View File

@ -71,7 +71,7 @@ from .utils import (
proxy_args_kwargs, proxy_args_kwargs,
) )
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builder import wrap_fx_proxy
from .variables.builtin import BuiltinVariable from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable from .variables.constant import ConstantVariable
from .variables.ctx_manager import ( from .variables.ctx_manager import (
@ -1224,15 +1224,14 @@ class InstructionTranslatorBase(
except KeyError: except KeyError:
return self.load_builtin(inst) return self.load_builtin(inst)
source = GlobalSource(name) self.push(VariableTracker.build(self, value, GlobalSource(name)))
self.push(VariableBuilder(self, source)(value))
@functools.cached_property @functools.cached_property
def nn_modules_globals_vt(self): def nn_modules_globals_vt(self):
module_name = "torch.nn.modules.module" module_name = "torch.nn.modules.module"
module_source = self.import_source(module_name) module_source = self.import_source(module_name)
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
return VariableBuilder(self, module_source)(fglobals_value) return VariableTracker.build(self, fglobals_value, module_source)
def LOAD_GLOBAL(self, inst): def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
@ -1374,7 +1373,7 @@ class InstructionTranslatorBase(
self.output.name_of_builtins_dict_key_in_fglobals self.output.name_of_builtins_dict_key_in_fglobals
) )
var_source = GetItemSource(builtins_source, argval) var_source = GetItemSource(builtins_source, argval)
self.push(VariableBuilder(self, var_source)(val)) self.push(VariableTracker.build(self, val, var_source))
else: else:
assert is_builtin_constant(val) assert is_builtin_constant(val)
self.push(ConstantVariable.create(value=val)) self.push(ConstantVariable.create(value=val))
@ -3403,7 +3402,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment] fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
else: else:
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
global_source = AttrSource(module_source, name) global_source = AttrSource(module_source, name)
else: else:
globals_name = self.output.install_global_by_id( globals_name = self.output.install_global_by_id(
@ -3411,7 +3410,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
) )
globals_source = GlobalSource(globals_name) globals_source = GlobalSource(globals_name)
fglobals_value = self.f_globals # type: ignore[assignment] fglobals_value = self.f_globals # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
global_source = GetItemSource(globals_source, name) # type: ignore[assignment] global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
return fglobals_value, fglobals_vt, global_source return fglobals_value, fglobals_vt, global_source
@ -3430,7 +3429,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
except KeyError: except KeyError:
return self.load_builtin(inst) return self.load_builtin(inst)
self.push(VariableBuilder(self, global_source)(value)) self.push(VariableTracker.build(self, value, global_source))
def STORE_GLOBAL(self, inst): def STORE_GLOBAL(self, inst):
if self.f_globals is self.parent.f_globals: if self.f_globals is self.parent.f_globals:

View File

@ -12,7 +12,7 @@ from ..utils import istype
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
class MutableLocalSource(Enum): class MutableLocalSource(Enum):
@ -121,6 +121,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
VariableTracker instances are immutable and should be copied in VariableTracker instances are immutable and should be copied in
order to change them. order to change them.
Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
""" """
# fields to leave unmodified in apply() # fields to leave unmodified in apply()
@ -244,9 +246,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
value = self.const_getattr(tx, name) value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value): if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError raise NotImplementedError
source = None source = self.source and AttrSource(self.source, name)
if self.source:
source = AttrSource(self.source, name)
return variables.ConstantVariable.create(value, source=source) return variables.ConstantVariable.create(value, source=source)
def is_proxy(self): def is_proxy(self):
@ -363,6 +363,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def is_strict_mode(self, tx): def is_strict_mode(self, tx):
return tx.strict_checks_fn and tx.strict_checks_fn(self) return tx.strict_checks_fn and tx.strict_checks_fn(self)
@staticmethod
def build(
tx: "InstructionTranslatorBase",
value: Any,
source: Optional[Source] = None,
) -> Any:
"""Create a new VariableTracker from a value and optional Source"""
from . import builder
if source is None:
return builder.SourcelessBuilder.create(tx, value)
else:
return builder.VariableBuilder(tx, source)(value)
def __init__( def __init__(
self, self,
*, *,

View File

@ -701,7 +701,6 @@ class BuiltinVariable(VariableTracker):
@staticmethod @staticmethod
def _make_handler(fn, arg_types: List[type], has_kwargs: bool): def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
from .builder import SourcelessBuilder
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
obj = BuiltinVariable(fn) obj = BuiltinVariable(fn)
@ -794,8 +793,6 @@ class BuiltinVariable(VariableTracker):
handlers.append(call_self_handler) handlers.append(call_self_handler)
if obj.can_constant_fold_through(): if obj.can_constant_fold_through():
builder = SourcelessBuilder.create
if ( if (
all(issubclass(x, ConstantVariable) for x in arg_types) all(issubclass(x, ConstantVariable) for x in arg_types)
and not has_kwargs and not has_kwargs
@ -809,7 +806,7 @@ class BuiltinVariable(VariableTracker):
) )
except Exception as exc: except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}") unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res) return VariableTracker.build(tx, res)
else: else:
@ -825,7 +822,7 @@ class BuiltinVariable(VariableTracker):
) )
except Exception as exc: except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}") unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res) return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler) handlers.append(constant_fold_handler)
@ -1361,8 +1358,6 @@ class BuiltinVariable(VariableTracker):
@staticmethod @staticmethod
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
from .builder import SourcelessBuilder
if not kwargs: if not kwargs:
if not args: if not args:
args = ({},) args = ({},)
@ -1399,7 +1394,7 @@ class BuiltinVariable(VariableTracker):
) )
new_dict = dict(arg.value.items()) new_dict = dict(arg.value.items())
return SourcelessBuilder.create(tx, new_dict) return VariableTracker.build(tx, new_dict)
else: else:
func_var = arg.var_getattr(tx, "items") func_var = arg.var_getattr(tx, "items")
if not isinstance(func_var, variables.UserFunctionVariable): if not isinstance(func_var, variables.UserFunctionVariable):
@ -1631,7 +1626,6 @@ class BuiltinVariable(VariableTracker):
TorchInGraphFunctionVariable, TorchInGraphFunctionVariable,
UserFunctionVariable, UserFunctionVariable,
) )
from .builder import SourcelessBuilder, VariableBuilder
name = name_var.as_python_constant() name = name_var.as_python_constant()
@ -1666,34 +1660,21 @@ class BuiltinVariable(VariableTracker):
if not hasattr_var.as_python_constant(): if not hasattr_var.as_python_constant():
return default return default
options = {} source = obj.source and AttrSource(obj.source, name)
if obj.source:
source = AttrSource(obj.source, name)
options["source"] = source
else:
source = None
if name in {"__bases__", "__base__", "__flags__"}: if name in {"__bases__", "__base__", "__flags__"}:
try: try:
value = obj.as_python_constant() value = obj.as_python_constant()
if isinstance(value, type): if isinstance(value, type):
if name == "__bases__": if name == "__bases__":
bases = value.__bases__ tuple_args = [
if source is not None: VariableTracker.build(
tuple_args = [ tx, b, source and GetItemSource(source, i)
VariableBuilder(tx, GetItemSource(source, i))(b) )
for i, b in enumerate(bases) for i, b in enumerate(value.__bases__)
] ]
else: return variables.TupleVariable(tuple_args, source=source)
tuple_args = [
SourcelessBuilder.create(tx, b) for b in bases
]
return variables.TupleVariable(tuple_args, **options)
if name == "__base__": if name == "__base__":
base = value.__base__ return VariableTracker.build(tx, value.__base__, source)
if source is not None:
return VariableBuilder(tx, source)(base)
return SourcelessBuilder.create(tx, base)
if name == "__flags__": if name == "__flags__":
return ConstantVariable.create(value.__flags__) return ConstantVariable.create(value.__flags__)
except NotImplementedError: except NotImplementedError:
@ -1715,14 +1696,14 @@ class BuiltinVariable(VariableTracker):
try: try:
return obj.var_getattr(tx, name) return obj.var_getattr(tx, name)
except NotImplementedError: except NotImplementedError:
return GetAttrVariable(obj, name, **options) return GetAttrVariable(obj, name, source=source)
elif isinstance(obj, TorchInGraphFunctionVariable): elif isinstance(obj, TorchInGraphFunctionVariable):
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
member = getattr(obj.value, name) member = getattr(obj.value, name)
if isinstance( if isinstance(
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
) and trace_rules.is_aten_op_or_tensor_method(member): ) and trace_rules.is_aten_op_or_tensor_method(member):
return TorchInGraphFunctionVariable(member, **options) return TorchInGraphFunctionVariable(member, source=source)
elif isinstance(obj, DummyModule): elif isinstance(obj, DummyModule):
# TODO(mlazos) - Do we need this? # TODO(mlazos) - Do we need this?
if obj.is_torch or name not in obj.value.__dict__: if obj.is_torch or name not in obj.value.__dict__:
@ -1732,18 +1713,15 @@ class BuiltinVariable(VariableTracker):
if config.replay_record_enabled: if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member) tx.exec_recorder.record_module_access(obj.value, name, member)
return VariableTracker.build(tx, member, source)
if source is not None:
return VariableBuilder(tx, source)(member)
else:
return SourcelessBuilder.create(tx, member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create(getattr(obj.fn, name)) return ConstantVariable.create(getattr(obj.fn, name))
else: else:
try: try:
return obj.var_getattr(tx, name) return obj.var_getattr(tx, name)
except NotImplementedError: except NotImplementedError:
return GetAttrVariable(obj, name, **options) return GetAttrVariable(obj, name, source=source)
def call_setattr( def call_setattr(
self, self,
@ -1882,8 +1860,6 @@ class BuiltinVariable(VariableTracker):
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder
try: try:
py_type = obj.python_type() py_type = obj.python_type()
except NotImplementedError as error: except NotImplementedError as error:
@ -1893,10 +1869,8 @@ class BuiltinVariable(VariableTracker):
case_name="unknown_python_type", case_name="unknown_python_type",
) from None ) from None
if obj.source is None: source = obj.source and TypeSource(obj.source)
return SourcelessBuilder.create(tx, py_type) return VariableTracker.build(tx, py_type, source)
else:
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
if obj.has_unpack_var_sequence(tx): if obj.has_unpack_var_sequence(tx):

View File

@ -984,12 +984,10 @@ class HFPretrainedConfigVariable(VariableTracker):
assert self.is_matching_cls(type(obj)) assert self.is_matching_cls(type(obj))
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from .builder import VariableBuilder
try: try:
attr_value = getattr(self.obj, name) attr_value = getattr(self.obj, name)
attr_source = AttrSource(self.source, name) source = self.source and AttrSource(self.source, name)
return VariableBuilder(tx, attr_source)(attr_value) return VariableTracker.build(tx, attr_value, source)
except AttributeError: except AttributeError:
unimplemented(f"getattr({self.value}, {name})") unimplemented(f"getattr({self.value}, {name})")
@ -1053,15 +1051,11 @@ class PythonSysModulesVariable(VariableTracker):
key: VariableTracker, key: VariableTracker,
default: Optional[VariableTracker] = None, default: Optional[VariableTracker] = None,
): ):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key) k, has_key = self._contains_helper(tx, key)
if has_key: if has_key:
return VariableBuilder( source = self.source and GetItemSource(self.source, k)
tx, return VariableTracker.build(tx, sys.modules[k], source)
GetItemSource(self.source, k),
)(sys.modules[k])
if default is not None: if default is not None:
return default return default
@ -1069,10 +1063,6 @@ class PythonSysModulesVariable(VariableTracker):
return ConstantVariable.create(value=None) return ConstantVariable.create(value=None)
def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker): def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key) k, has_key = self._contains_helper(tx, key)
return VariableBuilder( source = self.source and GetItemSource(self.source, k)
tx, return VariableTracker.build(tx, sys.modules[k], source)
GetItemSource(self.source, k),
)(sys.modules[k])

View File

@ -46,9 +46,7 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
if isinstance(val, VariableTracker): if isinstance(val, VariableTracker):
return val return val
elif not source: elif not source:
from torch._dynamo.variables.builder import SourcelessBuilder return VariableTracker.build(tx, val)
return SourcelessBuilder.create(tx, val)
else: else:
# Create a lazy variable to avoid guarding on __defaults__ unless really # Create a lazy variable to avoid guarding on __defaults__ unless really
# needed. # needed.
@ -240,8 +238,6 @@ class UserFunctionVariable(BaseUserFunctionVariable):
# optimization for cleaner codegen # optimization for cleaner codegen
result[name] = var result[name] = var
elif self.source: elif self.source:
from .builder import VariableBuilder
side_effects = parent.output.side_effects side_effects = parent.output.side_effects
if cell in side_effects: if cell in side_effects:
out = side_effects[cell] out = side_effects[cell]
@ -253,9 +249,9 @@ class UserFunctionVariable(BaseUserFunctionVariable):
closure_cell, "cell_contents" closure_cell, "cell_contents"
) )
try: try:
contents_var = VariableBuilder( contents_var = VariableTracker.build(
parent, closure_cell_contents parent, cell.cell_contents, closure_cell_contents
)(cell.cell_contents) )
except ValueError: except ValueError:
# Cell has not yet been assigned # Cell has not yet been assigned
contents_var = variables.DeletedVariable() contents_var = variables.DeletedVariable()
@ -286,9 +282,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
result[name] = out result[name] = out
else: else:
from .builder import SourcelessBuilder result[name] = VariableTracker.build(tx, cell.cell_contents)
result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
return result, closure_cells return result, closure_cells
@ -296,17 +290,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
pass pass
def var_getattr(self, tx: "InstructionTranslator", name: str): def var_getattr(self, tx: "InstructionTranslator", name: str):
source = AttrSource(self.source, name) if self.source else None source = self.source and AttrSource(self.source, name)
try: try:
subobj = inspect.getattr_static(self.fn, name) subobj = inspect.getattr_static(self.fn, name)
except AttributeError: except AttributeError:
options = {"source": source} return variables.GetAttrVariable(self, name, source=source)
return variables.GetAttrVariable(self, name, **options)
if source: if source:
return variables.LazyVariableTracker.create(subobj, source) return variables.LazyVariableTracker.create(subobj, source)
from .builder import SourcelessBuilder return VariableTracker.build(tx, subobj)
return SourcelessBuilder.create(tx, subobj)
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
result = hasattr(self.fn, name) result = hasattr(self.fn, name)
@ -757,14 +748,8 @@ class WrapperUserFunctionVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name):
if name == self.attr_to_trace: if name == self.attr_to_trace:
val = getattr(self.wrapper_obj, self.attr_to_trace) val = getattr(self.wrapper_obj, self.attr_to_trace)
if self.source: source = self.source and AttrSource(self.source, name)
from .builder import VariableBuilder return VariableTracker.build(tx, val, source)
return VariableBuilder(tx, AttrSource(self.source, name))(val)
else:
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, val)
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
@ -999,8 +984,6 @@ class PolyfilledFunctionVariable(VariableTracker):
args: "List[VariableTracker]", args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]", kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
from torch._dynamo.variables.builder import SourcelessBuilder
if self.can_constant_fold_through() and check_unspec_or_constant_args( if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs args, kwargs
): ):
@ -1010,9 +993,9 @@ class PolyfilledFunctionVariable(VariableTracker):
**{k: v.as_python_constant() for k, v in kwargs.items()}, **{k: v.as_python_constant() for k, v in kwargs.items()},
) )
) )
return SourcelessBuilder.create(tx, result) return VariableTracker.build(tx, result)
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn) traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
return traceable_function_variable.call_function(tx, args, kwargs) return traceable_function_variable.call_function(tx, args, kwargs)
def call_method( def call_method(

View File

@ -13,7 +13,6 @@ import torch.fx
import torch.nn import torch.nn
from torch._dynamo.utils import get_fake_value from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import ConstantVariable from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.builtin import BuiltinVariable from torch._dynamo.variables.builtin import BuiltinVariable
from torch._dynamo.variables.functions import UserFunctionVariable from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.tensor import SymNodeVariable from torch._dynamo.variables.tensor import SymNodeVariable
@ -31,6 +30,7 @@ from ..exc import (
) )
from ..source import AttrSource from ..source import AttrSource
from ..utils import proxy_args_kwargs from ..utils import proxy_args_kwargs
from .base import VariableTracker
from .dicts import ConstDictVariable from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable from .lists import ListVariable, TupleVariable
@ -1040,7 +1040,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
args: List[VariableTracker], args: List[VariableTracker],
kwargs: Dict[str, VariableTracker], kwargs: Dict[str, VariableTracker],
) -> VariableTracker: ) -> VariableTracker:
from .builder import SourcelessBuilder, wrap_fx_proxy from .builder import wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
@ -1062,7 +1062,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
tx, tx,
"new_empty", "new_empty",
args=( args=(
SourcelessBuilder.create( VariableTracker.build(
tx, tx,
leaf.size leaf.size
if leaf.size is not None if leaf.size is not None
@ -1072,8 +1072,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
), ),
), ),
kwargs={ kwargs={
"dtype": SourcelessBuilder.create(tx, leaf.dtype), "dtype": VariableTracker.build(tx, leaf.dtype),
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad), "requires_grad": VariableTracker.build(tx, leaf.requires_grad),
}, },
) )
for leaf in itertools.chain(xs.items, xs.items) for leaf in itertools.chain(xs.items, xs.items)
@ -2057,7 +2057,6 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
fn_name: str, fn_name: str,
): ):
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
from .builder import SourcelessBuilder
tx: InstructionTranslator = tx tx: InstructionTranslator = tx
@ -2065,9 +2064,9 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
return query.call_method( return query.call_method(
tx, tx,
"new_empty", "new_empty",
(SourcelessBuilder.create(tx, []),), (VariableTracker.build(tx, []),),
{ {
"dtype": SourcelessBuilder.create(tx, torch.int32), "dtype": VariableTracker.build(tx, torch.int32),
}, },
) )
@ -2077,8 +2076,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
score = query.call_method( score = query.call_method(
tx, tx,
"new_empty", "new_empty",
(SourcelessBuilder.create(tx, []),), (VariableTracker.build(tx, []),),
{"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)}, {"requires_grad": VariableTracker.build(tx, scores_require_grad)},
) )
new_args = [score, *bhmn] new_args = [score, *bhmn]
else: else:

View File

@ -172,10 +172,8 @@ class ItertoolsVariable(VariableTracker):
*args, mutable_local=MutableLocal() *args, mutable_local=MutableLocal()
) )
from .builder import SourcelessBuilder
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs VariableTracker.build(tx, polyfills.repeat), args, kwargs
) )
elif self.value is itertools.count: elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())

View File

@ -20,14 +20,15 @@ class LazyCache:
def realize(self) -> None: def realize(self) -> None:
assert self.vt is None assert self.vt is None
from ..symbolic_convert import InstructionTranslator from ..symbolic_convert import InstructionTranslator
from .builder import SourcelessBuilder, VariableBuilder
tx = InstructionTranslator.current_tx() tx = InstructionTranslator.current_tx()
if isinstance(self.value, LazySymNodeFormatString):
self.vt = SourcelessBuilder.create(tx, self.value)
else:
self.vt = VariableBuilder(tx, self.source)(self.value)
if isinstance(self.value, LazySymNodeFormatString):
source = None
else:
source = self.source
self.vt = VariableTracker.build(tx, self.value, source)
del self.value del self.value
del self.source del self.source
@ -37,7 +38,7 @@ class LazyVariableTracker(VariableTracker):
A structure that defers the creation of the actual VariableTracker A structure that defers the creation of the actual VariableTracker
for a given underlying value until it is accessed. for a given underlying value until it is accessed.
The `realize` function invokes VariableBuilder to produce the real object. The `realize` function invokes VariableTracker.build() to produce the real object.
Once a LazyVariableTracker has been realized, internal bookkeeping will Once a LazyVariableTracker has been realized, internal bookkeeping will
prevent double realization. prevent double realization.

View File

@ -135,10 +135,8 @@ class BaseListVariable(VariableTracker):
assert not kwargs assert not kwargs
return iter_contains(self.unpack_var_sequence(tx), args[0], tx) return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
elif name == "index": elif name == "index":
from .builder import SourcelessBuilder
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.index), VariableTracker.build(tx, polyfills.index),
[self] + list(args), [self] + list(args),
kwargs, kwargs,
) )

View File

@ -207,12 +207,10 @@ class SuperVariable(VariableTracker):
and len(kwargs) == 0 and len(kwargs) == 0
and args[0].is_python_constant() and args[0].is_python_constant()
): ):
from .builder import VariableBuilder
key = args[0].as_python_constant() key = args[0].as_python_constant()
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( value = collections.OrderedDict.__getitem__(self.objvar.value, key)
collections.OrderedDict.__getitem__(self.objvar.value, key) source = ODictGetItemSource(self.objvar.source, key)
) return VariableTracker.build(tx, value, source)
elif inner_fn in ( elif inner_fn in (
collections.OrderedDict.__setitem__, collections.OrderedDict.__setitem__,
object.__setattr__, object.__setattr__,
@ -467,15 +465,10 @@ class InspectParameterVariable(VariableTracker):
self.value = value self.value = value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from .builder import SourcelessBuilder, VariableBuilder
try: try:
attr_value = getattr(self.value, name) attr_value = getattr(self.value, name)
if self.source: source = self.source and AttrSource(self.source, name)
attr_source = AttrSource(self.source, name) return VariableTracker.build(tx, attr_value, source)
return VariableBuilder(tx, attr_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)
except AttributeError: except AttributeError:
unimplemented(f"getattr({self.value}, {name})") unimplemented(f"getattr({self.value}, {name})")
@ -909,11 +902,9 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
if self.needs_input_grad is not None: if self.needs_input_grad is not None:
return variables.ConstantVariable.create(self.needs_input_grad) return variables.ConstantVariable.create(self.needs_input_grad)
if self.source: if self.source:
from .builder import VariableBuilder source = AttrSource(self.source, "needs_input_grad")
return VariableTracker.build(tx, self.value.needs_input_grad, source)
return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
self.value.needs_input_grad
)
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
@ -1118,11 +1109,8 @@ class GetSetDescriptorVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name):
if name == "__get__" and self.source: if name == "__get__" and self.source:
from .builder import VariableBuilder source = AttrSource(self.source, "__get__")
return VariableTracker.build(tx, self.desc.__get__, source)
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
self.desc.__get__
)
else: else:
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
@ -1162,18 +1150,13 @@ class PythonModuleVariable(VariableTracker):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name): if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name) return tx.output.side_effects.load_attr(self, name)
from .builder import SourcelessBuilder, VariableBuilder
if self.is_torch or name not in self.value.__dict__: if self.is_torch or name not in self.value.__dict__:
attr_value = getattr(self.value, name) attr_value = getattr(self.value, name)
else: else:
attr_value = self.value.__dict__[name] attr_value = self.value.__dict__[name]
if self.source: source = self.source and AttrSource(self.source, name)
new_source = AttrSource(self.source, name) return VariableTracker.build(tx, attr_value, source)
return VariableBuilder(tx, new_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)
class TypingVariable(VariableTracker): class TypingVariable(VariableTracker):

View File

@ -244,12 +244,7 @@ class NNModuleVariable(VariableTracker):
) )
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name):
from .builder import VariableBuilder source = self.source and AttrSource(self.source, name)
if self.source:
source = AttrSource(self.source, name)
else:
source = None
base = tx.output.get_submodule(self.module_key) base = tx.output.get_submodule(self.module_key)
base_dict = object.__getattribute__(base, "__dict__") base_dict = object.__getattribute__(base, "__dict__")
@ -297,7 +292,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserDefinedClassVariable(base.__class__, source=source) return variables.UserDefinedClassVariable(base.__class__, source=source)
if object_member: if object_member:
out = VariableBuilder(tx, NNModuleSource(source))(subobj) out = VariableTracker.build(tx, subobj, NNModuleSource(source))
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
# nn_module_stack source is BC surface area. Ensure that # nn_module_stack source is BC surface area. Ensure that
@ -333,7 +328,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserMethodVariable(subobj, self, source=source) return variables.UserMethodVariable(subobj, self, source=source)
elif is_safe_constant(subobj) or istensor(subobj): elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members # Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj) return VariableTracker.build(tx, subobj, NNModuleSource(source))
else: else:
unimplemented( unimplemented(
f"class property {name} - {typestr(base)} {typestr(subobj)}" f"class property {name} - {typestr(base)} {typestr(subobj)}"
@ -1083,7 +1078,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
) )
return variables.ConstDictVariable({}) return variables.ConstDictVariable({})
# For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable. # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable.
# However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
# differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why # differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
# key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a

View File

@ -17,6 +17,7 @@ from ..source import (
GradSource, GradSource,
) )
from ..utils import GLOBAL_KEY_PREFIX from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable from .constant import ConstantVariable
from .dicts import ConstDictVariable from .dicts import ConstDictVariable
from .lists import ListVariable from .lists import ListVariable
@ -27,8 +28,6 @@ from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator from torch._dynamo.symbolic_convert import InstructionTranslator
from .base import VariableTracker
class ArgMappingException(Exception): class ArgMappingException(Exception):
pass pass
@ -147,7 +146,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def _set_capturable(self, tx): def _set_capturable(self, tx):
from . import LazyVariableTracker from . import LazyVariableTracker
from .builder import VariableBuilder
# We only set capturable if params are on cuda # We only set capturable if params are on cuda
# and the state is not initialized # and the state is not initialized
@ -168,10 +166,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
if safe_to_set_capturable(group): if safe_to_set_capturable(group):
group["capturable"] = True group["capturable"] = True
source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all( param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))( VariableTracker.build(tx, self.value.param_groups, source)
self.value.param_groups
)
) )
for ind, param_group_vt in enumerate(param_groups_vt.items): for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker( key = ConstDictVariable._HashableTracker(
@ -214,7 +211,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def map_sources_and_install_guards(self, tx): def map_sources_and_install_guards(self, tx):
from ..decorators import mark_static_address from ..decorators import mark_static_address
from .builder import VariableBuilder
from .lazy import LazyVariableTracker from .lazy import LazyVariableTracker
self.grad_to_source = {} self.grad_to_source = {}
@ -235,15 +231,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
# Recursively realize the variable trackers for optim.state and # Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards. # optim.param_groups, which recursively install the necessary guards.
params_groups_source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all( param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))( VariableTracker.build(tx, self.value.param_groups, params_groups_source)
self.value.param_groups
)
) )
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))( state_source = self.source and AttrSource(self.source, "state")
self.value.state state_vt = VariableTracker.build(tx, self.value.state, state_source)
)
# We need to realize the top level state dict to populate # We need to realize the top level state dict to populate
# the guard locals # the guard locals
@ -265,15 +259,15 @@ class OptimizerVariable(UserDefinedObjectVariable):
key_index = i key_index = i
break break
if key_index: if key_index:
state_source = AttrSource(self.source, "state")
LazyVariableTracker.realize_all( LazyVariableTracker.realize_all(
VariableBuilder( VariableTracker.build(
tx, tx,
self.value.state[param],
GetItemSource( GetItemSource(
state_source, state_source,
ConstDictKeySource(state_source, key_index), ConstDictKeySource(state_source, key_index),
), ),
)(self.value.state[param]) )
) )
break break
@ -312,7 +306,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
# We have to again iterate over the state dict to collect the # We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer. # tensor_to_source dict. This is used for the finalizer.
state_source = AttrSource(self.source, "state")
for idx, (p, value) in enumerate(self.value.state.items()): for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = GetItemSource( p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx) state_source, ConstDictKeySource(state_source, idx)
@ -328,7 +321,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
"""Wrap state tensor in a TensorVariable""" """Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address from ..decorators import mark_static_address
from .builder import VariableBuilder
# If we have a source for a tensor already use it, # If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a # if we have not seen a tensor before, stash and use a
@ -338,20 +330,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
if tensor_value in self.tensor_to_source: if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs # mark these tensors as static for cudagraphs
mark_static_address(tensor_value) mark_static_address(tensor_value)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value]) source = self.tensor_to_source[tensor_value]
self.static_tensor_names.add(tx.output.module_key_name(builder.name)) self.static_tensor_names.add(tx.output.module_key_name(source.name))
elif tensor_value in self.grad_to_source: elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value]) source = self.grad_to_source[tensor_value]
else: else:
# mark these tensors as static for cudagraphs # mark these tensors as static for cudagraphs
mark_static_address(tensor_value) mark_static_address(tensor_value)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name)) source = GlobalWeakRefSource(global_name)
self.static_tensor_names.add(tx.output.module_key_name(builder.name)) self.static_tensor_names.add(tx.output.module_key_name(source.name))
result = builder(tensor_value) return VariableTracker.build(tx, tensor_value, source)
return result
def update_list_args( def update_list_args(
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
@ -367,14 +358,8 @@ class OptimizerVariable(UserDefinedObjectVariable):
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val)) arg.items.append(self.wrap_tensor(tx, val))
else: else:
from .builder import SourcelessBuilder, VariableBuilder source = arg.source and GetItemSource(arg.source, i)
arg.items.append(VariableTracker.build(tx, val, source))
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder.create(tx, val))
def create_finalizer(self, tx): def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names names_to_delete = self.static_tensor_names

View File

@ -5,12 +5,15 @@ from typing import TYPE_CHECKING
from ..bytecode_transformation import create_call_function from ..bytecode_transformation import create_call_function
from ..exc import Unsupported from ..exc import Unsupported
from ..source import AttrSource
from .base import VariableTracker from .base import VariableTracker
if TYPE_CHECKING: if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator from torch._dynamo.symbolic_convert import InstructionTranslator
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
class SDPAParamsVariable(VariableTracker): class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention. """Represents the c++ params struct for scaled dot product attention.
@ -20,35 +23,13 @@ class SDPAParamsVariable(VariableTracker):
def create(tx: "InstructionTranslator", value, source): def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable from .torch import TorchInGraphFunctionVariable
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) params = [
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) for p in PARAM_NAMES
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
value.attn_mask
)
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
value_var,
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
] ]
return TorchInGraphFunctionVariable(SDPAParams).call_function( return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
tx, param_vars, {}
)
def __init__(self, proxy, param_vars, **kwargs) -> None: def __init__(self, proxy, param_vars, **kwargs) -> None:
self.proxy = proxy self.proxy = proxy
@ -70,7 +51,6 @@ class SDPAParamsVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy from .builder import wrap_fx_proxy
from .misc import GetAttrVariable from .misc import GetAttrVariable

View File

@ -238,9 +238,7 @@ class TensorVariable(VariableTracker):
# any other attributes on the subclass (that are not methods) # any other attributes on the subclass (that are not methods)
# are assumed to be constant metadata. # are assumed to be constant metadata.
elif not callable(example_value): elif not callable(example_value):
from .builder import SourcelessBuilder return VariableTracker.build(tx, example_value)
return SourcelessBuilder.create(tx, example_value)
if not (self.source and self.source.subguards_allowed()): if not (self.source and self.source.subguards_allowed()):
raise NotImplementedError raise NotImplementedError
@ -277,12 +275,9 @@ class TensorVariable(VariableTracker):
# Note - at a certain point we may want to handle # Note - at a certain point we may want to handle
raise NotImplementedError raise NotImplementedError
from ..guards import GuardBuilder
from .builder import VariableBuilder
attr_source = AttrSource(self.source, name) attr_source = AttrSource(self.source, name)
install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
return VariableBuilder(tx, attr_source)(real_value) return VariableTracker.build(tx, real_value, attr_source)
def method_attr_ndim(self, tx): def method_attr_ndim(self, tx):
if self.ndim is not None: if self.ndim is not None:
@ -695,7 +690,6 @@ class TensorVariable(VariableTracker):
def method_as_subclass(self, cls): def method_as_subclass(self, cls):
if isinstance(cls, TensorSubclassVariable) and cls.source: if isinstance(cls, TensorSubclassVariable) and cls.source:
from ..symbolic_convert import InstructionTranslator from ..symbolic_convert import InstructionTranslator
from .builder import VariableBuilder
from .torch_function import TensorWithTFOverrideVariable from .torch_function import TensorWithTFOverrideVariable
tx = InstructionTranslator.current_tx() tx = InstructionTranslator.current_tx()
@ -705,10 +699,11 @@ class TensorVariable(VariableTracker):
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call. # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
# It is up to the user whether this is correct behavior or not. # It is up to the user whether this is correct behavior or not.
py_cls = cls.as_python_constant() py_cls = cls.as_python_constant()
torch_fn = VariableBuilder( torch_fn = VariableTracker.build(
tx, tx,
py_cls.__torch_function__.__func__,
AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"), AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
)(py_cls.__torch_function__.__func__) )
return TensorWithTFOverrideVariable.from_tensor_var( return TensorWithTFOverrideVariable.from_tensor_var(
tx, self, py_cls, torch_fn tx, self, py_cls, torch_fn
@ -750,7 +745,6 @@ class TensorVariable(VariableTracker):
def method_tolist(self): def method_tolist(self):
from ..symbolic_convert import InstructionTranslator from ..symbolic_convert import InstructionTranslator
from .builder import SourcelessBuilder
tx = InstructionTranslator.current_tx() tx = InstructionTranslator.current_tx()
@ -787,7 +781,7 @@ class TensorVariable(VariableTracker):
tensor = self.as_proxy().node.meta["example_value"] tensor = self.as_proxy().node.meta["example_value"]
out = tolist(tensor, self.as_proxy()) out = tolist(tensor, self.as_proxy())
return SourcelessBuilder.create(tx, out) return VariableTracker.build(tx, out)
def method_backward(self, *args, **kwargs): def method_backward(self, *args, **kwargs):
unimplemented("Tensor.backward") unimplemented("Tensor.backward")
@ -857,10 +851,9 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx() tx = InstructionTranslator.current_tx()
if value is not None: if value is not None:
from .. import polyfills from .. import polyfills
from .builder import SourcelessBuilder
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.addcmul_inplace), VariableTracker.build(tx, polyfills.addcmul_inplace),
[self, tensor1, tensor2, value], [self, tensor1, tensor2, value],
{}, {},
) )
@ -1155,9 +1148,7 @@ class SymNodeVariable(VariableTracker):
def as_tensor(self, tx): def as_tensor(self, tx):
if self._tensor_var is None: if self._tensor_var is None:
from .builder import SourcelessBuilder self._tensor_var = VariableTracker.build(
self._tensor_var = SourcelessBuilder.create(
tx, torch.scalar_tensor tx, torch.scalar_tensor
).call_function(tx, [self], {}) ).call_function(tx, [self], {})
return self._tensor_var return self._tensor_var
@ -1362,12 +1353,10 @@ class TensorSubclassVariable(VariableTracker):
kwargs: Dict[str, VariableTracker], kwargs: Dict[str, VariableTracker],
) -> VariableTracker: ) -> VariableTracker:
if len(args) == 1 and isinstance(args[0], TensorVariable): if len(args) == 1 and isinstance(args[0], TensorVariable):
from .builder import VariableBuilder
from .torch_function import TensorWithTFOverrideVariable from .torch_function import TensorWithTFOverrideVariable
torch_fn = VariableBuilder( source = AttrSource(self.source, "__torch_function__")
tx, AttrSource(self.source, "__torch_function__") torch_fn = VariableTracker.build(tx, self.value.__torch_function__, source)
)(self.value.__torch_function__)
return TensorWithTFOverrideVariable.from_tensor_var( return TensorWithTFOverrideVariable.from_tensor_var(
tx, args[0], self.value, torch_fn tx, args[0], self.value, torch_fn

View File

@ -397,7 +397,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
TensorVariable, TensorVariable,
UserDefinedObjectVariable, UserDefinedObjectVariable,
) )
from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
@register(*tracing_state_functions) @register(*tracing_state_functions)
def handle_tracing_state_functions( def handle_tracing_state_functions(
@ -422,14 +422,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# the set of functions that we trace __torch_function__ on to # the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing # functions outside of the actual set. Implementing this properly will require implementing
# some variable types to track and compare tensor getset descriptors # some variable types to track and compare tensor getset descriptors
return SourcelessBuilder.create( return VariableTracker.build(
tx, torch.overrides.get_default_nowrap_functions() tx, torch.overrides.get_default_nowrap_functions()
) )
@register(torch.ops.inductor.accumulate_grad_.default) @register(torch.ops.inductor.accumulate_grad_.default)
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
) )
@register(math.radians) @register(math.radians)
@ -437,7 +437,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if not check_unspec_or_constant_args(args, kwargs): if not check_unspec_or_constant_args(args, kwargs):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0 # Use polyfill to convert math.radians(x) into math.pi * x / 180.0
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.radians), args, kwargs VariableTracker.build(tx, polyfills.radians), args, kwargs
) )
@register(torch.is_tensor, torch.overrides.is_tensor_like) @register(torch.is_tensor, torch.overrides.is_tensor_like)
@ -622,7 +622,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
): ):
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace), VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
args, args,
kwargs, kwargs,
) )
@ -635,7 +635,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# in compile, it's more performant to not graph break. # in compile, it's more performant to not graph break.
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar), VariableTracker.build(tx, polyfills.foreach_pow_scalar),
args, args,
kwargs, kwargs,
) )
@ -704,7 +704,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# Note - while we *could* cook up sources around invocations, like a FunctionSource # Note - while we *could* cook up sources around invocations, like a FunctionSource
# the space of invoking functions in the middle of the guard chain is very iffy. As such, # the space of invoking functions in the middle of the guard chain is very iffy. As such,
# guard propagation via options is the best we can do. # guard propagation via options is the best we can do.
return SourcelessBuilder.create(tx, invocation_result) return VariableTracker.build(tx, invocation_result)
@register(DTensor.from_local) @register(DTensor.from_local)
def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
@ -1143,8 +1143,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
@staticmethod @staticmethod
def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
# Alternate version if we have a .source # Alternate version if we have a .source
from .builder import VariableBuilder
varname = tx.output.new_var() varname = tx.output.new_var()
# construct the nn.Parmeter before the graph save it to varname # construct the nn.Parmeter before the graph save it to varname
@ -1167,7 +1165,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
example_value = torch.nn.Parameter( example_value = torch.nn.Parameter(
tx.output.example_value_from_input_node(data.as_proxy().node) tx.output.example_value_from_input_node(data.as_proxy().node)
) )
result = VariableBuilder(tx, source)(example_value) result = VariableTracker.build(tx, example_value, source)
# No need to guard on this since we already guarded on `data`. # No need to guard on this since we already guarded on `data`.
# These guards would fail since varname doesn't exist until after the function starts # These guards would fail since varname doesn't exist until after the function starts
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(

View File

@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
if isinstance(var, TensorWithTFOverrideVariable): if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx) return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable): elif isinstance(var, UserDefinedObjectVariable):
from .builder import SourcelessBuilder, VariableBuilder source = var.source and TypeSource(var.source)
return VariableTracker.build(tx, var.python_type(), source)
if var.source:
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())
def _is_attr_overidden(tx: "InstructionTranslator", var, name): def _is_attr_overidden(tx: "InstructionTranslator", var, name):
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
def call_torch_function( def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs tx, torch_function_type, torch_function_var, fn, types, args, kwargs
): ):
from .builder import SourcelessBuilder
# signature: # signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None): # def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = ( tf_args = (
torch_function_type, torch_function_type,
fn, fn,
types, types,
SourcelessBuilder.create(tx, tuple(args)), VariableTracker.build(tx, tuple(args)),
SourcelessBuilder.create(tx, kwargs), VariableTracker.build(tx, kwargs),
) )
return tx.inline_user_function_return(torch_function_var, tf_args, {}) return tx.inline_user_function_return(torch_function_var, tf_args, {})
@ -515,20 +509,13 @@ def call_torch_function(
def build_torch_function_fn(tx: "InstructionTranslator", value, source): def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from types import FunctionType from types import FunctionType
from .builder import SourcelessBuilder, VariableBuilder
func = value.__torch_function__.__func__ func = value.__torch_function__.__func__
if not isinstance(func, FunctionType): if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI") unimplemented("Builtin/C++ torch function implementations NYI")
if source: source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
return VariableBuilder( return VariableTracker.build(tx, func, source)
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
# base tensors, custom attribute accesses will graph break. # base tensors, custom attribute accesses will graph break.
import torch import torch
from .builder import SourcelessBuilder
if name in banned_attrs: if name in banned_attrs:
unimplemented( unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
GuardBuilder.FUNCTION_MATCH GuardBuilder.FUNCTION_MATCH
) )
) )
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__) get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function( return self.call_torch_function(
tx, tx,
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
if tx.output.torch_function_enabled: if tx.output.torch_function_enabled:
import torch import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name): if _is_attr_overidden(tx, self, name):
unimplemented( unimplemented(
f"Calling overridden method {name} on a tensor" f"Calling overridden method {name} on a tensor"
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
# We've established with the above check that the method is not overridden, so we guard that the method is the same # We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it # as the impl defined on tensor and retrieve it
if self.source: if self.source:
func_var = VariableBuilder( source = AttrSource(AttrSource(self.source, "__class__"), name)
tx, AttrSource(AttrSource(self.source, "__class__"), name) value = inspect.getattr_static(self.python_type(), name)
)(inspect.getattr_static(self.python_type(), name))
else: else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) source = None
value = getattr(torch.Tensor, name)
func_var = VariableTracker.build(tx, value, source)
return dispatch_torch_function(tx, func_var, [self] + args, kwargs) return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else: else:
return super().call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs)

View File

@ -158,7 +158,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from . import ConstantVariable, EnumVariable from . import ConstantVariable, EnumVariable
from .builder import SourcelessBuilder, VariableBuilder
source = AttrSource(self.source, name) if self.source is not None else None source = AttrSource(self.source, name) if self.source is not None else None
@ -187,11 +186,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
obj = None obj = None
if isinstance(obj, staticmethod): if isinstance(obj, staticmethod):
func = obj.__get__(self.value) return VariableTracker.build(tx, obj.__get__(self.value), source)
if source is not None:
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
elif isinstance(obj, classmethod): elif isinstance(obj, classmethod):
if isinstance(obj.__func__, property): if isinstance(obj.__func__, property):
return variables.UserFunctionVariable(obj.__func__.fget).call_function( return variables.UserFunctionVariable(obj.__func__.fget).call_function(
@ -202,16 +197,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
# e.g.: inspect.getattr_static(dict, "fromkeys") # e.g.: inspect.getattr_static(dict, "fromkeys")
# inspect.getattr_static(itertools.chain, "from_iterable") # inspect.getattr_static(itertools.chain, "from_iterable")
func = obj.__get__(None, self.value) func = obj.__get__(None, self.value)
if source is not None: return VariableTracker.build(tx, func, source)
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
elif source: elif source:
# __mro__ is a member in < 3.12, an attribute in >= 3.12 # __mro__ is a member in < 3.12, an attribute in >= 3.12
if inspect.ismemberdescriptor(obj) or ( if inspect.ismemberdescriptor(obj) or (
sys.version_info >= (3, 12) and name == "__mro__" sys.version_info >= (3, 12) and name == "__mro__"
): ):
return VariableBuilder(tx, source)(obj.__get__(self.value)) return VariableTracker.build(tx, obj.__get__(self.value), source)
if ConstantVariable.is_literal(obj): if ConstantVariable.is_literal(obj):
return ConstantVariable.create(obj) return ConstantVariable.create(obj)
@ -222,14 +214,15 @@ class UserDefinedClassVariable(UserDefinedVariable):
or self.value.__module__ == "torch" or self.value.__module__ == "torch"
): ):
if source: if source:
return VariableBuilder(tx, source)(obj) return VariableTracker.build(tx, obj, source)
if ( if (
source source
and not inspect.ismethoddescriptor(obj) and not inspect.ismethoddescriptor(obj)
and not is_wrapper_or_member_descriptor(obj) and not is_wrapper_or_member_descriptor(obj)
): ):
return VariableBuilder(tx, source)(obj) return VariableTracker.build(tx, obj, source)
return super().var_getattr(tx, name) return super().var_getattr(tx, name)
def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs): def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
@ -341,7 +334,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
kwargs: "Dict[str, VariableTracker]", kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
from ..side_effects import SideEffects from ..side_effects import SideEffects
from .builder import SourcelessBuilder, wrap_fx_proxy from .builder import wrap_fx_proxy
from .builtin import BuiltinVariable from .builtin import BuiltinVariable
constant_args = check_constant_args(args, kwargs) constant_args = check_constant_args(args, kwargs)
@ -452,7 +445,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
field_var = kwargs[field_name] field_var = kwargs[field_name]
else: else:
assert field_name in field_defaults assert field_name in field_defaults
field_var = SourcelessBuilder.create( field_var = VariableTracker.build(
tx, field_defaults[field_name] tx, field_defaults[field_name]
) )
var_tracker_kwargs[field_name] = field_var var_tracker_kwargs[field_name] = field_var
@ -465,8 +458,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.NamedTupleVariable(items, self.value) return variables.NamedTupleVariable(items, self.value)
elif is_frozen_dataclass(self.value) and self.is_standard_new(): elif is_frozen_dataclass(self.value) and self.is_standard_new():
from .builder import SourcelessBuilder
fields = dataclasses.fields(self.value) fields = dataclasses.fields(self.value)
items = list(args) items = list(args)
items.extend([None] * (len(fields) - len(items))) items.extend([None] * (len(fields) - len(items)))
@ -481,9 +472,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
continue continue
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
var_tracker = SourcelessBuilder.create(tx, field.default) var_tracker = VariableTracker.build(tx, field.default)
elif field.default_factory is not dataclasses.MISSING: elif field.default_factory is not dataclasses.MISSING:
factory_fn = SourcelessBuilder.create( factory_fn = VariableTracker.build(
tx, field.default_factory tx, field.default_factory
) )
var_tracker = factory_fn.call_function(tx, [], {}) var_tracker = factory_fn.call_function(tx, [], {})
@ -573,7 +564,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source and self.source
): ):
return tx.inline_user_function_return( return tx.inline_user_function_return(
SourcelessBuilder.create( VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object tx, polyfills.instantiate_user_defined_class_object
), ),
[self, *args], [self, *args],
@ -857,7 +848,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
kwargs: "Dict[str, VariableTracker]", kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker": ) -> "VariableTracker":
from .. import trace_rules from .. import trace_rules
from .builder import VariableBuilder
if ( if (
self.is_supported_random() self.is_supported_random()
@ -894,9 +884,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
"Sourceless UserDefinedObjectVariable method not supported" "Sourceless UserDefinedObjectVariable method not supported"
) )
func_src = AttrSource(self.source, "__func__") func_src = AttrSource(self.source, "__func__")
func_var = VariableBuilder(tx, func_src)(func) func_var = VariableTracker.build(tx, func, func_src)
obj_src = AttrSource(self.source, "__self__") obj_src = AttrSource(self.source, "__self__")
obj_var = VariableBuilder(tx, obj_src)(obj) obj_var = VariableTracker.build(tx, obj, obj_src)
return func_var.call_function(tx, [obj_var] + args, kwargs) return func_var.call_function(tx, [obj_var] + args, kwargs)
elif ( elif (
istype(self.value, functools.partial) istype(self.value, functools.partial)
@ -998,7 +988,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def var_getattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name):
from .. import trace_rules from .. import trace_rules
from . import ConstantVariable from . import ConstantVariable
from .builder import SourcelessBuilder, VariableBuilder
source = AttrSource(self.source, name) if self.source else None source = AttrSource(self.source, name) if self.source else None
self._check_for_getattribute() self._check_for_getattribute()
@ -1090,10 +1079,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
elif isinstance(subobj, types.ClassMethodDescriptorType): elif isinstance(subobj, types.ClassMethodDescriptorType):
# e.g.: inspect.getattr_static({}, "fromkeys") # e.g.: inspect.getattr_static({}, "fromkeys")
func = subobj.__get__(self.value, None) func = subobj.__get__(self.value, None)
if source is not None: return VariableTracker.build(tx, func, source)
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor( elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
subobj.__get__ subobj.__get__
): ):
@ -1188,7 +1174,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
subobj_from_class, src_from_class subobj_from_class, src_from_class
) )
return SourcelessBuilder.create(tx, subobj) return VariableTracker.build(tx, subobj)
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError. # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
raise_observed_exception(AttributeError, tx) raise_observed_exception(AttributeError, tx)
@ -1212,7 +1198,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.ConstantVariable.create(False) return variables.ConstantVariable.create(False)
def odict_getitem(self, tx: "InstructionTranslator", key): def odict_getitem(self, tx: "InstructionTranslator", key):
from .builder import VariableBuilder
from .dicts import is_hashable from .dicts import is_hashable
# TODO this should probably be merged with the dict handling # TODO this should probably be merged with the dict handling
@ -1223,10 +1208,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else key.as_python_constant() else key.as_python_constant()
) )
return VariableBuilder( return VariableTracker.build(
tx, tx,
ODictGetItemSource(self.source, index), collections.OrderedDict.__getitem__(self.value, key.as_python_constant()),
)(collections.OrderedDict.__getitem__(self.value, key.as_python_constant())) self.source and ODictGetItemSource(self.source, index),
)
class FrozenDataClassVariable(UserDefinedObjectVariable): class FrozenDataClassVariable(UserDefinedObjectVariable):
@ -1236,14 +1222,14 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
assert is_frozen_dataclass(value) assert is_frozen_dataclass(value)
from .builder import VariableBuilder
field_map = {} field_map = {}
for field in fields(value): for field in fields(value):
if hasattr(value, field.name): if hasattr(value, field.name):
field_map[field.name] = VariableBuilder( field_map[field.name] = VariableTracker.build(
tx, AttrSource(source, field.name) tx,
)(getattr(value, field.name)) getattr(value, field.name),
source and AttrSource(source, field.name),
)
return FrozenDataClassVariable(value, fields=field_map, source=source) return FrozenDataClassVariable(value, fields=field_map, source=source)
@ -1315,16 +1301,8 @@ class WeakRefVariable(UserDefinedObjectVariable):
) -> "VariableTracker": ) -> "VariableTracker":
call_source = None call_source = None
referent = self.value() referent = self.value()
source = self.source and WeakRefCallSource(self.source)
if self.source: return VariableTracker.build(tx, referent, source)
from .builder import VariableBuilder
call_source = WeakRefCallSource(self.source)
return VariableBuilder(tx, call_source)(referent)
else:
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, referent)
class KeyedJaggedTensorVariable(UserDefinedObjectVariable): class KeyedJaggedTensorVariable(UserDefinedObjectVariable):