[dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)

This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.

The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435).

There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
   `torch._C._disabled_torch_function_impl`, so this patch updates
   `SuperVariable.call_method` to handle it (we can't do a simpler
   polyfill due to some bug with `var_getattr` raising
   `NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
   defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
   `TensorSubclassVariable.call_function`.

Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
Ryan Guo
2025-04-01 17:29:39 -07:00
committed by PyTorch MergeBot
parent 85df0dc246
commit 33535b3eee
16 changed files with 374 additions and 88 deletions

View File

@ -586,6 +586,22 @@ class MiscTests(torch._inductor.test_case.TestCase):
ref = f(x)
self.assertEqual(res, ref)
def test_newly_constructed_tensor_attr_mutation(self):
def f(x):
y = x + 10
y.grad = x
y.foo = 42
return y
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = torch.ones(5)
res = opt_f(x)
ref = f(x)
self.assertEqual(res, ref)
self.assertEqual(res.grad, ref.grad)
self.assertEqual(res.foo, ref.foo)
def test_closure_recompiles(self):
cnt = CompileCounter()

View File

@ -954,6 +954,140 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
res_act = fn_opt(input)
self.assertEqual(res_exp, res_act)
def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self):
# This is a slight variation of
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
# which basically
# 1. uses tensor subclass to attach quantization metadata onto tensors
# 2. preserve them across torch ops
# 3. use the metadata to dequantize the tensor
# 4. convert it to a regular tensor.
#
# The test is meant to make sure Dynamo won't graph break over it.
class GGUFParameter(torch.nn.Parameter):
def __new__(cls, data, requires_grad=False, quant_type=None):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
return self
def __init__(self, *args, quant_type=None, **kwargs):
self.quant_type = quant_type
def as_tensor(self):
return torch.Tensor(self.data)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
result = super().__torch_function__(func, types, args, kwargs)
quant_type = None
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
# Preserve the original type (tuple or list)
wrapped = [
cls(x, quant_type=quant_type)
if isinstance(x, torch.Tensor)
else x
for x in result
]
return type(result)(wrapped)
else:
return result
def f(x):
tmp = x * 2
tmp = tmp + tmp.quant_type
tmp = tmp.as_tensor()
return tmp * 3
opt_f = torch.compile(f, backend="eager", fullgraph=True)
x = GGUFParameter(torch.ones(2), quant_type=42)
with traceable_subclass(GGUFParameter):
res = f(x)
ref = opt_f(x)
self.assertEqual(res, ref)
def test_newly_constructed_tensor_subclass_attr_mutation(self):
# Make sure the attribute mutation for newly constructed tensor subclass
# object (from constructor call) is handled both during Dynamo tracing
# and codegen-ed to be visible outside `torch.compile`.
class MySubclass(torch.Tensor):
pass
def f():
x = MySubclass(torch.ones(2))
x.bar = 42
return x, x * x.bar
opt_f = compile_full_eager(f)
with traceable_subclass(MySubclass):
res = f()
ref = opt_f()
self.assertEqual(res, ref)
self.assertEqual(res[0].bar, ref[0].bar)
def test_as_subclass_attr_mutation(self):
# Make sure the attribute mutation for newly constructed tensor subclass
# object (from as_subclass call) is handled both during Dynamo tracing
# and codegen-ed to be visible outside `torch.compile`.
class MySubclass(torch.Tensor):
pass
def f():
x = torch.ones(2).as_subclass(MySubclass)
x.bar = 42
return x, x * x.bar
opt_f = compile_full_eager(f)
with traceable_subclass(MySubclass):
res = f()
ref = opt_f()
self.assertEqual(res, ref)
self.assertEqual(res[0].bar, ref[0].bar)
def test_tensor_subclass_attr_codegen_tos(self):
# This repros a very subtle interaction between
# `TensorWithTFOverrideVariable` attribute mutation codegen and
# `PyCodegen.top_of_stack`. It was uncovered from
# `test_tensor_subclass_deepcopy`.
class MySubclass(torch.Tensor):
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_subclass(cls, torch.ones(0))
r.elem = elem
return r
def f(t):
return MySubclass(t.elem.clone())
opt_f = compile_full_eager(f)
t = MySubclass(torch.ones(2))
with traceable_subclass(MySubclass):
res = f(t)
ref = opt_f(t)
# TODO uncomment once we trace into `__new__`.
# self.assertEqual(res, ref)
# self.assertEqual(res.elem, ref.elem)
self.assertEqual(type(res), type(ref))
def test_compile_with_fake_tensor_dynamic_dim(self):
x = torch.randn([3, 4])

View File

@ -0,0 +1 @@
https://github.com/pytorch/pytorch/issues/149881

View File

@ -295,9 +295,7 @@ class SideEffects:
variable: VariableTracker,
mutation_type_cls=ValueMutationExisting,
):
"""Start tracking a new variable for mutation"""
assert variable.source is not None
"""Start tracking an existing or new variable for mutation"""
if id(item) in self.id_to_variable:
raise AssertionError(
f"{variable} is already tracked for mutation. This could be "
@ -576,12 +574,18 @@ class SideEffects:
return [var for var in self.id_to_variable.values() if self.is_modified(var)]
def codegen_save_tempvars(self, cg: PyCodegen):
# Make sure we codegen these modified VT to their source by default, so
# that mutation and aliasing are properly accounted for.
# We must codegen modified VT to their source by default, so that
# mutation and aliasing are properly accounted for.
#
# Since newly constructed objects don't have a source, we manually
# codegen their construction and store them to a newly assigned local
# source. Note that `ValueMutationNew` isn't tracked by SideEffects.
for var in self._get_modified_vars():
if isinstance(var.mutation_type, AttributeMutationNew) and isinstance(
var, variables.CellVariable
):
if not isinstance(var.mutation_type, AttributeMutationNew):
assert var.source is not None
continue
if isinstance(var, variables.CellVariable):
# Cells created in the root frame are created either by
# `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
# `make_cell` for the non-root-frame cells here.
@ -595,18 +599,38 @@ class SideEffects:
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif var.source is None:
var.source = LocalCellSource(var.local_name)
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented_v2(
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
context="",
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
hints=[],
)
elif isinstance(var, variables.TensorVariable):
# NOTE: for historical reasons we never assigned local sources
# to newly constructed tensor object, so we keep it that way.
# They are always loaded from output of the fx graph, so one can
# think of it as having a "OutputGraphSource" for codegen
# purposes.
#
# However, tensor subclass objects are different, because the
# reconstruction logic in `PyCodegen` loads the data tensor from
# graph output and then calls `as_subclass`, meaning we must
# assign a source to it to ensure we only reconstruct one
# subclass instance.
if isinstance(
var, variables.torch_function.TensorWithTFOverrideVariable
):
# Don't codegen from temp source assigned from the 1st pass.
cg(var, allow_cache=False)
cg.add_cache(var)
# `add_cache` generates STORE and consumes TOS, but we never
# cleared it. TODO move this call into `add_cache`
cg.clear_tos()
var.source = LocalSource(cg.tempvars[var])
elif isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented_v2(
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
context="",
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
hints=[],
)
else:
# Reconstruct the bytecode for
# base_cls.__new__(user_cls, *args)
if isinstance(var, variables.UserDefinedObjectVariable):
def load_new_method():
@ -630,10 +654,6 @@ class SideEffects:
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var])
else:
# The remaning cases here are `AttributeMutationExisting` and
# `MutableSideEffects`, which have sources already.
assert var.source is not None
for ctx, args in self.save_for_backward:
cg(ctx.source)
@ -993,7 +1013,7 @@ class SideEffects:
else:
cg.tx.output.update_co_names(name)
cg(value)
cg(var.source)
cg(var)
suffixes.append([create_instruction("STORE_ATTR", argval=name)])
elif isinstance(var, variables.ListIteratorVariable):
for _ in range(var.index):

View File

@ -510,7 +510,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._debug_set_fusion_group_inlining",
"torch._C._demangle",
"torch._C._disabled_torch_dispatch_impl",
"torch._C._disabled_torch_function_impl",
"torch._C._dispatch_call_boxed",
"torch._C._dispatch_check_all_invariants",
"torch._C._dispatch_check_invariants",

View File

@ -140,6 +140,7 @@ from ..utils import (
wrap_fake_exception,
)
from .base import (
AttributeMutationNew,
typestr,
ValueMutationExisting,
ValueMutationNew,
@ -2470,7 +2471,9 @@ def _wrap_fx_preexisting_tensor(
f"wrapped by this instance of Dynamo. Found: {tensor}"
)
return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls)
return construct_tensor_variable(
target_cls, tx, proxy, tensor, subclass_type, options
)
# This is 2 in the above comment (wrapping the output of a traced op)
@ -2504,36 +2507,23 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
import torch._utils
if isinstance(example_value, torch.Tensor):
is_parameter = isinstance(example_value, torch.nn.Parameter)
is_buffer = isinstance(example_value, torch.nn.Buffer)
# NB: In most (all?) cases, this does not actually do a clone.
# (WARNING: this means that if we mutate metadata on the fake
# tensor, the stored example value will update too!)
example_value = _clone_input(example_value, tx.fake_mode)
set_example_value(proxy.node, example_value)
# We bind the unbacked symints in sizes/trdies of tensor lazily.
# So that subgraphs can access the unbacked symbol's proxy in parent graph
# when lifting unbacked symbols of input tensors to subgraph inputs.
# We do it lazily because the tensor may not be used in subgraphs.
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
tensor_type = subclass_type if subclass_type else torch.Tensor
specialized_props["class_type"] = (
torch.nn.Parameter
if is_parameter
else torch.nn.Buffer
if is_buffer
else tensor_type
)
options.update(specialized_props)
return target_cls(proxy, **options)
var = construct_tensor_variable(
target_cls, tx, proxy, example_value, subclass_type, options
)
# NOTE: [Side effect tracking for newly constructed tensor]
# For newly constructed objects that have mutable attributes, we usually
# construct their VariableTracker via `track_object_new`, but since
# tensor variable construction is a bit different, we handle them
# speically here. This ensures that codegen will actually generate the
# attribute mutations on this tensor.
#
# NOTE we pass a dummy object as the `item` argument to avoid
# constructing a dummy _tensor_ object. The object isn't used for
# newly constructed VTs anyways.
tx.output.side_effects._track_obj(
proxy, var, mutation_type_cls=AttributeMutationNew
)
return var
elif (
hasattr(proxy.node.target, "__name__")
and proxy.node.target.__name__ == "set_state"
@ -2702,6 +2692,43 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
)
def construct_tensor_variable(
target_cls, tx, proxy, example_value, subclass_type, options
):
"""
Actually construct a tensor variable after all the pre-processing from
wrapping a pre-existing or newly created tensor value.
"""
# NB: In most (all?) cases, this does not actually do a clone.
# (WARNING: this means that if we mutate metadata on the fake
# tensor, the stored example value will update too!)
example_value = _clone_input(example_value, tx.fake_mode)
set_example_value(proxy.node, example_value)
# We bind the unbacked symints in sizes/trdies of tensor lazily.
# So that subgraphs can access the unbacked symbol's proxy in parent graph
# when lifting unbacked symbols of input tensors to subgraph inputs.
# We do it lazily because the tensor may not be used in subgraphs.
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
specialized_props = target_cls.specialize(example_value)
# TODO: not sure about this fake mode test
if (
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
and example_value.fake_mode is tx.fake_mode
):
if subclass_type:
tensor_type = subclass_type
elif isinstance(example_value, torch.nn.Parameter):
tensor_type = torch.nn.Parameter
elif isinstance(example_value, torch.nn.Buffer):
tensor_type = torch.nn.Buffer
else:
tensor_type = torch.Tensor
specialized_props["class_type"] = tensor_type
options.update(specialized_props)
return target_cls(proxy, **options)
def get_automatic_dynamic_shapes_mark_as():
if config.automatic_dynamic_shapes_mark_as == "dynamic":
return DimDynamic.DYNAMIC

View File

@ -1933,6 +1933,20 @@ class BuiltinVariable(VariableTracker):
"the middle of the graph, which aot_autograd does not currently know how to handle. "
)
elif name == "data":
# See comments on `test_set_data_on_scoped_tensor` for plans
# to support this.
if obj.source is None:
unimplemented_v2(
gb_type="Failed to mutate tensor data attribute",
context=f"setattr({obj}, {name}, {val})",
explanation="Dyanmo only supports mutating `.data`"
" of tensor created outside `torch.compile` region",
hints=[
"Don't mutate `.data` on this tensor, or move "
"the mutation out of `torch.compile` region",
],
)
# Remove the old reference in tracked fakes - if we don't do this
# new .data value size and shape differences will cause
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable

View File

@ -161,6 +161,14 @@ class SuperVariable(VariableTracker):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
inner_fn, source = self._resolved_getattr_and_source(self, name)
# This essentially simulates CPython's `super_getattro`:
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168
# where `inner_fn` is the VT for `res = _super_lookup_descr(...)`.
#
# However, `res`'s type needs to be checked for `tp_descr_get`, and
# applied if it has one. We currently don't have polyfills for all the
# relevant `tp_descr_get`, so we explicitly handle the cases we care
# about here (e.g., note the staticmethod, classmethod cases).
if inner_fn is object.__init__:
return LambdaVariable(identity)
elif inner_fn is torch.nn.Module.__init__:
@ -266,6 +274,29 @@ class SuperVariable(VariableTracker):
source = self.source and AttrSource(self.source, attr_name)
return VariableTracker.build(tx, attr_value, source)
elif inner_fn is torch._C._disabled_torch_function_impl:
# See `THPModule_disable_torch_function` for the C impl.
# The signature of _disabled_torch_function_impl is similar to
# `__torch_function__`, just without the first `cls` argument:
# * (func, types, args, kwargs)
func = args[0]
tf_kwargs = {}
tf_args = args[2].items
for hash_key_vt, value_vt in args[3].items.items():
key_str = hash_key_vt.vt.as_python_constant()
tf_kwargs[key_str] = value_vt
output_old = tx.output.torch_function_enabled
tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled
tx.output.torch_function_enabled = False
tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
try:
return func.call_function(tx, tf_args, tf_kwargs)
finally:
tx.output.torch_function_enabled = output_old
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
tx_old
)
unimplemented(f"non-function or method super: {inner_fn}")

View File

@ -67,7 +67,7 @@ from ..utils import (
set_example_value,
tensortype_to_dtype,
)
from .base import VariableTracker
from .base import AttributeMutationNew, VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable
@ -789,9 +789,14 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx()
py_cls = cls.as_python_constant()
return TensorWithTFOverrideVariable.from_tensor_var(
var = TensorWithTFOverrideVariable.from_tensor_var(
tx, self, py_cls, cls.source
)
# See NOTE [Side effect tracking for newly constructed tensor]
tx.output.side_effects._track_obj(
object(), var, mutation_type_cls=AttributeMutationNew
)
return var
def method_get_device(self):
if isinstance(self.device, torch.device):
@ -1443,14 +1448,37 @@ class TensorSubclassVariable(VariableTracker):
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
if len(args) == 1 and isinstance(args[0], TensorVariable):
from .torch_function import TensorWithTFOverrideVariable
# Handle `Subclass(existing_tensor)` calls.
def impl():
if len(args) == 1 and isinstance(args[0], TensorVariable):
from .torch_function import TensorWithTFOverrideVariable
return TensorWithTFOverrideVariable.from_tensor_var(
tx, args[0], self.value, self.source
)
# This simulates `__new__` and _assumes_ it doesn't have
# side-effects that matters to Dynamo tracing. TODO trace through
# `__new__`.
var = TensorWithTFOverrideVariable.from_tensor_var(
tx, args[0], self.value, self.source
)
return super().call_function(tx, args, kwargs)
# Let Dynamo trace through custom `__init__`
init_func = self.value.__init__
# TODO builder should be able to handle `torch.Tensor.__init__`,
# which is `object.__init__`, so that we can remove this check.
if init_func is not torch.Tensor.__init__:
cls_kwargs = kwargs or {}
VariableTracker.build(tx, init_func).call_function(
tx, [var], cls_kwargs
)
return var
return super().call_function(tx, args, kwargs)
var = impl()
# See NOTE [Side effect tracking for newly constructed tensor]
tx.output.side_effects._track_obj(
object(), var, mutation_type_cls=AttributeMutationNew
)
return var
def as_python_constant(self):
return self.value

View File

@ -62,6 +62,7 @@ from ..utils import (
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import GenericContextWrappingVariable
from .functions import UserMethodVariable
from .lazy import LazyVariableTracker
from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
@ -592,12 +593,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source):
# [Note: __torch_function__] coerce `tensor_var` into a
# TensorWithTFOverrideVariable. In eager, this is just a type change.
# This isn't sound if a __torch_function__ tensor subclass defines a
# constructor, but if only a __torch_function__ impl is defined, this is
# okay to call. It is up to the user whether this is correct behavior
# or not.
import torch
# This simulates shallow-copying the tensor object.
kwargs = dict(tensor_var.__dict__)
assert kwargs.pop("class_type") is torch.Tensor, (
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
@ -640,30 +638,48 @@ class TensorWithTFOverrideVariable(TensorVariable):
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
if tx.output.torch_function_enabled and hasattr(torch.Tensor, name):
if self.source:
install_guard(
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
GuardBuilder.FUNCTION_MATCH
)
if hasattr(torch.Tensor, name):
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
)
if tx.output.torch_function_enabled:
if self.source:
install_guard(
AttrSource(
AttrSource(self.source, "__class__"), name
).make_guard(GuardBuilder.FUNCTION_MATCH)
)
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
return self.call_torch_function(
tx,
get_fn,
TupleVariable([self.class_type_var(tx)]),
[self],
{},
)
else:
return super().var_getattr(tx, name)
# `TensorVariable.var_getattr` doesn't handle user-defined
# function/attribute well, so we explicitly handle them here.
#
# TODO move this logic into `TensorVariable`, or try to merge it
# with similar logic in `UserDefinedObjectVariable`.
try:
attr = inspect.getattr_static(self.class_type, name)
except AttributeError:
pass
else:
import types
if isinstance(attr, types.FunctionType):
cls_source = GlobalSource(self.global_mangled_class_name(tx))
func_source = AttrSource(cls_source, name)
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
return UserMethodVariable(attr, self)
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
return call_torch_function(