Files
pytorch/torch/_dynamo/variables/lazy.py
Tom Ritchford e1c4548441 [dynamo] Simplify creation of VariableTrackers (#135714)
## `VariableTracker::build()` hides the Builders

### The problem

In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).

Variations on this code are repeated in many places.

More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.

### The solution

In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.

This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.

**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
2024-10-18 09:36:46 +00:00

183 lines
5.9 KiB
Python

import collections
import functools
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing_extensions import Self
from .base import VariableTracker
from .tensor import SymNodeVariable
class LazyCache:
"""Container to cache the real VariableTracker"""
def __init__(self, value: Any, source: Any) -> None:
if not isinstance(value, LazySymNodeFormatString):
assert source
self.value = value
self.source = source
self.vt: Optional[VariableTracker] = None
def realize(self) -> None:
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
if isinstance(self.value, LazySymNodeFormatString):
source = None
else:
source = self.source
self.vt = VariableTracker.build(tx, self.value, source)
del self.value
del self.source
class LazyVariableTracker(VariableTracker):
"""
A structure that defers the creation of the actual VariableTracker
for a given underlying value until it is accessed.
The `realize` function invokes VariableTracker.build() to produce the real object.
Once a LazyVariableTracker has been realized, internal bookkeeping will
prevent double realization.
This object should be utilized for processing containers, or objects that
reference other objects where we may not want to take on creating all the
VariableTrackers right away.
"""
_nonvar_fields = {"_cache", *VariableTracker._nonvar_fields}
@staticmethod
def create(value: Any, source: Any, **options: Any) -> "LazyVariableTracker":
return LazyVariableTracker(LazyCache(value, source), source=source, **options)
def __init__(self, _cache: LazyCache, **kwargs: Any) -> None:
assert isinstance(_cache, LazyCache)
super().__init__(**kwargs)
self._cache = _cache
def realize(self) -> VariableTracker:
"""Force construction of the real VariableTracker"""
if self._cache.vt is None:
self._cache.realize()
assert self._cache.vt is not None
return self._cache.vt
def unwrap(self) -> Union[VariableTracker, Self]:
"""Return the real VariableTracker if it already exists"""
if self.is_realized():
assert self._cache.vt is not None
return self._cache.vt
return self
def is_realized(self) -> bool:
return self._cache.vt is not None
def clone(self, **kwargs: Any) -> VariableTracker:
assert kwargs.get("_cache", self._cache) is self._cache
if kwargs.get("source", self.source) is not self.source:
self.realize()
return VariableTracker.clone(self.unwrap(), **kwargs)
def peek_type(self) -> type[Any]:
assert not self.is_realized()
return type(self._cache.value)
def peek_value(self) -> Any:
assert not self.is_realized()
return self._cache.value
def __str__(self) -> str:
if self.is_realized():
return self.unwrap().__str__()
return VariableTracker.__str__(self.unwrap())
def __getattr__(self, item: str) -> Any:
return getattr(self.realize(), item)
# most methods are auto-generated below, these are the ones we want to exclude
visit = VariableTracker.visit # type: ignore[assignment]
__repr__ = VariableTracker.__repr__
@classmethod
def realize_all(
cls,
value: Any,
cache: Optional[Dict[int, Tuple[Any, Any]]] = None,
) -> Any:
"""
Walk an object and realize all LazyVariableTrackers inside it.
"""
if cache is None:
cache = {}
idx = id(value)
if idx in cache:
return cache[idx][0]
value_cls = type(value)
if issubclass(value_cls, LazyVariableTracker):
result = cls.realize_all(value.realize(), cache)
elif issubclass(value_cls, VariableTracker):
# update value in-place
result = value
value_dict = value.__dict__
nonvars = value._nonvar_fields
for key in value_dict:
if key not in nonvars:
value_dict[key] = cls.realize_all(value_dict[key], cache)
elif value_cls is list:
result = [cls.realize_all(v, cache) for v in value]
elif value_cls is tuple:
result = tuple(cls.realize_all(v, cache) for v in value)
elif value_cls in (dict, collections.OrderedDict):
result = {k: cls.realize_all(v, cache) for k, v in list(value.items())}
else:
result = value
# save `value` to keep it alive and ensure id() isn't reused
cache[idx] = (result, value)
return result
class LazySymNodeFormatString:
def __init__(
self, sym_node_variable: SymNodeVariable, fmt_spec_var: VariableTracker
) -> None:
from .constant import ConstantVariable
self.sym_node_var = sym_node_variable
self.fmt_var = ConstantVariable.create(
"{:" + fmt_spec_var.as_python_constant() + "}"
)
def __str__(self) -> str:
return str.format(
self.fmt_var.as_python_constant(),
str(self.sym_node_var.evaluate_expr()),
)
def _create_realize_and_forward(
name: str,
) -> Callable[[LazyVariableTracker, Any, Any], Any]:
@functools.wraps(getattr(VariableTracker, name))
def realize_and_forward(
self: LazyVariableTracker, *args: Any, **kwargs: Any
) -> Any:
return getattr(self.realize(), name)(*args, **kwargs)
return realize_and_forward
def _populate() -> None:
for name, value in VariableTracker.__dict__.items():
if name not in LazyVariableTracker.__dict__:
if callable(value):
setattr(LazyVariableTracker, name, _create_realize_and_forward(name))
_populate()