mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Simplify creation of VariableTrackers (#135714)
## `VariableTracker::build()` hides the Builders
### The problem
In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).
Variations on this code are repeated in many places.
More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.
### The solution
In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.
This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.
**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1581a93e87
commit
e1c4548441
@ -91,7 +91,6 @@ from .variables.builder import (
|
||||
BackwardStateGraphArg,
|
||||
GraphArg,
|
||||
TrackedFake,
|
||||
VariableBuilder,
|
||||
wrap_fx_proxy,
|
||||
)
|
||||
from .variables.lists import BaseListVariable
|
||||
@ -498,7 +497,7 @@ class OutputGraph:
|
||||
cg.store(varname)
|
||||
self.pregraph_bytecode.extend(cg.get_instructions())
|
||||
source = SyntheticLocalSource(varname)
|
||||
result = VariableBuilder(self.root_tx, source)(example_value)
|
||||
result = VariableTracker.build(self.root_tx, example_value, source)
|
||||
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
||||
source
|
||||
)
|
||||
@ -767,8 +766,8 @@ class OutputGraph:
|
||||
):
|
||||
if is_dynamic_nn_module(target, self.root_tx.export):
|
||||
# Instead of returning UnspecializedNNModuleVariable, call
|
||||
# VariableBuilder so that it is tracked for mutation.
|
||||
return VariableBuilder(self.current_tx, **options)(target)
|
||||
# VariableTracker.build so that it is tracked for mutation.
|
||||
return VariableTracker.build(self.current_tx, target, **options)
|
||||
|
||||
options = dict(options)
|
||||
assert "source" in options
|
||||
@ -860,8 +859,8 @@ class OutputGraph:
|
||||
def wrap_name(module_key):
|
||||
self.output.update_co_names(module_key)
|
||||
self.global_scope[module_key] = target
|
||||
return VariableBuilder(self, ConstantSource(source_name=module_key))(
|
||||
target
|
||||
return VariableTracker.build(
|
||||
self, target, ConstantSource(source_name=module_key)
|
||||
)
|
||||
|
||||
for k, v in self.nn_modules.items():
|
||||
|
||||
@ -71,7 +71,7 @@ from .utils import (
|
||||
proxy_args_kwargs,
|
||||
)
|
||||
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker
|
||||
from .variables.builder import VariableBuilder, wrap_fx_proxy
|
||||
from .variables.builder import wrap_fx_proxy
|
||||
from .variables.builtin import BuiltinVariable
|
||||
from .variables.constant import ConstantVariable
|
||||
from .variables.ctx_manager import (
|
||||
@ -1224,15 +1224,14 @@ class InstructionTranslatorBase(
|
||||
except KeyError:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
source = GlobalSource(name)
|
||||
self.push(VariableBuilder(self, source)(value))
|
||||
self.push(VariableTracker.build(self, value, GlobalSource(name)))
|
||||
|
||||
@functools.cached_property
|
||||
def nn_modules_globals_vt(self):
|
||||
module_name = "torch.nn.modules.module"
|
||||
module_source = self.import_source(module_name)
|
||||
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
||||
return VariableBuilder(self, module_source)(fglobals_value)
|
||||
return VariableTracker.build(self, fglobals_value, module_source)
|
||||
|
||||
def LOAD_GLOBAL(self, inst):
|
||||
if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
|
||||
@ -1374,7 +1373,7 @@ class InstructionTranslatorBase(
|
||||
self.output.name_of_builtins_dict_key_in_fglobals
|
||||
)
|
||||
var_source = GetItemSource(builtins_source, argval)
|
||||
self.push(VariableBuilder(self, var_source)(val))
|
||||
self.push(VariableTracker.build(self, val, var_source))
|
||||
else:
|
||||
assert is_builtin_constant(val)
|
||||
self.push(ConstantVariable.create(value=val))
|
||||
@ -3403,7 +3402,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment]
|
||||
else:
|
||||
fglobals_value = importlib.import_module(module_name) # type: ignore[assignment]
|
||||
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
|
||||
fglobals_vt = VariableTracker.build(self, fglobals_value, module_source)
|
||||
global_source = AttrSource(module_source, name)
|
||||
else:
|
||||
globals_name = self.output.install_global_by_id(
|
||||
@ -3411,7 +3410,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
)
|
||||
globals_source = GlobalSource(globals_name)
|
||||
fglobals_value = self.f_globals # type: ignore[assignment]
|
||||
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
|
||||
fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source)
|
||||
global_source = GetItemSource(globals_source, name) # type: ignore[assignment]
|
||||
return fglobals_value, fglobals_vt, global_source
|
||||
|
||||
@ -3430,7 +3429,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
except KeyError:
|
||||
return self.load_builtin(inst)
|
||||
|
||||
self.push(VariableBuilder(self, global_source)(value))
|
||||
self.push(VariableTracker.build(self, value, global_source))
|
||||
|
||||
def STORE_GLOBAL(self, inst):
|
||||
if self.f_globals is self.parent.f_globals:
|
||||
|
||||
@ -12,7 +12,7 @@ from ..utils import istype
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase
|
||||
|
||||
|
||||
class MutableLocalSource(Enum):
|
||||
@ -121,6 +121,8 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
|
||||
VariableTracker instances are immutable and should be copied in
|
||||
order to change them.
|
||||
|
||||
Prefer the factory function VariableTracker.build() over VariableTracker.__init__().
|
||||
"""
|
||||
|
||||
# fields to leave unmodified in apply()
|
||||
@ -244,9 +246,7 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
value = self.const_getattr(tx, name)
|
||||
if not variables.ConstantVariable.is_literal(value):
|
||||
raise NotImplementedError
|
||||
source = None
|
||||
if self.source:
|
||||
source = AttrSource(self.source, name)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return variables.ConstantVariable.create(value, source=source)
|
||||
|
||||
def is_proxy(self):
|
||||
@ -363,6 +363,20 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
||||
def is_strict_mode(self, tx):
|
||||
return tx.strict_checks_fn and tx.strict_checks_fn(self)
|
||||
|
||||
@staticmethod
|
||||
def build(
|
||||
tx: "InstructionTranslatorBase",
|
||||
value: Any,
|
||||
source: Optional[Source] = None,
|
||||
) -> Any:
|
||||
"""Create a new VariableTracker from a value and optional Source"""
|
||||
from . import builder
|
||||
|
||||
if source is None:
|
||||
return builder.SourcelessBuilder.create(tx, value)
|
||||
else:
|
||||
return builder.VariableBuilder(tx, source)(value)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -701,7 +701,6 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(fn, arg_types: List[type], has_kwargs: bool):
|
||||
from .builder import SourcelessBuilder
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
obj = BuiltinVariable(fn)
|
||||
@ -794,8 +793,6 @@ class BuiltinVariable(VariableTracker):
|
||||
handlers.append(call_self_handler)
|
||||
|
||||
if obj.can_constant_fold_through():
|
||||
builder = SourcelessBuilder.create
|
||||
|
||||
if (
|
||||
all(issubclass(x, ConstantVariable) for x in arg_types)
|
||||
and not has_kwargs
|
||||
@ -809,7 +806,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
except Exception as exc:
|
||||
unimplemented(f"constant fold exception: {repr(exc)}")
|
||||
return builder(tx, res)
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
else:
|
||||
|
||||
@ -825,7 +822,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
except Exception as exc:
|
||||
unimplemented(f"constant fold exception: {repr(exc)}")
|
||||
return builder(tx, res)
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
|
||||
@ -1361,8 +1358,6 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
@staticmethod
|
||||
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if not kwargs:
|
||||
if not args:
|
||||
args = ({},)
|
||||
@ -1399,7 +1394,7 @@ class BuiltinVariable(VariableTracker):
|
||||
)
|
||||
|
||||
new_dict = dict(arg.value.items())
|
||||
return SourcelessBuilder.create(tx, new_dict)
|
||||
return VariableTracker.build(tx, new_dict)
|
||||
else:
|
||||
func_var = arg.var_getattr(tx, "items")
|
||||
if not isinstance(func_var, variables.UserFunctionVariable):
|
||||
@ -1631,7 +1626,6 @@ class BuiltinVariable(VariableTracker):
|
||||
TorchInGraphFunctionVariable,
|
||||
UserFunctionVariable,
|
||||
)
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
name = name_var.as_python_constant()
|
||||
|
||||
@ -1666,34 +1660,21 @@ class BuiltinVariable(VariableTracker):
|
||||
if not hasattr_var.as_python_constant():
|
||||
return default
|
||||
|
||||
options = {}
|
||||
if obj.source:
|
||||
source = AttrSource(obj.source, name)
|
||||
options["source"] = source
|
||||
else:
|
||||
source = None
|
||||
|
||||
source = obj.source and AttrSource(obj.source, name)
|
||||
if name in {"__bases__", "__base__", "__flags__"}:
|
||||
try:
|
||||
value = obj.as_python_constant()
|
||||
if isinstance(value, type):
|
||||
if name == "__bases__":
|
||||
bases = value.__bases__
|
||||
if source is not None:
|
||||
tuple_args = [
|
||||
VariableBuilder(tx, GetItemSource(source, i))(b)
|
||||
for i, b in enumerate(bases)
|
||||
]
|
||||
else:
|
||||
tuple_args = [
|
||||
SourcelessBuilder.create(tx, b) for b in bases
|
||||
]
|
||||
return variables.TupleVariable(tuple_args, **options)
|
||||
tuple_args = [
|
||||
VariableTracker.build(
|
||||
tx, b, source and GetItemSource(source, i)
|
||||
)
|
||||
for i, b in enumerate(value.__bases__)
|
||||
]
|
||||
return variables.TupleVariable(tuple_args, source=source)
|
||||
if name == "__base__":
|
||||
base = value.__base__
|
||||
if source is not None:
|
||||
return VariableBuilder(tx, source)(base)
|
||||
return SourcelessBuilder.create(tx, base)
|
||||
return VariableTracker.build(tx, value.__base__, source)
|
||||
if name == "__flags__":
|
||||
return ConstantVariable.create(value.__flags__)
|
||||
except NotImplementedError:
|
||||
@ -1715,14 +1696,14 @@ class BuiltinVariable(VariableTracker):
|
||||
try:
|
||||
return obj.var_getattr(tx, name)
|
||||
except NotImplementedError:
|
||||
return GetAttrVariable(obj, name, **options)
|
||||
return GetAttrVariable(obj, name, source=source)
|
||||
elif isinstance(obj, TorchInGraphFunctionVariable):
|
||||
# Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default.
|
||||
member = getattr(obj.value, name)
|
||||
if isinstance(
|
||||
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
|
||||
) and trace_rules.is_aten_op_or_tensor_method(member):
|
||||
return TorchInGraphFunctionVariable(member, **options)
|
||||
return TorchInGraphFunctionVariable(member, source=source)
|
||||
elif isinstance(obj, DummyModule):
|
||||
# TODO(mlazos) - Do we need this?
|
||||
if obj.is_torch or name not in obj.value.__dict__:
|
||||
@ -1732,18 +1713,15 @@ class BuiltinVariable(VariableTracker):
|
||||
|
||||
if config.replay_record_enabled:
|
||||
tx.exec_recorder.record_module_access(obj.value, name, member)
|
||||
return VariableTracker.build(tx, member, source)
|
||||
|
||||
if source is not None:
|
||||
return VariableBuilder(tx, source)(member)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, member)
|
||||
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
|
||||
return ConstantVariable.create(getattr(obj.fn, name))
|
||||
else:
|
||||
try:
|
||||
return obj.var_getattr(tx, name)
|
||||
except NotImplementedError:
|
||||
return GetAttrVariable(obj, name, **options)
|
||||
return GetAttrVariable(obj, name, source=source)
|
||||
|
||||
def call_setattr(
|
||||
self,
|
||||
@ -1882,8 +1860,6 @@ class BuiltinVariable(VariableTracker):
|
||||
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
|
||||
|
||||
def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
try:
|
||||
py_type = obj.python_type()
|
||||
except NotImplementedError as error:
|
||||
@ -1893,10 +1869,8 @@ class BuiltinVariable(VariableTracker):
|
||||
case_name="unknown_python_type",
|
||||
) from None
|
||||
|
||||
if obj.source is None:
|
||||
return SourcelessBuilder.create(tx, py_type)
|
||||
else:
|
||||
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
|
||||
source = obj.source and TypeSource(obj.source)
|
||||
return VariableTracker.build(tx, py_type, source)
|
||||
|
||||
def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
|
||||
if obj.has_unpack_var_sequence(tx):
|
||||
|
||||
@ -984,12 +984,10 @@ class HFPretrainedConfigVariable(VariableTracker):
|
||||
assert self.is_matching_cls(type(obj))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
from .builder import VariableBuilder
|
||||
|
||||
try:
|
||||
attr_value = getattr(self.obj, name)
|
||||
attr_source = AttrSource(self.source, name)
|
||||
return VariableBuilder(tx, attr_source)(attr_value)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, attr_value, source)
|
||||
|
||||
except AttributeError:
|
||||
unimplemented(f"getattr({self.value}, {name})")
|
||||
@ -1053,15 +1051,11 @@ class PythonSysModulesVariable(VariableTracker):
|
||||
key: VariableTracker,
|
||||
default: Optional[VariableTracker] = None,
|
||||
):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
k, has_key = self._contains_helper(tx, key)
|
||||
|
||||
if has_key:
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
GetItemSource(self.source, k),
|
||||
)(sys.modules[k])
|
||||
source = self.source and GetItemSource(self.source, k)
|
||||
return VariableTracker.build(tx, sys.modules[k], source)
|
||||
|
||||
if default is not None:
|
||||
return default
|
||||
@ -1069,10 +1063,6 @@ class PythonSysModulesVariable(VariableTracker):
|
||||
return ConstantVariable.create(value=None)
|
||||
|
||||
def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
k, has_key = self._contains_helper(tx, key)
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
GetItemSource(self.source, k),
|
||||
)(sys.modules[k])
|
||||
source = self.source and GetItemSource(self.source, k)
|
||||
return VariableTracker.build(tx, sys.modules[k], source)
|
||||
|
||||
@ -46,9 +46,7 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
|
||||
if isinstance(val, VariableTracker):
|
||||
return val
|
||||
elif not source:
|
||||
from torch._dynamo.variables.builder import SourcelessBuilder
|
||||
|
||||
return SourcelessBuilder.create(tx, val)
|
||||
return VariableTracker.build(tx, val)
|
||||
else:
|
||||
# Create a lazy variable to avoid guarding on __defaults__ unless really
|
||||
# needed.
|
||||
@ -240,8 +238,6 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
# optimization for cleaner codegen
|
||||
result[name] = var
|
||||
elif self.source:
|
||||
from .builder import VariableBuilder
|
||||
|
||||
side_effects = parent.output.side_effects
|
||||
if cell in side_effects:
|
||||
out = side_effects[cell]
|
||||
@ -253,9 +249,9 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
closure_cell, "cell_contents"
|
||||
)
|
||||
try:
|
||||
contents_var = VariableBuilder(
|
||||
parent, closure_cell_contents
|
||||
)(cell.cell_contents)
|
||||
contents_var = VariableTracker.build(
|
||||
parent, cell.cell_contents, closure_cell_contents
|
||||
)
|
||||
except ValueError:
|
||||
# Cell has not yet been assigned
|
||||
contents_var = variables.DeletedVariable()
|
||||
@ -286,9 +282,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
result[name] = out
|
||||
|
||||
else:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
|
||||
result[name] = VariableTracker.build(tx, cell.cell_contents)
|
||||
|
||||
return result, closure_cells
|
||||
|
||||
@ -296,17 +290,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||
pass
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str):
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
try:
|
||||
subobj = inspect.getattr_static(self.fn, name)
|
||||
except AttributeError:
|
||||
options = {"source": source}
|
||||
return variables.GetAttrVariable(self, name, **options)
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
if source:
|
||||
return variables.LazyVariableTracker.create(subobj, source)
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return SourcelessBuilder.create(tx, subobj)
|
||||
return VariableTracker.build(tx, subobj)
|
||||
|
||||
def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
result = hasattr(self.fn, name)
|
||||
@ -757,14 +748,8 @@ class WrapperUserFunctionVariable(VariableTracker):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if name == self.attr_to_trace:
|
||||
val = getattr(self.wrapper_obj, self.attr_to_trace)
|
||||
if self.source:
|
||||
from .builder import VariableBuilder
|
||||
|
||||
return VariableBuilder(tx, AttrSource(self.source, name))(val)
|
||||
else:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return SourcelessBuilder.create(tx, val)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, val, source)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
@ -999,8 +984,6 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
args: "List[VariableTracker]",
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from torch._dynamo.variables.builder import SourcelessBuilder
|
||||
|
||||
if self.can_constant_fold_through() and check_unspec_or_constant_args(
|
||||
args, kwargs
|
||||
):
|
||||
@ -1010,9 +993,9 @@ class PolyfilledFunctionVariable(VariableTracker):
|
||||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||||
)
|
||||
)
|
||||
return SourcelessBuilder.create(tx, result)
|
||||
return VariableTracker.build(tx, result)
|
||||
|
||||
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn)
|
||||
traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
|
||||
return traceable_function_variable.call_function(tx, args, kwargs)
|
||||
|
||||
def call_method(
|
||||
|
||||
@ -13,7 +13,6 @@ import torch.fx
|
||||
import torch.nn
|
||||
from torch._dynamo.utils import get_fake_value
|
||||
from torch._dynamo.variables import ConstantVariable
|
||||
from torch._dynamo.variables.base import VariableTracker
|
||||
from torch._dynamo.variables.builtin import BuiltinVariable
|
||||
from torch._dynamo.variables.functions import UserFunctionVariable
|
||||
from torch._dynamo.variables.tensor import SymNodeVariable
|
||||
@ -31,6 +30,7 @@ from ..exc import (
|
||||
)
|
||||
from ..source import AttrSource
|
||||
from ..utils import proxy_args_kwargs
|
||||
from .base import VariableTracker
|
||||
from .dicts import ConstDictVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import ListVariable, TupleVariable
|
||||
@ -1040,7 +1040,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
args: List[VariableTracker],
|
||||
kwargs: Dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
from .builder import SourcelessBuilder, wrap_fx_proxy
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
|
||||
|
||||
@ -1062,7 +1062,7 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
tx,
|
||||
"new_empty",
|
||||
args=(
|
||||
SourcelessBuilder.create(
|
||||
VariableTracker.build(
|
||||
tx,
|
||||
leaf.size
|
||||
if leaf.size is not None
|
||||
@ -1072,8 +1072,8 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
),
|
||||
),
|
||||
kwargs={
|
||||
"dtype": SourcelessBuilder.create(tx, leaf.dtype),
|
||||
"requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad),
|
||||
"dtype": VariableTracker.build(tx, leaf.dtype),
|
||||
"requires_grad": VariableTracker.build(tx, leaf.requires_grad),
|
||||
},
|
||||
)
|
||||
for leaf in itertools.chain(xs.items, xs.items)
|
||||
@ -2057,7 +2057,6 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
fn_name: str,
|
||||
):
|
||||
from .._trace_wrapped_higher_order_op import TransformGetItemToIndex
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
tx: InstructionTranslator = tx
|
||||
|
||||
@ -2065,9 +2064,9 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
return query.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
(SourcelessBuilder.create(tx, []),),
|
||||
(VariableTracker.build(tx, []),),
|
||||
{
|
||||
"dtype": SourcelessBuilder.create(tx, torch.int32),
|
||||
"dtype": VariableTracker.build(tx, torch.int32),
|
||||
},
|
||||
)
|
||||
|
||||
@ -2077,8 +2076,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
score = query.call_method(
|
||||
tx,
|
||||
"new_empty",
|
||||
(SourcelessBuilder.create(tx, []),),
|
||||
{"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)},
|
||||
(VariableTracker.build(tx, []),),
|
||||
{"requires_grad": VariableTracker.build(tx, scores_require_grad)},
|
||||
)
|
||||
new_args = [score, *bhmn]
|
||||
else:
|
||||
|
||||
@ -172,10 +172,8 @@ class ItertoolsVariable(VariableTracker):
|
||||
*args, mutable_local=MutableLocal()
|
||||
)
|
||||
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
|
||||
VariableTracker.build(tx, polyfills.repeat), args, kwargs
|
||||
)
|
||||
elif self.value is itertools.count:
|
||||
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
|
||||
|
||||
@ -20,14 +20,15 @@ class LazyCache:
|
||||
def realize(self) -> None:
|
||||
assert self.vt is None
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
if isinstance(self.value, LazySymNodeFormatString):
|
||||
self.vt = SourcelessBuilder.create(tx, self.value)
|
||||
else:
|
||||
self.vt = VariableBuilder(tx, self.source)(self.value)
|
||||
|
||||
if isinstance(self.value, LazySymNodeFormatString):
|
||||
source = None
|
||||
else:
|
||||
source = self.source
|
||||
|
||||
self.vt = VariableTracker.build(tx, self.value, source)
|
||||
del self.value
|
||||
del self.source
|
||||
|
||||
@ -37,7 +38,7 @@ class LazyVariableTracker(VariableTracker):
|
||||
A structure that defers the creation of the actual VariableTracker
|
||||
for a given underlying value until it is accessed.
|
||||
|
||||
The `realize` function invokes VariableBuilder to produce the real object.
|
||||
The `realize` function invokes VariableTracker.build() to produce the real object.
|
||||
Once a LazyVariableTracker has been realized, internal bookkeeping will
|
||||
prevent double realization.
|
||||
|
||||
|
||||
@ -135,10 +135,8 @@ class BaseListVariable(VariableTracker):
|
||||
assert not kwargs
|
||||
return iter_contains(self.unpack_var_sequence(tx), args[0], tx)
|
||||
elif name == "index":
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.index),
|
||||
VariableTracker.build(tx, polyfills.index),
|
||||
[self] + list(args),
|
||||
kwargs,
|
||||
)
|
||||
|
||||
@ -207,12 +207,10 @@ class SuperVariable(VariableTracker):
|
||||
and len(kwargs) == 0
|
||||
and args[0].is_python_constant()
|
||||
):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
key = args[0].as_python_constant()
|
||||
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
|
||||
collections.OrderedDict.__getitem__(self.objvar.value, key)
|
||||
)
|
||||
value = collections.OrderedDict.__getitem__(self.objvar.value, key)
|
||||
source = ODictGetItemSource(self.objvar.source, key)
|
||||
return VariableTracker.build(tx, value, source)
|
||||
elif inner_fn in (
|
||||
collections.OrderedDict.__setitem__,
|
||||
object.__setattr__,
|
||||
@ -467,15 +465,10 @@ class InspectParameterVariable(VariableTracker):
|
||||
self.value = value
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
try:
|
||||
attr_value = getattr(self.value, name)
|
||||
if self.source:
|
||||
attr_source = AttrSource(self.source, name)
|
||||
return VariableBuilder(tx, attr_source)(attr_value)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, attr_value)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, attr_value, source)
|
||||
except AttributeError:
|
||||
unimplemented(f"getattr({self.value}, {name})")
|
||||
|
||||
@ -909,11 +902,9 @@ class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
||||
if self.needs_input_grad is not None:
|
||||
return variables.ConstantVariable.create(self.needs_input_grad)
|
||||
if self.source:
|
||||
from .builder import VariableBuilder
|
||||
source = AttrSource(self.source, "needs_input_grad")
|
||||
return VariableTracker.build(tx, self.value.needs_input_grad, source)
|
||||
|
||||
return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))(
|
||||
self.value.needs_input_grad
|
||||
)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
|
||||
@ -1118,11 +1109,8 @@ class GetSetDescriptorVariable(VariableTracker):
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if name == "__get__" and self.source:
|
||||
from .builder import VariableBuilder
|
||||
|
||||
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
|
||||
self.desc.__get__
|
||||
)
|
||||
source = AttrSource(self.source, "__get__")
|
||||
return VariableTracker.build(tx, self.desc.__get__, source)
|
||||
else:
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
@ -1162,18 +1150,13 @@ class PythonModuleVariable(VariableTracker):
|
||||
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
|
||||
return tx.output.side_effects.load_attr(self, name)
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if self.is_torch or name not in self.value.__dict__:
|
||||
attr_value = getattr(self.value, name)
|
||||
else:
|
||||
attr_value = self.value.__dict__[name]
|
||||
|
||||
if self.source:
|
||||
new_source = AttrSource(self.source, name)
|
||||
return VariableBuilder(tx, new_source)(attr_value)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, attr_value)
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
return VariableTracker.build(tx, attr_value, source)
|
||||
|
||||
|
||||
class TypingVariable(VariableTracker):
|
||||
|
||||
@ -244,12 +244,7 @@ class NNModuleVariable(VariableTracker):
|
||||
)
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
from .builder import VariableBuilder
|
||||
|
||||
if self.source:
|
||||
source = AttrSource(self.source, name)
|
||||
else:
|
||||
source = None
|
||||
source = self.source and AttrSource(self.source, name)
|
||||
|
||||
base = tx.output.get_submodule(self.module_key)
|
||||
base_dict = object.__getattribute__(base, "__dict__")
|
||||
@ -297,7 +292,7 @@ class NNModuleVariable(VariableTracker):
|
||||
return variables.UserDefinedClassVariable(base.__class__, source=source)
|
||||
|
||||
if object_member:
|
||||
out = VariableBuilder(tx, NNModuleSource(source))(subobj)
|
||||
out = VariableTracker.build(tx, subobj, NNModuleSource(source))
|
||||
|
||||
if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)):
|
||||
# nn_module_stack source is BC surface area. Ensure that
|
||||
@ -333,7 +328,7 @@ class NNModuleVariable(VariableTracker):
|
||||
return variables.UserMethodVariable(subobj, self, source=source)
|
||||
elif is_safe_constant(subobj) or istensor(subobj):
|
||||
# Support possibly common cases of class members
|
||||
return VariableBuilder(tx, NNModuleSource(source))(subobj)
|
||||
return VariableTracker.build(tx, subobj, NNModuleSource(source))
|
||||
else:
|
||||
unimplemented(
|
||||
f"class property {name} - {typestr(base)} {typestr(subobj)}"
|
||||
@ -1083,7 +1078,7 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable):
|
||||
)
|
||||
return variables.ConstDictVariable({})
|
||||
|
||||
# For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable.
|
||||
# For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable.
|
||||
# However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for
|
||||
# differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why
|
||||
# key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a
|
||||
|
||||
@ -17,6 +17,7 @@ from ..source import (
|
||||
GradSource,
|
||||
)
|
||||
from ..utils import GLOBAL_KEY_PREFIX
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import ConstDictVariable
|
||||
from .lists import ListVariable
|
||||
@ -27,8 +28,6 @@ from .user_defined import UserDefinedObjectVariable
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
class ArgMappingException(Exception):
|
||||
pass
|
||||
@ -147,7 +146,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
def _set_capturable(self, tx):
|
||||
from . import LazyVariableTracker
|
||||
from .builder import VariableBuilder
|
||||
|
||||
# We only set capturable if params are on cuda
|
||||
# and the state is not initialized
|
||||
@ -168,10 +166,9 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
if safe_to_set_capturable(group):
|
||||
group["capturable"] = True
|
||||
|
||||
source = self.source and AttrSource(self.source, "param_groups")
|
||||
param_groups_vt = LazyVariableTracker.realize_all(
|
||||
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
||||
self.value.param_groups
|
||||
)
|
||||
VariableTracker.build(tx, self.value.param_groups, source)
|
||||
)
|
||||
for ind, param_group_vt in enumerate(param_groups_vt.items):
|
||||
key = ConstDictVariable._HashableTracker(
|
||||
@ -214,7 +211,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
def map_sources_and_install_guards(self, tx):
|
||||
from ..decorators import mark_static_address
|
||||
from .builder import VariableBuilder
|
||||
from .lazy import LazyVariableTracker
|
||||
|
||||
self.grad_to_source = {}
|
||||
@ -235,15 +231,13 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
# Recursively realize the variable trackers for optim.state and
|
||||
# optim.param_groups, which recursively install the necessary guards.
|
||||
params_groups_source = self.source and AttrSource(self.source, "param_groups")
|
||||
param_groups_vt = LazyVariableTracker.realize_all(
|
||||
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
|
||||
self.value.param_groups
|
||||
)
|
||||
VariableTracker.build(tx, self.value.param_groups, params_groups_source)
|
||||
)
|
||||
|
||||
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
|
||||
self.value.state
|
||||
)
|
||||
state_source = self.source and AttrSource(self.source, "state")
|
||||
state_vt = VariableTracker.build(tx, self.value.state, state_source)
|
||||
|
||||
# We need to realize the top level state dict to populate
|
||||
# the guard locals
|
||||
@ -265,15 +259,15 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
key_index = i
|
||||
break
|
||||
if key_index:
|
||||
state_source = AttrSource(self.source, "state")
|
||||
LazyVariableTracker.realize_all(
|
||||
VariableBuilder(
|
||||
VariableTracker.build(
|
||||
tx,
|
||||
self.value.state[param],
|
||||
GetItemSource(
|
||||
state_source,
|
||||
ConstDictKeySource(state_source, key_index),
|
||||
),
|
||||
)(self.value.state[param])
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
@ -312,7 +306,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
|
||||
# We have to again iterate over the state dict to collect the
|
||||
# tensor_to_source dict. This is used for the finalizer.
|
||||
state_source = AttrSource(self.source, "state")
|
||||
for idx, (p, value) in enumerate(self.value.state.items()):
|
||||
p_state_source = GetItemSource(
|
||||
state_source, ConstDictKeySource(state_source, idx)
|
||||
@ -328,7 +321,6 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
|
||||
"""Wrap state tensor in a TensorVariable"""
|
||||
from ..decorators import mark_static_address
|
||||
from .builder import VariableBuilder
|
||||
|
||||
# If we have a source for a tensor already use it,
|
||||
# if we have not seen a tensor before, stash and use a
|
||||
@ -338,20 +330,19 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
if tensor_value in self.tensor_to_source:
|
||||
# mark these tensors as static for cudagraphs
|
||||
mark_static_address(tensor_value)
|
||||
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
|
||||
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
|
||||
source = self.tensor_to_source[tensor_value]
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||
elif tensor_value in self.grad_to_source:
|
||||
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
|
||||
source = self.grad_to_source[tensor_value]
|
||||
else:
|
||||
# mark these tensors as static for cudagraphs
|
||||
mark_static_address(tensor_value)
|
||||
|
||||
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
|
||||
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
|
||||
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
|
||||
source = GlobalWeakRefSource(global_name)
|
||||
self.static_tensor_names.add(tx.output.module_key_name(source.name))
|
||||
|
||||
result = builder(tensor_value)
|
||||
return result
|
||||
return VariableTracker.build(tx, tensor_value, source)
|
||||
|
||||
def update_list_args(
|
||||
self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
|
||||
@ -367,14 +358,8 @@ class OptimizerVariable(UserDefinedObjectVariable):
|
||||
if isinstance(val, torch.Tensor):
|
||||
arg.items.append(self.wrap_tensor(tx, val))
|
||||
else:
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if arg.source:
|
||||
arg.items.append(
|
||||
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
|
||||
)
|
||||
else:
|
||||
arg.items.append(SourcelessBuilder.create(tx, val))
|
||||
source = arg.source and GetItemSource(arg.source, i)
|
||||
arg.items.append(VariableTracker.build(tx, val, source))
|
||||
|
||||
def create_finalizer(self, tx):
|
||||
names_to_delete = self.static_tensor_names
|
||||
|
||||
@ -5,12 +5,15 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import Unsupported
|
||||
from ..source import AttrSource
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split()
|
||||
|
||||
|
||||
class SDPAParamsVariable(VariableTracker):
|
||||
"""Represents the c++ params struct for scaled dot product attention.
|
||||
@ -20,35 +23,13 @@ class SDPAParamsVariable(VariableTracker):
|
||||
def create(tx: "InstructionTranslator", value, source):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
from ..source import AttrSource
|
||||
from .builder import VariableBuilder
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
||||
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
|
||||
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
|
||||
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
|
||||
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
|
||||
value.attn_mask
|
||||
)
|
||||
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
|
||||
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
|
||||
value.is_causal
|
||||
)
|
||||
enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))(
|
||||
value.enable_gqa
|
||||
)
|
||||
param_vars = [
|
||||
query_var,
|
||||
key_var,
|
||||
value_var,
|
||||
attn_mask_var,
|
||||
dropout_var,
|
||||
is_causal_var,
|
||||
enable_gqa_var,
|
||||
params = [
|
||||
VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
|
||||
for p in PARAM_NAMES
|
||||
]
|
||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(
|
||||
tx, param_vars, {}
|
||||
)
|
||||
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
|
||||
|
||||
def __init__(self, proxy, param_vars, **kwargs) -> None:
|
||||
self.proxy = proxy
|
||||
@ -70,7 +51,6 @@ class SDPAParamsVariable(VariableTracker):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
import torch._C
|
||||
|
||||
from ..source import AttrSource
|
||||
from .builder import wrap_fx_proxy
|
||||
from .misc import GetAttrVariable
|
||||
|
||||
|
||||
@ -238,9 +238,7 @@ class TensorVariable(VariableTracker):
|
||||
# any other attributes on the subclass (that are not methods)
|
||||
# are assumed to be constant metadata.
|
||||
elif not callable(example_value):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return SourcelessBuilder.create(tx, example_value)
|
||||
return VariableTracker.build(tx, example_value)
|
||||
|
||||
if not (self.source and self.source.subguards_allowed()):
|
||||
raise NotImplementedError
|
||||
@ -277,12 +275,9 @@ class TensorVariable(VariableTracker):
|
||||
# Note - at a certain point we may want to handle
|
||||
raise NotImplementedError
|
||||
|
||||
from ..guards import GuardBuilder
|
||||
from .builder import VariableBuilder
|
||||
|
||||
attr_source = AttrSource(self.source, name)
|
||||
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
|
||||
return VariableBuilder(tx, attr_source)(real_value)
|
||||
return VariableTracker.build(tx, real_value, attr_source)
|
||||
|
||||
def method_attr_ndim(self, tx):
|
||||
if self.ndim is not None:
|
||||
@ -695,7 +690,6 @@ class TensorVariable(VariableTracker):
|
||||
def method_as_subclass(self, cls):
|
||||
if isinstance(cls, TensorSubclassVariable) and cls.source:
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import VariableBuilder
|
||||
from .torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
@ -705,10 +699,11 @@ class TensorVariable(VariableTracker):
|
||||
# defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call.
|
||||
# It is up to the user whether this is correct behavior or not.
|
||||
py_cls = cls.as_python_constant()
|
||||
torch_fn = VariableBuilder(
|
||||
torch_fn = VariableTracker.build(
|
||||
tx,
|
||||
py_cls.__torch_function__.__func__,
|
||||
AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"),
|
||||
)(py_cls.__torch_function__.__func__)
|
||||
)
|
||||
|
||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||
tx, self, py_cls, torch_fn
|
||||
@ -750,7 +745,6 @@ class TensorVariable(VariableTracker):
|
||||
|
||||
def method_tolist(self):
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
|
||||
@ -787,7 +781,7 @@ class TensorVariable(VariableTracker):
|
||||
|
||||
tensor = self.as_proxy().node.meta["example_value"]
|
||||
out = tolist(tensor, self.as_proxy())
|
||||
return SourcelessBuilder.create(tx, out)
|
||||
return VariableTracker.build(tx, out)
|
||||
|
||||
def method_backward(self, *args, **kwargs):
|
||||
unimplemented("Tensor.backward")
|
||||
@ -857,10 +851,9 @@ class TensorVariable(VariableTracker):
|
||||
tx = InstructionTranslator.current_tx()
|
||||
if value is not None:
|
||||
from .. import polyfills
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.addcmul_inplace),
|
||||
VariableTracker.build(tx, polyfills.addcmul_inplace),
|
||||
[self, tensor1, tensor2, value],
|
||||
{},
|
||||
)
|
||||
@ -1155,9 +1148,7 @@ class SymNodeVariable(VariableTracker):
|
||||
|
||||
def as_tensor(self, tx):
|
||||
if self._tensor_var is None:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
self._tensor_var = SourcelessBuilder.create(
|
||||
self._tensor_var = VariableTracker.build(
|
||||
tx, torch.scalar_tensor
|
||||
).call_function(tx, [self], {})
|
||||
return self._tensor_var
|
||||
@ -1362,12 +1353,10 @@ class TensorSubclassVariable(VariableTracker):
|
||||
kwargs: Dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
||||
from .builder import VariableBuilder
|
||||
from .torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
torch_fn = VariableBuilder(
|
||||
tx, AttrSource(self.source, "__torch_function__")
|
||||
)(self.value.__torch_function__)
|
||||
source = AttrSource(self.source, "__torch_function__")
|
||||
torch_fn = VariableTracker.build(tx, self.value.__torch_function__, source)
|
||||
|
||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||
tx, args[0], self.value, torch_fn
|
||||
|
||||
@ -397,7 +397,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
TensorVariable,
|
||||
UserDefinedObjectVariable,
|
||||
)
|
||||
from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
|
||||
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
|
||||
|
||||
@register(*tracing_state_functions)
|
||||
def handle_tracing_state_functions(
|
||||
@ -422,14 +422,14 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
# the set of functions that we trace __torch_function__ on to
|
||||
# functions outside of the actual set. Implementing this properly will require implementing
|
||||
# some variable types to track and compare tensor getset descriptors
|
||||
return SourcelessBuilder.create(
|
||||
return VariableTracker.build(
|
||||
tx, torch.overrides.get_default_nowrap_functions()
|
||||
)
|
||||
|
||||
@register(torch.ops.inductor.accumulate_grad_.default)
|
||||
def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs
|
||||
VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs
|
||||
)
|
||||
|
||||
@register(math.radians)
|
||||
@ -437,7 +437,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
if not check_unspec_or_constant_args(args, kwargs):
|
||||
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.radians), args, kwargs
|
||||
VariableTracker.build(tx, polyfills.radians), args, kwargs
|
||||
)
|
||||
|
||||
@register(torch.is_tensor, torch.overrides.is_tensor_like)
|
||||
@ -622,7 +622,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
):
|
||||
if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs:
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace),
|
||||
VariableTracker.build(tx, polyfills.foreach_lerp_inplace),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
@ -635,7 +635,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
# in compile, it's more performant to not graph break.
|
||||
if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs:
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar),
|
||||
VariableTracker.build(tx, polyfills.foreach_pow_scalar),
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
@ -704,7 +704,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
||||
# Note - while we *could* cook up sources around invocations, like a FunctionSource
|
||||
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
|
||||
# guard propagation via options is the best we can do.
|
||||
return SourcelessBuilder.create(tx, invocation_result)
|
||||
return VariableTracker.build(tx, invocation_result)
|
||||
|
||||
@register(DTensor.from_local)
|
||||
def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
|
||||
@ -1143,8 +1143,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
@staticmethod
|
||||
def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad):
|
||||
# Alternate version if we have a .source
|
||||
from .builder import VariableBuilder
|
||||
|
||||
varname = tx.output.new_var()
|
||||
|
||||
# construct the nn.Parmeter before the graph save it to varname
|
||||
@ -1167,7 +1165,7 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
||||
example_value = torch.nn.Parameter(
|
||||
tx.output.example_value_from_input_node(data.as_proxy().node)
|
||||
)
|
||||
result = VariableBuilder(tx, source)(example_value)
|
||||
result = VariableTracker.build(tx, example_value, source)
|
||||
# No need to guard on this since we already guarded on `data`.
|
||||
# These guards would fail since varname doesn't exist until after the function starts
|
||||
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
|
||||
|
||||
@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
|
||||
if isinstance(var, TensorWithTFOverrideVariable):
|
||||
return var.class_type_var(tx)
|
||||
elif isinstance(var, UserDefinedObjectVariable):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if var.source:
|
||||
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, var.python_type())
|
||||
source = var.source and TypeSource(var.source)
|
||||
return VariableTracker.build(tx, var.python_type(), source)
|
||||
|
||||
|
||||
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
||||
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
||||
def call_torch_function(
|
||||
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
|
||||
):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
# signature:
|
||||
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
tf_args = (
|
||||
torch_function_type,
|
||||
fn,
|
||||
types,
|
||||
SourcelessBuilder.create(tx, tuple(args)),
|
||||
SourcelessBuilder.create(tx, kwargs),
|
||||
VariableTracker.build(tx, tuple(args)),
|
||||
VariableTracker.build(tx, kwargs),
|
||||
)
|
||||
return tx.inline_user_function_return(torch_function_var, tf_args, {})
|
||||
|
||||
@ -515,20 +509,13 @@ def call_torch_function(
|
||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||
from types import FunctionType
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
func = value.__torch_function__.__func__
|
||||
|
||||
if not isinstance(func, FunctionType):
|
||||
unimplemented("Builtin/C++ torch function implementations NYI")
|
||||
|
||||
if source:
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
|
||||
)(value.__torch_function__.__func__)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
|
||||
source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
|
||||
return VariableTracker.build(tx, func, source)
|
||||
|
||||
|
||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if name in banned_attrs:
|
||||
unimplemented(
|
||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
)
|
||||
)
|
||||
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
|
||||
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
||||
|
||||
return self.call_torch_function(
|
||||
tx,
|
||||
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
if tx.output.torch_function_enabled:
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Calling overridden method {name} on a tensor"
|
||||
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# We've established with the above check that the method is not overridden, so we guard that the method is the same
|
||||
# as the impl defined on tensor and retrieve it
|
||||
if self.source:
|
||||
func_var = VariableBuilder(
|
||||
tx, AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)(inspect.getattr_static(self.python_type(), name))
|
||||
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
value = inspect.getattr_static(self.python_type(), name)
|
||||
else:
|
||||
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
|
||||
source = None
|
||||
value = getattr(torch.Tensor, name)
|
||||
func_var = VariableTracker.build(tx, value, source)
|
||||
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
@ -158,7 +158,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
from . import ConstantVariable, EnumVariable
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
source = AttrSource(self.source, name) if self.source is not None else None
|
||||
|
||||
@ -187,11 +186,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
obj = None
|
||||
|
||||
if isinstance(obj, staticmethod):
|
||||
func = obj.__get__(self.value)
|
||||
if source is not None:
|
||||
return VariableBuilder(tx, source)(func)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, func)
|
||||
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||
elif isinstance(obj, classmethod):
|
||||
if isinstance(obj.__func__, property):
|
||||
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
|
||||
@ -202,16 +197,13 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
# e.g.: inspect.getattr_static(dict, "fromkeys")
|
||||
# inspect.getattr_static(itertools.chain, "from_iterable")
|
||||
func = obj.__get__(None, self.value)
|
||||
if source is not None:
|
||||
return VariableBuilder(tx, source)(func)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, func)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif source:
|
||||
# __mro__ is a member in < 3.12, an attribute in >= 3.12
|
||||
if inspect.ismemberdescriptor(obj) or (
|
||||
sys.version_info >= (3, 12) and name == "__mro__"
|
||||
):
|
||||
return VariableBuilder(tx, source)(obj.__get__(self.value))
|
||||
return VariableTracker.build(tx, obj.__get__(self.value), source)
|
||||
|
||||
if ConstantVariable.is_literal(obj):
|
||||
return ConstantVariable.create(obj)
|
||||
@ -222,14 +214,15 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
or self.value.__module__ == "torch"
|
||||
):
|
||||
if source:
|
||||
return VariableBuilder(tx, source)(obj)
|
||||
return VariableTracker.build(tx, obj, source)
|
||||
|
||||
if (
|
||||
source
|
||||
and not inspect.ismethoddescriptor(obj)
|
||||
and not is_wrapper_or_member_descriptor(obj)
|
||||
):
|
||||
return VariableBuilder(tx, source)(obj)
|
||||
return VariableTracker.build(tx, obj, source)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
|
||||
@ -341,7 +334,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from ..side_effects import SideEffects
|
||||
from .builder import SourcelessBuilder, wrap_fx_proxy
|
||||
from .builder import wrap_fx_proxy
|
||||
from .builtin import BuiltinVariable
|
||||
|
||||
constant_args = check_constant_args(args, kwargs)
|
||||
@ -452,7 +445,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
field_var = kwargs[field_name]
|
||||
else:
|
||||
assert field_name in field_defaults
|
||||
field_var = SourcelessBuilder.create(
|
||||
field_var = VariableTracker.build(
|
||||
tx, field_defaults[field_name]
|
||||
)
|
||||
var_tracker_kwargs[field_name] = field_var
|
||||
@ -465,8 +458,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
|
||||
return variables.NamedTupleVariable(items, self.value)
|
||||
elif is_frozen_dataclass(self.value) and self.is_standard_new():
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
fields = dataclasses.fields(self.value)
|
||||
items = list(args)
|
||||
items.extend([None] * (len(fields) - len(items)))
|
||||
@ -481,9 +472,9 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
continue
|
||||
|
||||
if field.default is not dataclasses.MISSING:
|
||||
var_tracker = SourcelessBuilder.create(tx, field.default)
|
||||
var_tracker = VariableTracker.build(tx, field.default)
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
factory_fn = SourcelessBuilder.create(
|
||||
factory_fn = VariableTracker.build(
|
||||
tx, field.default_factory
|
||||
)
|
||||
var_tracker = factory_fn.call_function(tx, [], {})
|
||||
@ -573,7 +564,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
and self.source
|
||||
):
|
||||
return tx.inline_user_function_return(
|
||||
SourcelessBuilder.create(
|
||||
VariableTracker.build(
|
||||
tx, polyfills.instantiate_user_defined_class_object
|
||||
),
|
||||
[self, *args],
|
||||
@ -857,7 +848,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
from .. import trace_rules
|
||||
from .builder import VariableBuilder
|
||||
|
||||
if (
|
||||
self.is_supported_random()
|
||||
@ -894,9 +884,9 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
"Sourceless UserDefinedObjectVariable method not supported"
|
||||
)
|
||||
func_src = AttrSource(self.source, "__func__")
|
||||
func_var = VariableBuilder(tx, func_src)(func)
|
||||
func_var = VariableTracker.build(tx, func, func_src)
|
||||
obj_src = AttrSource(self.source, "__self__")
|
||||
obj_var = VariableBuilder(tx, obj_src)(obj)
|
||||
obj_var = VariableTracker.build(tx, obj, obj_src)
|
||||
return func_var.call_function(tx, [obj_var] + args, kwargs)
|
||||
elif (
|
||||
istype(self.value, functools.partial)
|
||||
@ -998,7 +988,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
from .. import trace_rules
|
||||
from . import ConstantVariable
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
source = AttrSource(self.source, name) if self.source else None
|
||||
self._check_for_getattribute()
|
||||
@ -1090,10 +1079,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
elif isinstance(subobj, types.ClassMethodDescriptorType):
|
||||
# e.g.: inspect.getattr_static({}, "fromkeys")
|
||||
func = subobj.__get__(self.value, None)
|
||||
if source is not None:
|
||||
return VariableBuilder(tx, source)(func)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, func)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor(
|
||||
subobj.__get__
|
||||
):
|
||||
@ -1188,7 +1174,7 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
subobj_from_class, src_from_class
|
||||
)
|
||||
|
||||
return SourcelessBuilder.create(tx, subobj)
|
||||
return VariableTracker.build(tx, subobj)
|
||||
|
||||
# Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError.
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
@ -1212,7 +1198,6 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
return variables.ConstantVariable.create(False)
|
||||
|
||||
def odict_getitem(self, tx: "InstructionTranslator", key):
|
||||
from .builder import VariableBuilder
|
||||
from .dicts import is_hashable
|
||||
|
||||
# TODO this should probably be merged with the dict handling
|
||||
@ -1223,10 +1208,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
else key.as_python_constant()
|
||||
)
|
||||
|
||||
return VariableBuilder(
|
||||
return VariableTracker.build(
|
||||
tx,
|
||||
ODictGetItemSource(self.source, index),
|
||||
)(collections.OrderedDict.__getitem__(self.value, key.as_python_constant()))
|
||||
collections.OrderedDict.__getitem__(self.value, key.as_python_constant()),
|
||||
self.source and ODictGetItemSource(self.source, index),
|
||||
)
|
||||
|
||||
|
||||
class FrozenDataClassVariable(UserDefinedObjectVariable):
|
||||
@ -1236,14 +1222,14 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
||||
|
||||
assert is_frozen_dataclass(value)
|
||||
|
||||
from .builder import VariableBuilder
|
||||
|
||||
field_map = {}
|
||||
for field in fields(value):
|
||||
if hasattr(value, field.name):
|
||||
field_map[field.name] = VariableBuilder(
|
||||
tx, AttrSource(source, field.name)
|
||||
)(getattr(value, field.name))
|
||||
field_map[field.name] = VariableTracker.build(
|
||||
tx,
|
||||
getattr(value, field.name),
|
||||
source and AttrSource(source, field.name),
|
||||
)
|
||||
|
||||
return FrozenDataClassVariable(value, fields=field_map, source=source)
|
||||
|
||||
@ -1315,16 +1301,8 @@ class WeakRefVariable(UserDefinedObjectVariable):
|
||||
) -> "VariableTracker":
|
||||
call_source = None
|
||||
referent = self.value()
|
||||
|
||||
if self.source:
|
||||
from .builder import VariableBuilder
|
||||
|
||||
call_source = WeakRefCallSource(self.source)
|
||||
return VariableBuilder(tx, call_source)(referent)
|
||||
else:
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
return SourcelessBuilder.create(tx, referent)
|
||||
source = self.source and WeakRefCallSource(self.source)
|
||||
return VariableTracker.build(tx, referent, source)
|
||||
|
||||
|
||||
class KeyedJaggedTensorVariable(UserDefinedObjectVariable):
|
||||
|
||||
Reference in New Issue
Block a user