[dynamo] Simplify creation of VariableTrackers (#135714)

## `VariableTracker::build()` hides the Builders

### The problem

In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)).

Variations on this code are repeated in many places.

More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.

### The solution

In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.

This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.

**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
Tom Ritchford
2024-10-17 16:21:48 +00:00
committed by PyTorch MergeBot
parent 1581a93e87
commit e1c4548441
18 changed files with 180 additions and 333 deletions

View File

@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
if isinstance(var, TensorWithTFOverrideVariable):
return var.class_type_var(tx)
elif isinstance(var, UserDefinedObjectVariable):
from .builder import SourcelessBuilder, VariableBuilder
if var.source:
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
else:
return SourcelessBuilder.create(tx, var.python_type())
source = var.source and TypeSource(var.source)
return VariableTracker.build(tx, var.python_type(), source)
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
def call_torch_function(
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
):
from .builder import SourcelessBuilder
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
torch_function_type,
fn,
types,
SourcelessBuilder.create(tx, tuple(args)),
SourcelessBuilder.create(tx, kwargs),
VariableTracker.build(tx, tuple(args)),
VariableTracker.build(tx, kwargs),
)
return tx.inline_user_function_return(torch_function_var, tf_args, {})
@ -515,20 +509,13 @@ def call_torch_function(
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
from types import FunctionType
from .builder import SourcelessBuilder, VariableBuilder
func = value.__torch_function__.__func__
if not isinstance(func, FunctionType):
unimplemented("Builtin/C++ torch function implementations NYI")
if source:
return VariableBuilder(
tx,
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
)(value.__torch_function__.__func__)
else:
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
return VariableTracker.build(tx, func, source)
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder
if name in banned_attrs:
unimplemented(
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
GuardBuilder.FUNCTION_MATCH
)
)
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
if tx.output.torch_function_enabled:
import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Calling overridden method {name} on a tensor"
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
# We've established with the above check that the method is not overridden, so we guard that the method is the same
# as the impl defined on tensor and retrieve it
if self.source:
func_var = VariableBuilder(
tx, AttrSource(AttrSource(self.source, "__class__"), name)
)(inspect.getattr_static(self.python_type(), name))
source = AttrSource(AttrSource(self.source, "__class__"), name)
value = inspect.getattr_static(self.python_type(), name)
else:
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
source = None
value = getattr(torch.Tensor, name)
func_var = VariableTracker.build(tx, value, source)
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
else:
return super().call_method(tx, name, args, kwargs)