mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Simplify creation of VariableTrackers (#135714)
## `VariableTracker::build()` hides the Builders
### The problem
In the current code, creating a `VariableTracker` involves choosing one of two `Builder` classes and either calling a method, or calling a constructor that creates an object that you immediately call, [like this](083c9149b7/torch/_dynamo/variables/functions.py (L761-L768)
).
Variations on this code are repeated in many places.
More, the `Builder` classes have a lot of dependencies, so they have to be loaded late in the whole import process to avoid circular imports, so they end up being repeatedly imported at local scope.
### The solution
In this commit, the import from `builder` and the logic of choosing and calling the Builder class are hidden in a single static factory method, `VariableTracker.build()`, easier to reason about and to import.
This commit net lowers the total lines of code by over 150 lines by removing repetitive logic and unnecessary local imports.
**CHANGES:** Originally the name of the static method was `VariableTracker.create()` but a static method on a derived class, `LazyVariableTracker.create()` now exists with a different signature that's irreconcilable, so the new static method was renamed to `VariableTracker.build()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135714
Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
1581a93e87
commit
e1c4548441
@ -474,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var):
|
||||
if isinstance(var, TensorWithTFOverrideVariable):
|
||||
return var.class_type_var(tx)
|
||||
elif isinstance(var, UserDefinedObjectVariable):
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if var.source:
|
||||
return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, var.python_type())
|
||||
source = var.source and TypeSource(var.source)
|
||||
return VariableTracker.build(tx, var.python_type(), source)
|
||||
|
||||
|
||||
def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
||||
@ -498,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name):
|
||||
def call_torch_function(
|
||||
tx, torch_function_type, torch_function_var, fn, types, args, kwargs
|
||||
):
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
# signature:
|
||||
# def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
tf_args = (
|
||||
torch_function_type,
|
||||
fn,
|
||||
types,
|
||||
SourcelessBuilder.create(tx, tuple(args)),
|
||||
SourcelessBuilder.create(tx, kwargs),
|
||||
VariableTracker.build(tx, tuple(args)),
|
||||
VariableTracker.build(tx, kwargs),
|
||||
)
|
||||
return tx.inline_user_function_return(torch_function_var, tf_args, {})
|
||||
|
||||
@ -515,20 +509,13 @@ def call_torch_function(
|
||||
def build_torch_function_fn(tx: "InstructionTranslator", value, source):
|
||||
from types import FunctionType
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
func = value.__torch_function__.__func__
|
||||
|
||||
if not isinstance(func, FunctionType):
|
||||
unimplemented("Builtin/C++ torch function implementations NYI")
|
||||
|
||||
if source:
|
||||
return VariableBuilder(
|
||||
tx,
|
||||
AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
|
||||
)(value.__torch_function__.__func__)
|
||||
else:
|
||||
return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
|
||||
source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__")
|
||||
return VariableTracker.build(tx, func, source)
|
||||
|
||||
|
||||
def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
|
||||
@ -625,8 +612,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if name in banned_attrs:
|
||||
unimplemented(
|
||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
@ -645,7 +630,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
)
|
||||
)
|
||||
get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
|
||||
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
||||
|
||||
return self.call_torch_function(
|
||||
tx,
|
||||
@ -680,8 +665,6 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
if tx.output.torch_function_enabled:
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Calling overridden method {name} on a tensor"
|
||||
@ -693,11 +676,12 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
# We've established with the above check that the method is not overridden, so we guard that the method is the same
|
||||
# as the impl defined on tensor and retrieve it
|
||||
if self.source:
|
||||
func_var = VariableBuilder(
|
||||
tx, AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
)(inspect.getattr_static(self.python_type(), name))
|
||||
source = AttrSource(AttrSource(self.source, "__class__"), name)
|
||||
value = inspect.getattr_static(self.python_type(), name)
|
||||
else:
|
||||
func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
|
||||
source = None
|
||||
value = getattr(torch.Tensor, name)
|
||||
func_var = VariableTracker.build(tx, value, source)
|
||||
return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
Reference in New Issue
Block a user