[dynamo] Simplify creation of VariableTrackers (#135714)

## `VariableTracker::build()` hides the Builders

### The problem

In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).

Variations on this code are repeated in many places.

More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.

### The solution

In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.

This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.

**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
Tom Ritchford
2024-10-17 16:21:48 +00:00
committed by PyTorch MergeBot
parent 1581a93e87
commit e1c4548441
18 changed files with 180 additions and 333 deletions

View File

@ -91,7 +91,6 @@ from .variables.builder import (
BackwardStateGraphArg,
GraphArg,
TrackedFake,
VariableBuilder,
wrap_fx_proxy,
)
from .variables.lists import BaseListVariable
@ -498,7 +497,7 @@ class OutputGraph:
cg.store(varname)
self.pregraph_bytecode.extend(cg.get_instructions())
source = SyntheticLocalSource(varname)
result = VariableBuilder(self.root_tx, source)(example_value)
result = VariableTracker.build(self.root_tx, example_value, source)
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
source
)
@ -767,8 +766,8 @@ class OutputGraph:
):
if is_dynamic_nn_module(target, self.root_tx.export):
# Instead of returning UnspecializedNNModuleVariable, call
# VariableBuilder so that it is tracked for mutation.
return VariableBuilder(self.current_tx, **options)(target)
# VariableTracker.build so that it is tracked for mutation.
return VariableTracker.build(self.current_tx, target, **options)
options = dict(options)
assert "source" in options
@ -860,8 +859,8 @@ class OutputGraph:
def wrap_name(module_key):
self.output.update_co_names(module_key)
self.global_scope[module_key] = target
return VariableBuilder(self, ConstantSource(source_name=module_key))(
target
return VariableTracker.build(
self, target, ConstantSource(source_name=module_key)
)
for k, v in self.nn_modules.items():

View File

@ -71,7 +71,7 @@ from .utils import (
proxy_args_kwargs,
)
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
from .variables.builder import wrap_fx_proxy
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable
from .variables.ctx_manager import (
@ -1224,15 +1224,14 @@ class InstructionTranslatorBase(
except KeyError:
return self.load_builtin(inst)
source = GlobalSource(name)
self.push(VariableBuilder(self, source)(value))
self.push(VariableTracker.build(self, value, GlobalSource(name)))
@functools.cached_property
def nn_modules_globals_vt(self):
module_name = "torch.nn.modules.module"
module_source = self.import_source(module_name)
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
return VariableBuilder(self, module_source)(fglobals_value)
return VariableTracker.build(self, fglobals_value, module_source)
def LOAD_GLOBAL(self, inst):
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
@ -1374,7 +1373,7 @@ class InstructionTranslatorBase(
self.output.name_of_builtins_dict_key_in_fglobals
)
var_source = GetItemSource(builtins_source, argval)
self.push(VariableBuilder(self, var_source)(val))
self.push(VariableTracker.build(self, val, var_source))
else:
assert is_builtin_constant(val)
self.push(ConstantVariable.create(value=val))
@ -3403,7 +3402,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
else:
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
global_source = AttrSource(module_source, name)
else:
globals_name = self.output.install_global_by_id(
@ -3411,7 +3410,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
)
globals_source = GlobalSource(globals_name)
fglobals_value = self.f_globals # type: ignore[assignment]
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
return fglobals_value, fglobals_vt, global_source
@ -3430,7 +3429,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
except KeyError:
return self.load_builtin(inst)
self.push(VariableBuilder(self, global_source)(value))
self.push(VariableTracker.build(self, value, global_source))
def STORE_GLOBAL(self, inst):
if self.f_globals is self.parent.f_globals:

View File

@ -12,7 +12,7 @@ from ..utils import istype
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
class MutableLocalSource(Enum):
@ -121,6 +121,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
VariableTracker instances are immutable and should be copied in
order to change them.
Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
"""
# fields to leave unmodified in apply()
@ -244,9 +246,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError
source = None
if self.source:
source = AttrSource(self.source, name)
source = self.source and AttrSource(self.source, name)
return variables.ConstantVariable.create(value, source=source)
def is_proxy(self):
@ -363,6 +363,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
def is_strict_mode(self, tx):
return tx.strict_checks_fn and tx.strict_checks_fn(self)
@staticmethod
def build(
tx: "InstructionTranslatorBase",
value: Any,
source: Optional[Source] = None,
) -> Any:
"""Create a new VariableTracker from a value and optional Source"""
from . import builder
if source is None:
return builder.SourcelessBuilder.create(tx, value)
else:
return builder.VariableBuilder(tx, source)(value)
def __init__(
self,
*,

View File

@ -701,7 +701,6 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
from .builder import SourcelessBuilder
from .lazy import LazyVariableTracker
obj = BuiltinVariable(fn)
@ -794,8 +793,6 @@ class BuiltinVariable(VariableTracker):
handlers.append(call_self_handler)
if obj.can_constant_fold_through():
builder = SourcelessBuilder.create
if (
all(issubclass(x, ConstantVariable) for x in arg_types)
and not has_kwargs
@ -809,7 +806,7 @@ class BuiltinVariable(VariableTracker):
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)
return VariableTracker.build(tx, res)
else:
@ -825,7 +822,7 @@ class BuiltinVariable(VariableTracker):
)
except Exception as exc:
unimplemented(f"constant fold exception: {repr(exc)}")
return builder(tx, res)
return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler)
@ -1361,8 +1358,6 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
from .builder import SourcelessBuilder
if not kwargs:
if not args:
args = ({},)
@ -1399,7 +1394,7 @@ class BuiltinVariable(VariableTracker):
)
new_dict = dict(arg.value.items())
return SourcelessBuilder.create(tx, new_dict)
return VariableTracker.build(tx, new_dict)
else:
func_var = arg.var_getattr(tx, "items")
if not isinstance(func_var, variables.UserFunctionVariable):
@ -1631,7 +1626,6 @@ class BuiltinVariable(VariableTracker):
TorchInGraphFunctionVariable,
UserFunctionVariable,
)
from .builder import SourcelessBuilder, VariableBuilder
name = name_var.as_python_constant()
@ -1666,34 +1660,21 @@ class BuiltinVariable(VariableTracker):
if not hasattr_var.as_python_constant():
return default
options = {}
if obj.source:
source = AttrSource(obj.source, name)
options["source"] = source
else:
source = None
source = obj.source and AttrSource(obj.source, name)
if name in {"__bases__", "__base__", "__flags__"}:
try:
value = obj.as_python_constant()
if isinstance(value, type):
if name == "__bases__":
bases = value.__bases__
if source is not None:
tuple_args = [
VariableBuilder(tx, GetItemSource(source, i))(b)
for i, b in enumerate(bases)
]
else:
tuple_args = [
SourcelessBuilder.create(tx, b) for b in bases
]
return variables.TupleVariable(tuple_args, **options)
tuple_args = [
VariableTracker.build(
tx, b, source and GetItemSource(source, i)
)
for i, b in enumerate(value.__bases__)
]
return variables.TupleVariable(tuple_args, source=source)
if name == "__base__":
base = value.__base__
if source is not None:
return VariableBuilder(tx, source)(base)
return SourcelessBuilder.create(tx, base)
return VariableTracker.build(tx, value.__base__, source)
if name == "__flags__":
return ConstantVariable.create(value.__flags__)
except NotImplementedError:
@ -1715,14 +1696,14 @@ class BuiltinVariable(VariableTracker):
try:
return obj.var_getattr(tx, name)
except NotImplementedError:
return GetAttrVariable(obj, name, **options)
return GetAttrVariable(obj, name, source=source)
elif isinstance(obj, TorchInGraphFunctionVariable):
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
member = getattr(obj.value, name)
if isinstance(
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
) and trace_rules.is_aten_op_or_tensor_method(member):
return TorchInGraphFunctionVariable(member, **options)
return TorchInGraphFunctionVariable(member, source=source)
elif isinstance(obj, DummyModule):
# TODO(mlazos) - Do we need this?
if obj.is_torch or name not in obj.value.__dict__:
@ -1732,18 +1713,15 @@ class BuiltinVariable(VariableTracker):
if config.replay_record_enabled:
tx.exec_recorder.record_module_access(obj.value, name, member)
return VariableTracker.build(tx, member, source)
if source is not None:
return VariableBuilder(tx, source)(member)
else:
return SourcelessBuilder.create(tx, member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create(getattr(obj.fn, name))
else:
try:
return obj.var_getattr(tx, name)
except NotImplementedError:
return GetAttrVariable(obj, name, **options)
return GetAttrVariable(obj, name, source=source)
def call_setattr(
self,
@ -1882,8 +1860,6 @@ class BuiltinVariable(VariableTracker):
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder
try:
py_type = obj.python_type()
except NotImplementedError as error:
@ -1893,10 +1869,8 @@ class BuiltinVariable(VariableTracker):
case_name="unknown_python_type",
) from None
if obj.source is None:
return SourcelessBuilder.create(tx, py_type)
else:
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
source = obj.source and TypeSource(obj.source)
return VariableTracker.build(tx, py_type, source)
def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
if obj.has_unpack_var_sequence(tx):

View File

@ -984,12 +984,10 @@ class HFPretrainedConfigVariable(VariableTracker):
assert self.is_matching_cls(type(obj))
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from .builder import VariableBuilder
try:
attr_value = getattr(self.obj, name)
attr_source = AttrSource(self.source, name)
return VariableBuilder(tx, attr_source)(attr_value)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, attr_value, source)
except AttributeError:
unimplemented(f"getattr({self.value}, {name})")
@ -1053,15 +1051,11 @@ class PythonSysModulesVariable(VariableTracker):
key: VariableTracker,
default: Optional[VariableTracker] = None,
):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
if has_key:
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
source = self.source and GetItemSource(self.source, k)
return VariableTracker.build(tx, sys.modules[k], source)
if default is not None:
return default
@ -1069,10 +1063,6 @@ class PythonSysModulesVariable(VariableTracker):
return ConstantVariable.create(value=None)
def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
source = self.source and GetItemSource(self.source, k)
return VariableTracker.build(tx, sys.modules[k], source)

View File

@ -46,9 +46,7 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
if isinstance(val, VariableTracker):
return val
elif not source:
from torch._dynamo.variables.builder import SourcelessBuilder
return SourcelessBuilder.create(tx, val)
return VariableTracker.build(tx, val)
else:
# Create a lazy variable to avoid guarding on __defaults__ unless really
# needed.
@ -240,8 +238,6 @@ class UserFunctionVariable(BaseUserFunctionVariable):
# optimization for cleaner codegen
result[name] = var
elif self.source:
from .builder import VariableBuilder
side_effects = parent.output.side_effects
if cell in side_effects:
out = side_effects[cell]
@ -253,9 +249,9 @@ class UserFunctionVariable(BaseUserFunctionVariable):
closure_cell, "cell_contents"
)
try:
contents_var = VariableBuilder(
parent, closure_cell_contents
)(cell.cell_contents)
contents_var = VariableTracker.build(
parent, cell.cell_contents, closure_cell_contents
)
except ValueError:
# Cell has not yet been assigned
contents_var = variables.DeletedVariable()
@ -286,9 +282,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
result[name] = out
else:
from .builder import SourcelessBuilder
result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
result[name] = VariableTracker.build(tx, cell.cell_contents)
return result, closure_cells
@ -296,17 +290,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
pass
def var_getattr(self, tx: "InstructionTranslator", name: str):
source = AttrSource(self.source, name) if self.source else None
source = self.source and AttrSource(self.source, name)
try:
subobj = inspect.getattr_static(self.fn, name)
except AttributeError:
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
return variables.GetAttrVariable(self, name, source=source)
if source:
return variables.LazyVariableTracker.create(subobj, source)
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, subobj)
return VariableTracker.build(tx, subobj)
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
result = hasattr(self.fn, name)
@ -757,14 +748,8 @@ class WrapperUserFunctionVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name):
if name == self.attr_to_trace:
val = getattr(self.wrapper_obj, self.attr_to_trace)
if self.source:
from .builder import VariableBuilder
return VariableBuilder(tx, AttrSource(self.source, name))(val)
else:
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, val)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, val, source)
return super().var_getattr(tx, name)
@ -999,8 +984,6 @@ class PolyfilledFunctionVariable(VariableTracker):
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from torch._dynamo.variables.builder import SourcelessBuilder
if self.can_constant_fold_through() and check_unspec_or_constant_args(
args, kwargs
):
@ -1010,9 +993,9 @@ class PolyfilledFunctionVariable(VariableTracker):
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
)
return SourcelessBuilder.create(tx, result)
return VariableTracker.build(tx, result)
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn)
traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
return traceable_function_variable.call_function(tx, args, kwargs)
def call_method(

View File

@ -13,7 +13,6 @@ import torch.fx
import torch.nn
from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.builtin import BuiltinVariable
from torch._dynamo.variables.functions import UserFunctionVariable
from torch._dynamo.variables.tensor import SymNodeVariable
@ -31,6 +30,7 @@ from ..exc import (
)
from ..source import AttrSource
from ..utils import proxy_args_kwargs
from .base import VariableTracker
from .dicts import ConstDictVariable
from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable
@ -1040,7 +1040,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
from .builder import SourcelessBuilder, wrap_fx_proxy
from .builder import wrap_fx_proxy
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
@ -1062,7 +1062,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
tx,
"new_empty",
args=(
SourcelessBuilder.create(
VariableTracker.build(
tx,
leaf.size
if leaf.size is not None
@ -1072,8 +1072,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
),
),
kwargs={
"dtype": SourcelessBuilder.create(tx, leaf.dtype),
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad),
"dtype": VariableTracker.build(tx, leaf.dtype),
"requires_grad": VariableTracker.build(tx, leaf.requires_grad),
},
)
for leaf in itertools.chain(xs.items, xs.items)
@ -2057,7 +2057,6 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
fn_name: str,
):
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
from .builder import SourcelessBuilder
tx: InstructionTranslator = tx
@ -2065,9 +2064,9 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
return query.call_method(
tx,
"new_empty",
(SourcelessBuilder.create(tx, []),),
(VariableTracker.build(tx, []),),
{
"dtype": SourcelessBuilder.create(tx, torch.int32),
"dtype": VariableTracker.build(tx, torch.int32),
},
)
@ -2077,8 +2076,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
score = query.call_method(
tx,
"new_empty",
(SourcelessBuilder.create(tx, []),),
{"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)},
(VariableTracker.build(tx, []),),
{"requires_grad": VariableTracker.build(tx, scores_require_grad)},
)
new_args = [score, *bhmn]
else:

View File

@ -172,10 +172,8 @@ class ItertoolsVariable(VariableTracker):
*args, mutable_local=MutableLocal()
)
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
VariableTracker.build(tx, polyfills.repeat), args, kwargs
)
elif self.value is itertools.count:
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())

View File

@ -20,14 +20,15 @@ class LazyCache:
def realize(self) -> None:
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
from .builder import SourcelessBuilder, VariableBuilder
tx = InstructionTranslator.current_tx()
if isinstance(self.value, LazySymNodeFormatString):
self.vt = SourcelessBuilder.create(tx, self.value)
else:
self.vt = VariableBuilder(tx, self.source)(self.value)
if isinstance(self.value, LazySymNodeFormatString):
source = None
else:
source = self.source
self.vt = VariableTracker.build(tx, self.value, source)
del self.value
del self.source
@ -37,7 +38,7 @@ class LazyVariableTracker(VariableTracker):
A structure that defers the creation of the actual VariableTracker
for a given underlying value until it is accessed.
The `realize` function invokes VariableBuilder to produce the real object.
The `realize` function invokes VariableTracker.build() to produce the real object.
Once a LazyVariableTracker has been realized, internal bookkeeping will
prevent double realization.

View File

@ -135,10 +135,8 @@ class BaseListVariable(VariableTracker):
assert not kwargs
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
elif name == "index":
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.index),
VariableTracker.build(tx, polyfills.index),
[self] + list(args),
kwargs,
)

View File

@ -207,12 +207,10 @@ class SuperVariable(VariableTracker):
and len(kwargs) == 0
and args[0].is_python_constant()
):
from .builder import VariableBuilder
key = args[0].as_python_constant()
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
collections.OrderedDict.__getitem__(self.objvar.value, key)
)
value = collections.OrderedDict.__getitem__(self.objvar.value, key)
source = ODictGetItemSource(self.objvar.source, key)
return VariableTracker.build(tx, value, source)
elif inner_fn in (
collections.OrderedDict.__setitem__,
object.__setattr__,
@ -467,15 +465,10 @@ class InspectParameterVariable(VariableTracker):
self.value = value
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from .builder import SourcelessBuilder, VariableBuilder
try:
attr_value = getattr(self.value, name)
if self.source:
attr_source = AttrSource(self.source, name)
return VariableBuilder(tx, attr_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, attr_value, source)
except AttributeError:
unimplemented(f"getattr({self.value}, {name})")
@ -909,11 +902,9 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
if self.needs_input_grad is not None:
return variables.ConstantVariable.create(self.needs_input_grad)
if self.source:
from .builder import VariableBuilder
source = AttrSource(self.source, "needs_input_grad")
return VariableTracker.build(tx, self.value.needs_input_grad, source)
return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
self.value.needs_input_grad
)
return super().var_getattr(tx, name)
@ -1118,11 +1109,8 @@ class GetSetDescriptorVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name):
if name == "__get__" and self.source:
from .builder import VariableBuilder
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
self.desc.__get__
)
source = AttrSource(self.source, "__get__")
return VariableTracker.build(tx, self.desc.__get__, source)
else:
return super().var_getattr(tx, name)
@ -1162,18 +1150,13 @@ class PythonModuleVariable(VariableTracker):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)
from .builder import SourcelessBuilder, VariableBuilder
if self.is_torch or name not in self.value.__dict__:
attr_value = getattr(self.value, name)
else:
attr_value = self.value.__dict__[name]
if self.source:
new_source = AttrSource(self.source, name)
return VariableBuilder(tx, new_source)(attr_value)
else:
return SourcelessBuilder.create(tx, attr_value)
source = self.source and AttrSource(self.source, name)
return VariableTracker.build(tx, attr_value, source)
class TypingVariable(VariableTracker):

View File

@ -244,12 +244,7 @@ class NNModuleVariable(VariableTracker):
)
def var_getattr(self, tx: "InstructionTranslator", name):
from .builder import VariableBuilder
if self.source:
source = AttrSource(self.source, name)
else:
source = None
source = self.source and AttrSource(self.source, name)
base = tx.output.get_submodule(self.module_key)
base_dict = object.__getattribute__(base, "__dict__")
@ -297,7 +292,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserDefinedClassVariable(base.__class__, source=source)
if object_member:
out = VariableBuilder(tx, NNModuleSource(source))(subobj)
out = VariableTracker.build(tx, subobj, NNModuleSource(source))
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
# nn_module_stack source is BC surface area. Ensure that
@ -333,7 +328,7 @@ class NNModuleVariable(VariableTracker):
return variables.UserMethodVariable(subobj, self, source=source)
elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj)
return VariableTracker.build(tx, subobj, NNModuleSource(source))
else:
unimplemented(
f"class property {name} - {typestr(base)} {typestr(subobj)}"
@ -1083,7 +1078,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
)
return variables.ConstDictVariable({})
# For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable.
# For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable.
# However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
# differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
# key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a

View File

@ -17,6 +17,7 @@ from ..source import (
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
@ -27,8 +28,6 @@ from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .base import VariableTracker
class ArgMappingException(Exception):
pass
@ -147,7 +146,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def _set_capturable(self, tx):
from . import LazyVariableTracker
from .builder import VariableBuilder
# We only set capturable if params are on cuda
# and the state is not initialized
@ -168,10 +166,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
if safe_to_set_capturable(group):
group["capturable"] = True
source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
VariableTracker.build(tx, self.value.param_groups, source)
)
for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker(
@ -214,7 +211,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def map_sources_and_install_guards(self, tx):
from ..decorators import mark_static_address
from .builder import VariableBuilder
from .lazy import LazyVariableTracker
self.grad_to_source = {}
@ -235,15 +231,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
# Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards.
params_groups_source = self.source and AttrSource(self.source, "param_groups")
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
VariableTracker.build(tx, self.value.param_groups, params_groups_source)
)
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
state_source = self.source and AttrSource(self.source, "state")
state_vt = VariableTracker.build(tx, self.value.state, state_source)
# We need to realize the top level state dict to populate
# the guard locals
@ -265,15 +259,15 @@ class OptimizerVariable(UserDefinedObjectVariable):
key_index = i
break
if key_index:
state_source = AttrSource(self.source, "state")
LazyVariableTracker.realize_all(
VariableBuilder(
VariableTracker.build(
tx,
self.value.state[param],
GetItemSource(
state_source,
ConstDictKeySource(state_source, key_index),
),
)(self.value.state[param])
)
)
break
@ -312,7 +306,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
# We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer.
state_source = AttrSource(self.source, "state")
for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx)
@ -328,7 +321,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
"""Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
@ -338,20 +330,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
source = self.tensor_to_source[tensor_value]
self.static_tensor_names.add(tx.output.module_key_name(source.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
source = self.grad_to_source[tensor_value]
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
source = GlobalWeakRefSource(global_name)
self.static_tensor_names.add(tx.output.module_key_name(source.name))
result = builder(tensor_value)
return result
return VariableTracker.build(tx, tensor_value, source)
def update_list_args(
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
@ -367,14 +358,8 @@ class OptimizerVariable(UserDefinedObjectVariable):
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder.create(tx, val))
source = arg.source and GetItemSource(arg.source, i)
arg.items.append(VariableTracker.build(tx, val, source))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names

View File

@ -5,12 +5,15 @@ from typing import TYPE_CHECKING
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from ..source import AttrSource
from .base import VariableTracker
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
@ -20,35 +23,13 @@ class SDPAParamsVariable(VariableTracker):
def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
value.attn_mask
)
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
value.enable_gqa
)
param_vars = [
query_var,
key_var,
value_var,
attn_mask_var,
dropout_var,
is_causal_var,
enable_gqa_var,
params = [
VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
for p in PARAM_NAMES
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}
)
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
def __init__(self, proxy, param_vars, **kwargs) -> None:
self.proxy = proxy
@ -70,7 +51,6 @@ class SDPAParamsVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable

View File

@ -238,9 +238,7 @@ class TensorVariable(VariableTracker):
# any other attributes on the subclass (that are not methods)
# are assumed to be constant metadata.
elif not callable(example_value):
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, example_value)
return VariableTracker.build(tx, example_value)
if not (self.source and self.source.subguards_allowed()):
raise NotImplementedError
@ -277,12 +275,9 @@ class TensorVariable(VariableTracker):
# Note - at a certain point we may want to handle
raise NotImplementedError
from ..guards import GuardBuilder
from .builder import VariableBuilder
attr_source = AttrSource(self.source, name)
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
return VariableBuilder(tx, attr_source)(real_value)
return VariableTracker.build(tx, real_value, attr_source)
def method_attr_ndim(self, tx):
if self.ndim is not None:
@ -695,7 +690,6 @@ class TensorVariable(VariableTracker):
def method_as_subclass(self, cls):
if isinstance(cls, TensorSubclassVariable) and cls.source:
from ..symbolic_convert import InstructionTranslator
from .builder import VariableBuilder
from .torch_function import TensorWithTFOverrideVariable
tx = InstructionTranslator.current_tx()
@ -705,10 +699,11 @@ class TensorVariable(VariableTracker):
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
# It is up to the user whether this is correct behavior or not.
py_cls = cls.as_python_constant()
torch_fn = VariableBuilder(
torch_fn = VariableTracker.build(
tx,
py_cls.__torch_function__.__func__,
AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
)(py_cls.__torch_function__.__func__)
)
return TensorWithTFOverrideVariable.from_tensor_var(
tx, self, py_cls, torch_fn
@ -750,7 +745,6 @@ class TensorVariable(VariableTracker):
def method_tolist(self):
from ..symbolic_convert import InstructionTranslator
from .builder import SourcelessBuilder
tx = InstructionTranslator.current_tx()
@ -787,7 +781,7 @@ class TensorVariable(VariableTracker):
tensor = self.as_proxy().node.meta["example_value"]
out = tolist(tensor, self.as_proxy())
return SourcelessBuilder.create(tx, out)
return VariableTracker.build(tx, out)
def method_backward(self, *args, **kwargs):
unimplemented("Tensor.backward")
@ -857,10 +851,9 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx()
if value is not None:
from .. import polyfills
from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.addcmul_inplace),
VariableTracker.build(tx, polyfills.addcmul_inplace),
[self, tensor1, tensor2, value],
{},
)
@ -1155,9 +1148,7 @@ class SymNodeVariable(VariableTracker):
def as_tensor(self, tx):
if self._tensor_var is None:
from .builder import SourcelessBuilder
self._tensor_var = SourcelessBuilder.create(
self._tensor_var = VariableTracker.build(
tx, torch.scalar_tensor
).call_function(tx, [self], {})
return self._tensor_var
@ -1362,12 +1353,10 @@ class TensorSubclassVariable(VariableTracker):
kwargs: Dict[str, VariableTracker],
) -> VariableTracker:
if len(args) == 1 and isinstance(args[0], TensorVariable):
from .builder import VariableBuilder
from .torch_function import TensorWithTFOverrideVariable
torch_fn = VariableBuilder(
tx, AttrSource(self.source, "__torch_function__")
)(self.value.__torch_function__)
source = AttrSource(self.source, "__torch_function__")
torch_fn = VariableTracker.build(tx, self.value.__torch_function__, source)
return TensorWithTFOverrideVariable.from_tensor_var(
tx, args[0], self.value, torch_fn

View File

@ -397,7 +397,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
TensorVariable,
UserDefinedObjectVariable,
)
from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
@register(*tracing_state_functions)
def handle_tracing_state_functions(
@ -422,14 +422,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing
# some variable types to track and compare tensor getset descriptors
return SourcelessBuilder.create(
return VariableTracker.build(
tx, torch.overrides.get_default_nowrap_functions()
)
@register(torch.ops.inductor.accumulate_grad_.default)
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
)
@register(math.radians)
@ -437,7 +437,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
if not check_unspec_or_constant_args(args, kwargs):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.radians), args, kwargs
VariableTracker.build(tx, polyfills.radians), args, kwargs
)
@register(torch.is_tensor, torch.overrides.is_tensor_like)
@ -622,7 +622,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace),
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
args,
kwargs,
)
@ -635,7 +635,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# in compile, it's more performant to not graph break.
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar),
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
args,
kwargs,
)
@ -704,7 +704,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
# Note - while we *could* cook up sources around invocations, like a FunctionSource
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
# guard propagation via options is the best we can do.
return SourcelessBuilder.create(tx, invocation_result)
return VariableTracker.build(tx, invocation_result)
@register(DTensor.from_local)
def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
@ -1143,8 +1143,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
@staticmethod
def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
# Alternate version if we have a .source
from .builder import VariableBuilder
varname = tx.output.new_var()
# construct the nn.Parmeter before the graph save it to varname
@ -1167,7 +1165,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
example_value = torch.nn.Parameter(
tx.output.example_value_from_input_node(data.as_proxy().node)
)
result = VariableBuilder(tx, source)(example_value)
result = VariableTracker.build(tx, example_value, source)
# No need to guard on this since we already guarded on `data`.
# These guards would fail since varname doesn't exist until after the function starts
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(

View File

@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable):
from .builder import SourcelessBuilder, VariableBuilder
if var.source:
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())
source = var.source and TypeSource(var.source)
return VariableTracker.build(tx, var.python_type(), source)
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
):
from .builder import SourcelessBuilder
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
torch_function_type,
fn,
types,
SourcelessBuilder.create(tx, tuple(args)),
SourcelessBuilder.create(tx, kwargs),
VariableTracker.build(tx, tuple(args)),
VariableTracker.build(tx, kwargs),
)
return tx.inline_user_function_return(torch_function_var, tf_args, {})
@ -515,20 +509,13 @@ def call_torch_function(
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from types import FunctionType
from .builder import SourcelessBuilder, VariableBuilder
func = value.__torch_function__.__func__
if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI")
if source:
return VariableBuilder(
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
return VariableTracker.build(tx, func, source)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
GuardBuilder.FUNCTION_MATCH
)
)
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
if tx.output.torch_function_enabled:
import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
# We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(self.python_type(), name))
source = AttrSource(AttrSource(self.source, "__class__"), name)
value = inspect.getattr_static(self.python_type(), name)
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
source = None
value = getattr(torch.Tensor, name)
func_var = VariableTracker.build(tx, value, source)
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)

View File

@ -158,7 +158,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
from . import ConstantVariable, EnumVariable
from .builder import SourcelessBuilder, VariableBuilder
source = AttrSource(self.source, name) if self.source is not None else None
@ -187,11 +186,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
obj = None
if isinstance(obj, staticmethod):
func = obj.__get__(self.value)
if source is not None:
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
return VariableTracker.build(tx, obj.__get__(self.value), source)
elif isinstance(obj, classmethod):
if isinstance(obj.__func__, property):
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
@ -202,16 +197,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
# e.g.: inspect.getattr_static(dict, "fromkeys")
# inspect.getattr_static(itertools.chain, "from_iterable")
func = obj.__get__(None, self.value)
if source is not None:
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
return VariableTracker.build(tx, func, source)
elif source:
# __mro__ is a member in < 3.12, an attribute in >= 3.12
if inspect.ismemberdescriptor(obj) or (
sys.version_info >= (3, 12) and name == "__mro__"
):
return VariableBuilder(tx, source)(obj.__get__(self.value))
return VariableTracker.build(tx, obj.__get__(self.value), source)
if ConstantVariable.is_literal(obj):
return ConstantVariable.create(obj)
@ -222,14 +214,15 @@ class UserDefinedClassVariable(UserDefinedVariable):
or self.value.__module__ == "torch"
):
if source:
return VariableBuilder(tx, source)(obj)
return VariableTracker.build(tx, obj, source)
if (
source
and not inspect.ismethoddescriptor(obj)
and not is_wrapper_or_member_descriptor(obj)
):
return VariableBuilder(tx, source)(obj)
return VariableTracker.build(tx, obj, source)
return super().var_getattr(tx, name)
def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
@ -341,7 +334,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from ..side_effects import SideEffects
from .builder import SourcelessBuilder, wrap_fx_proxy
from .builder import wrap_fx_proxy
from .builtin import BuiltinVariable
constant_args = check_constant_args(args, kwargs)
@ -452,7 +445,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
field_var = kwargs[field_name]
else:
assert field_name in field_defaults
field_var = SourcelessBuilder.create(
field_var = VariableTracker.build(
tx, field_defaults[field_name]
)
var_tracker_kwargs[field_name] = field_var
@ -465,8 +458,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.NamedTupleVariable(items, self.value)
elif is_frozen_dataclass(self.value) and self.is_standard_new():
from .builder import SourcelessBuilder
fields = dataclasses.fields(self.value)
items = list(args)
items.extend([None] * (len(fields) - len(items)))
@ -481,9 +472,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
continue
if field.default is not dataclasses.MISSING:
var_tracker = SourcelessBuilder.create(tx, field.default)
var_tracker = VariableTracker.build(tx, field.default)
elif field.default_factory is not dataclasses.MISSING:
factory_fn = SourcelessBuilder.create(
factory_fn = VariableTracker.build(
tx, field.default_factory
)
var_tracker = factory_fn.call_function(tx, [], {})
@ -573,7 +564,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
and self.source
):
return tx.inline_user_function_return(
SourcelessBuilder.create(
VariableTracker.build(
tx, polyfills.instantiate_user_defined_class_object
),
[self, *args],
@ -857,7 +848,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .. import trace_rules
from .builder import VariableBuilder
if (
self.is_supported_random()
@ -894,9 +884,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
"Sourceless UserDefinedObjectVariable method not supported"
)
func_src = AttrSource(self.source, "__func__")
func_var = VariableBuilder(tx, func_src)(func)
func_var = VariableTracker.build(tx, func, func_src)
obj_src = AttrSource(self.source, "__self__")
obj_var = VariableBuilder(tx, obj_src)(obj)
obj_var = VariableTracker.build(tx, obj, obj_src)
return func_var.call_function(tx, [obj_var] + args, kwargs)
elif (
istype(self.value, functools.partial)
@ -998,7 +988,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
def var_getattr(self, tx: "InstructionTranslator", name):
from .. import trace_rules
from . import ConstantVariable
from .builder import SourcelessBuilder, VariableBuilder
source = AttrSource(self.source, name) if self.source else None
self._check_for_getattribute()
@ -1090,10 +1079,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
elif isinstance(subobj, types.ClassMethodDescriptorType):
# e.g.: inspect.getattr_static({}, "fromkeys")
func = subobj.__get__(self.value, None)
if source is not None:
return VariableBuilder(tx, source)(func)
else:
return SourcelessBuilder.create(tx, func)
return VariableTracker.build(tx, func, source)
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
subobj.__get__
):
@ -1188,7 +1174,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
subobj_from_class, src_from_class
)
return SourcelessBuilder.create(tx, subobj)
return VariableTracker.build(tx, subobj)
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
raise_observed_exception(AttributeError, tx)
@ -1212,7 +1198,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
return variables.ConstantVariable.create(False)
def odict_getitem(self, tx: "InstructionTranslator", key):
from .builder import VariableBuilder
from .dicts import is_hashable
# TODO this should probably be merged with the dict handling
@ -1223,10 +1208,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
else key.as_python_constant()
)
return VariableBuilder(
return VariableTracker.build(
tx,
ODictGetItemSource(self.source, index),
)(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
collections.OrderedDict.__getitem__(self.value, key.as_python_constant()),
self.source and ODictGetItemSource(self.source, index),
)
class FrozenDataClassVariable(UserDefinedObjectVariable):
@ -1236,14 +1222,14 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
assert is_frozen_dataclass(value)
from .builder import VariableBuilder
field_map = {}
for field in fields(value):
if hasattr(value, field.name):
field_map[field.name] = VariableBuilder(
tx, AttrSource(source, field.name)
)(getattr(value, field.name))
field_map[field.name] = VariableTracker.build(
tx,
getattr(value, field.name),
source and AttrSource(source, field.name),
)
return FrozenDataClassVariable(value, fields=field_map, source=source)
@ -1315,16 +1301,8 @@ class WeakRefVariable(UserDefinedObjectVariable):
) -> "VariableTracker":
call_source = None
referent = self.value()
if self.source:
from .builder import VariableBuilder
call_source = WeakRefCallSource(self.source)
return VariableBuilder(tx, call_source)(referent)
else:
from .builder import SourcelessBuilder
return SourcelessBuilder.create(tx, referent)
source = self.source and WeakRefCallSource(self.source)
return VariableTracker.build(tx, referent, source)
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):