mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__
(#149482)
This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.
The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435)
.
There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
`torch._C._disabled_torch_function_impl`, so this patch updates
`SuperVariable.call_method` to handle it (we can't do a simpler
polyfill due to some bug with `var_getattr` raising
`NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
`TensorSubclassVariable.call_function`.
Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
This commit is contained in:
committed by
PyTorch MergeBot
parent
85df0dc246
commit
33535b3eee
@ -586,6 +586,22 @@ class MiscTests(torch._inductor.test_case.TestCase):
|
||||
ref = f(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_newly_constructed_tensor_attr_mutation(self):
|
||||
def f(x):
|
||||
y = x + 10
|
||||
y.grad = x
|
||||
y.foo = 42
|
||||
return y
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
x = torch.ones(5)
|
||||
|
||||
res = opt_f(x)
|
||||
ref = f(x)
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res.grad, ref.grad)
|
||||
self.assertEqual(res.foo, ref.foo)
|
||||
|
||||
def test_closure_recompiles(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
|
@ -954,6 +954,140 @@ class SubclassTests(torch._dynamo.test_case.TestCase):
|
||||
res_act = fn_opt(input)
|
||||
self.assertEqual(res_exp, res_act)
|
||||
|
||||
def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self):
|
||||
# This is a slight variation of
|
||||
# https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435
|
||||
# which basically
|
||||
# 1. uses tensor subclass to attach quantization metadata onto tensors
|
||||
# 2. preserve them across torch ops
|
||||
# 3. use the metadata to dequantize the tensor
|
||||
# 4. convert it to a regular tensor.
|
||||
#
|
||||
# The test is meant to make sure Dynamo won't graph break over it.
|
||||
class GGUFParameter(torch.nn.Parameter):
|
||||
def __new__(cls, data, requires_grad=False, quant_type=None):
|
||||
data = data if data is not None else torch.empty(0)
|
||||
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
return self
|
||||
|
||||
def __init__(self, *args, quant_type=None, **kwargs):
|
||||
self.quant_type = quant_type
|
||||
|
||||
def as_tensor(self):
|
||||
return torch.Tensor(self.data)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
quant_type = None
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
||||
quant_type = arg[0].quant_type
|
||||
break
|
||||
if isinstance(arg, GGUFParameter):
|
||||
quant_type = arg.quant_type
|
||||
break
|
||||
if isinstance(result, torch.Tensor):
|
||||
return cls(result, quant_type=quant_type)
|
||||
# Handle tuples and lists
|
||||
elif isinstance(result, (tuple, list)):
|
||||
# Preserve the original type (tuple or list)
|
||||
wrapped = [
|
||||
cls(x, quant_type=quant_type)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
for x in result
|
||||
]
|
||||
return type(result)(wrapped)
|
||||
else:
|
||||
return result
|
||||
|
||||
def f(x):
|
||||
tmp = x * 2
|
||||
tmp = tmp + tmp.quant_type
|
||||
tmp = tmp.as_tensor()
|
||||
return tmp * 3
|
||||
|
||||
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
||||
|
||||
x = GGUFParameter(torch.ones(2), quant_type=42)
|
||||
with traceable_subclass(GGUFParameter):
|
||||
res = f(x)
|
||||
ref = opt_f(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_newly_constructed_tensor_subclass_attr_mutation(self):
|
||||
# Make sure the attribute mutation for newly constructed tensor subclass
|
||||
# object (from constructor call) is handled both during Dynamo tracing
|
||||
# and codegen-ed to be visible outside `torch.compile`.
|
||||
class MySubclass(torch.Tensor):
|
||||
pass
|
||||
|
||||
def f():
|
||||
x = MySubclass(torch.ones(2))
|
||||
x.bar = 42
|
||||
return x, x * x.bar
|
||||
|
||||
opt_f = compile_full_eager(f)
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res[0].bar, ref[0].bar)
|
||||
|
||||
def test_as_subclass_attr_mutation(self):
|
||||
# Make sure the attribute mutation for newly constructed tensor subclass
|
||||
# object (from as_subclass call) is handled both during Dynamo tracing
|
||||
# and codegen-ed to be visible outside `torch.compile`.
|
||||
class MySubclass(torch.Tensor):
|
||||
pass
|
||||
|
||||
def f():
|
||||
x = torch.ones(2).as_subclass(MySubclass)
|
||||
x.bar = 42
|
||||
return x, x * x.bar
|
||||
|
||||
opt_f = compile_full_eager(f)
|
||||
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f()
|
||||
ref = opt_f()
|
||||
|
||||
self.assertEqual(res, ref)
|
||||
self.assertEqual(res[0].bar, ref[0].bar)
|
||||
|
||||
def test_tensor_subclass_attr_codegen_tos(self):
|
||||
# This repros a very subtle interaction between
|
||||
# `TensorWithTFOverrideVariable` attribute mutation codegen and
|
||||
# `PyCodegen.top_of_stack`. It was uncovered from
|
||||
# `test_tensor_subclass_deepcopy`.
|
||||
class MySubclass(torch.Tensor):
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_subclass(cls, torch.ones(0))
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
def f(t):
|
||||
return MySubclass(t.elem.clone())
|
||||
|
||||
opt_f = compile_full_eager(f)
|
||||
|
||||
t = MySubclass(torch.ones(2))
|
||||
with traceable_subclass(MySubclass):
|
||||
res = f(t)
|
||||
ref = opt_f(t)
|
||||
|
||||
# TODO uncomment once we trace into `__new__`.
|
||||
# self.assertEqual(res, ref)
|
||||
# self.assertEqual(res.elem, ref.elem)
|
||||
self.assertEqual(type(res), type(ref))
|
||||
|
||||
def test_compile_with_fake_tensor_dynamic_dim(self):
|
||||
x = torch.randn([3, 4])
|
||||
|
||||
|
@ -0,0 +1 @@
|
||||
https://github.com/pytorch/pytorch/issues/149881
|
@ -295,9 +295,7 @@ class SideEffects:
|
||||
variable: VariableTracker,
|
||||
mutation_type_cls=ValueMutationExisting,
|
||||
):
|
||||
"""Start tracking a new variable for mutation"""
|
||||
assert variable.source is not None
|
||||
|
||||
"""Start tracking an existing or new variable for mutation"""
|
||||
if id(item) in self.id_to_variable:
|
||||
raise AssertionError(
|
||||
f"{variable} is already tracked for mutation. This could be "
|
||||
@ -576,12 +574,18 @@ class SideEffects:
|
||||
return [var for var in self.id_to_variable.values() if self.is_modified(var)]
|
||||
|
||||
def codegen_save_tempvars(self, cg: PyCodegen):
|
||||
# Make sure we codegen these modified VT to their source by default, so
|
||||
# that mutation and aliasing are properly accounted for.
|
||||
# We must codegen modified VT to their source by default, so that
|
||||
# mutation and aliasing are properly accounted for.
|
||||
#
|
||||
# Since newly constructed objects don't have a source, we manually
|
||||
# codegen their construction and store them to a newly assigned local
|
||||
# source. Note that `ValueMutationNew` isn't tracked by SideEffects.
|
||||
for var in self._get_modified_vars():
|
||||
if isinstance(var.mutation_type, AttributeMutationNew) and isinstance(
|
||||
var, variables.CellVariable
|
||||
):
|
||||
if not isinstance(var.mutation_type, AttributeMutationNew):
|
||||
assert var.source is not None
|
||||
continue
|
||||
|
||||
if isinstance(var, variables.CellVariable):
|
||||
# Cells created in the root frame are created either by
|
||||
# `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
|
||||
# `make_cell` for the non-root-frame cells here.
|
||||
@ -595,18 +599,38 @@ class SideEffects:
|
||||
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
||||
elif var.source is None:
|
||||
var.source = LocalCellSource(var.local_name)
|
||||
elif isinstance(var.mutation_type, AttributeMutationNew):
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented_v2(
|
||||
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
|
||||
context="",
|
||||
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
|
||||
hints=[],
|
||||
)
|
||||
|
||||
elif isinstance(var, variables.TensorVariable):
|
||||
# NOTE: for historical reasons we never assigned local sources
|
||||
# to newly constructed tensor object, so we keep it that way.
|
||||
# They are always loaded from output of the fx graph, so one can
|
||||
# think of it as having a "OutputGraphSource" for codegen
|
||||
# purposes.
|
||||
#
|
||||
# However, tensor subclass objects are different, because the
|
||||
# reconstruction logic in `PyCodegen` loads the data tensor from
|
||||
# graph output and then calls `as_subclass`, meaning we must
|
||||
# assign a source to it to ensure we only reconstruct one
|
||||
# subclass instance.
|
||||
if isinstance(
|
||||
var, variables.torch_function.TensorWithTFOverrideVariable
|
||||
):
|
||||
# Don't codegen from temp source assigned from the 1st pass.
|
||||
cg(var, allow_cache=False)
|
||||
cg.add_cache(var)
|
||||
# `add_cache` generates STORE and consumes TOS, but we never
|
||||
# cleared it. TODO move this call into `add_cache`
|
||||
cg.clear_tos()
|
||||
var.source = LocalSource(cg.tempvars[var])
|
||||
elif isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented_v2(
|
||||
gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region",
|
||||
context="",
|
||||
explanation="We cannot reconstruct a torch.autograd.Function's context object.",
|
||||
hints=[],
|
||||
)
|
||||
else:
|
||||
# Reconstruct the bytecode for
|
||||
# base_cls.__new__(user_cls, *args)
|
||||
|
||||
if isinstance(var, variables.UserDefinedObjectVariable):
|
||||
|
||||
def load_new_method():
|
||||
@ -630,10 +654,6 @@ class SideEffects:
|
||||
|
||||
cg.add_cache(var)
|
||||
var.source = LocalSource(cg.tempvars[var])
|
||||
else:
|
||||
# The remaning cases here are `AttributeMutationExisting` and
|
||||
# `MutableSideEffects`, which have sources already.
|
||||
assert var.source is not None
|
||||
|
||||
for ctx, args in self.save_for_backward:
|
||||
cg(ctx.source)
|
||||
@ -993,7 +1013,7 @@ class SideEffects:
|
||||
else:
|
||||
cg.tx.output.update_co_names(name)
|
||||
cg(value)
|
||||
cg(var.source)
|
||||
cg(var)
|
||||
suffixes.append([create_instruction("STORE_ATTR", argval=name)])
|
||||
elif isinstance(var, variables.ListIteratorVariable):
|
||||
for _ in range(var.index):
|
||||
|
@ -510,7 +510,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._debug_set_fusion_group_inlining",
|
||||
"torch._C._demangle",
|
||||
"torch._C._disabled_torch_dispatch_impl",
|
||||
"torch._C._disabled_torch_function_impl",
|
||||
"torch._C._dispatch_call_boxed",
|
||||
"torch._C._dispatch_check_all_invariants",
|
||||
"torch._C._dispatch_check_invariants",
|
||||
|
@ -140,6 +140,7 @@ from ..utils import (
|
||||
wrap_fake_exception,
|
||||
)
|
||||
from .base import (
|
||||
AttributeMutationNew,
|
||||
typestr,
|
||||
ValueMutationExisting,
|
||||
ValueMutationNew,
|
||||
@ -2470,7 +2471,9 @@ def _wrap_fx_preexisting_tensor(
|
||||
f"wrapped by this instance of Dynamo. Found: {tensor}"
|
||||
)
|
||||
|
||||
return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls)
|
||||
return construct_tensor_variable(
|
||||
target_cls, tx, proxy, tensor, subclass_type, options
|
||||
)
|
||||
|
||||
|
||||
# This is 2 in the above comment (wrapping the output of a traced op)
|
||||
@ -2504,36 +2507,23 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
import torch._utils
|
||||
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
is_parameter = isinstance(example_value, torch.nn.Parameter)
|
||||
is_buffer = isinstance(example_value, torch.nn.Buffer)
|
||||
|
||||
# NB: In most (all?) cases, this does not actually do a clone.
|
||||
# (WARNING: this means that if we mutate metadata on the fake
|
||||
# tensor, the stored example value will update too!)
|
||||
example_value = _clone_input(example_value, tx.fake_mode)
|
||||
set_example_value(proxy.node, example_value)
|
||||
# We bind the unbacked symints in sizes/trdies of tensor lazily.
|
||||
# So that subgraphs can access the unbacked symbol's proxy in parent graph
|
||||
# when lifting unbacked symbols of input tensors to subgraph inputs.
|
||||
# We do it lazily because the tensor may not be used in subgraphs.
|
||||
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
|
||||
specialized_props = target_cls.specialize(example_value)
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and example_value.fake_mode is tx.fake_mode
|
||||
):
|
||||
tensor_type = subclass_type if subclass_type else torch.Tensor
|
||||
specialized_props["class_type"] = (
|
||||
torch.nn.Parameter
|
||||
if is_parameter
|
||||
else torch.nn.Buffer
|
||||
if is_buffer
|
||||
else tensor_type
|
||||
)
|
||||
|
||||
options.update(specialized_props)
|
||||
return target_cls(proxy, **options)
|
||||
var = construct_tensor_variable(
|
||||
target_cls, tx, proxy, example_value, subclass_type, options
|
||||
)
|
||||
# NOTE: [Side effect tracking for newly constructed tensor]
|
||||
# For newly constructed objects that have mutable attributes, we usually
|
||||
# construct their VariableTracker via `track_object_new`, but since
|
||||
# tensor variable construction is a bit different, we handle them
|
||||
# speically here. This ensures that codegen will actually generate the
|
||||
# attribute mutations on this tensor.
|
||||
#
|
||||
# NOTE we pass a dummy object as the `item` argument to avoid
|
||||
# constructing a dummy _tensor_ object. The object isn't used for
|
||||
# newly constructed VTs anyways.
|
||||
tx.output.side_effects._track_obj(
|
||||
proxy, var, mutation_type_cls=AttributeMutationNew
|
||||
)
|
||||
return var
|
||||
elif (
|
||||
hasattr(proxy.node.target, "__name__")
|
||||
and proxy.node.target.__name__ == "set_state"
|
||||
@ -2702,6 +2692,43 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
||||
)
|
||||
|
||||
|
||||
def construct_tensor_variable(
|
||||
target_cls, tx, proxy, example_value, subclass_type, options
|
||||
):
|
||||
"""
|
||||
Actually construct a tensor variable after all the pre-processing from
|
||||
wrapping a pre-existing or newly created tensor value.
|
||||
"""
|
||||
# NB: In most (all?) cases, this does not actually do a clone.
|
||||
# (WARNING: this means that if we mutate metadata on the fake
|
||||
# tensor, the stored example value will update too!)
|
||||
example_value = _clone_input(example_value, tx.fake_mode)
|
||||
set_example_value(proxy.node, example_value)
|
||||
# We bind the unbacked symints in sizes/trdies of tensor lazily.
|
||||
# So that subgraphs can access the unbacked symbol's proxy in parent graph
|
||||
# when lifting unbacked symbols of input tensors to subgraph inputs.
|
||||
# We do it lazily because the tensor may not be used in subgraphs.
|
||||
tx.output.current_tracer.track_unbacked_symbols(example_value, proxy)
|
||||
specialized_props = target_cls.specialize(example_value)
|
||||
# TODO: not sure about this fake mode test
|
||||
if (
|
||||
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor)
|
||||
and example_value.fake_mode is tx.fake_mode
|
||||
):
|
||||
if subclass_type:
|
||||
tensor_type = subclass_type
|
||||
elif isinstance(example_value, torch.nn.Parameter):
|
||||
tensor_type = torch.nn.Parameter
|
||||
elif isinstance(example_value, torch.nn.Buffer):
|
||||
tensor_type = torch.nn.Buffer
|
||||
else:
|
||||
tensor_type = torch.Tensor
|
||||
specialized_props["class_type"] = tensor_type
|
||||
|
||||
options.update(specialized_props)
|
||||
return target_cls(proxy, **options)
|
||||
|
||||
|
||||
def get_automatic_dynamic_shapes_mark_as():
|
||||
if config.automatic_dynamic_shapes_mark_as == "dynamic":
|
||||
return DimDynamic.DYNAMIC
|
||||
|
@ -1933,6 +1933,20 @@ class BuiltinVariable(VariableTracker):
|
||||
"the middle of the graph, which aot_autograd does not currently know how to handle. "
|
||||
)
|
||||
elif name == "data":
|
||||
# See comments on `test_set_data_on_scoped_tensor` for plans
|
||||
# to support this.
|
||||
if obj.source is None:
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to mutate tensor data attribute",
|
||||
context=f"setattr({obj}, {name}, {val})",
|
||||
explanation="Dyanmo only supports mutating `.data`"
|
||||
" of tensor created outside `torch.compile` region",
|
||||
hints=[
|
||||
"Don't mutate `.data` on this tensor, or move "
|
||||
"the mutation out of `torch.compile` region",
|
||||
],
|
||||
)
|
||||
|
||||
# Remove the old reference in tracked fakes - if we don't do this
|
||||
# new .data value size and shape differences will cause
|
||||
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable
|
||||
|
@ -161,6 +161,14 @@ class SuperVariable(VariableTracker):
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
inner_fn, source = self._resolved_getattr_and_source(self, name)
|
||||
# This essentially simulates CPython's `super_getattro`:
|
||||
# https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168
|
||||
# where `inner_fn` is the VT for `res = _super_lookup_descr(...)`.
|
||||
#
|
||||
# However, `res`'s type needs to be checked for `tp_descr_get`, and
|
||||
# applied if it has one. We currently don't have polyfills for all the
|
||||
# relevant `tp_descr_get`, so we explicitly handle the cases we care
|
||||
# about here (e.g., note the staticmethod, classmethod cases).
|
||||
if inner_fn is object.__init__:
|
||||
return LambdaVariable(identity)
|
||||
elif inner_fn is torch.nn.Module.__init__:
|
||||
@ -266,6 +274,29 @@ class SuperVariable(VariableTracker):
|
||||
|
||||
source = self.source and AttrSource(self.source, attr_name)
|
||||
return VariableTracker.build(tx, attr_value, source)
|
||||
elif inner_fn is torch._C._disabled_torch_function_impl:
|
||||
# See `THPModule_disable_torch_function` for the C impl.
|
||||
# The signature of _disabled_torch_function_impl is similar to
|
||||
# `__torch_function__`, just without the first `cls` argument:
|
||||
# * (func, types, args, kwargs)
|
||||
func = args[0]
|
||||
tf_kwargs = {}
|
||||
tf_args = args[2].items
|
||||
for hash_key_vt, value_vt in args[3].items.items():
|
||||
key_str = hash_key_vt.vt.as_python_constant()
|
||||
tf_kwargs[key_str] = value_vt
|
||||
|
||||
output_old = tx.output.torch_function_enabled
|
||||
tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled
|
||||
tx.output.torch_function_enabled = False
|
||||
tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
|
||||
try:
|
||||
return func.call_function(tx, tf_args, tf_kwargs)
|
||||
finally:
|
||||
tx.output.torch_function_enabled = output_old
|
||||
tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
|
||||
tx_old
|
||||
)
|
||||
|
||||
unimplemented(f"non-function or method super: {inner_fn}")
|
||||
|
||||
|
@ -67,7 +67,7 @@ from ..utils import (
|
||||
set_example_value,
|
||||
tensortype_to_dtype,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .base import AttributeMutationNew, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
@ -789,9 +789,14 @@ class TensorVariable(VariableTracker):
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
py_cls = cls.as_python_constant()
|
||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||
var = TensorWithTFOverrideVariable.from_tensor_var(
|
||||
tx, self, py_cls, cls.source
|
||||
)
|
||||
# See NOTE [Side effect tracking for newly constructed tensor]
|
||||
tx.output.side_effects._track_obj(
|
||||
object(), var, mutation_type_cls=AttributeMutationNew
|
||||
)
|
||||
return var
|
||||
|
||||
def method_get_device(self):
|
||||
if isinstance(self.device, torch.device):
|
||||
@ -1443,14 +1448,37 @@ class TensorSubclassVariable(VariableTracker):
|
||||
args: list[VariableTracker],
|
||||
kwargs: dict[str, VariableTracker],
|
||||
) -> VariableTracker:
|
||||
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
||||
from .torch_function import TensorWithTFOverrideVariable
|
||||
# Handle `Subclass(existing_tensor)` calls.
|
||||
def impl():
|
||||
if len(args) == 1 and isinstance(args[0], TensorVariable):
|
||||
from .torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
return TensorWithTFOverrideVariable.from_tensor_var(
|
||||
tx, args[0], self.value, self.source
|
||||
)
|
||||
# This simulates `__new__` and _assumes_ it doesn't have
|
||||
# side-effects that matters to Dynamo tracing. TODO trace through
|
||||
# `__new__`.
|
||||
var = TensorWithTFOverrideVariable.from_tensor_var(
|
||||
tx, args[0], self.value, self.source
|
||||
)
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
# Let Dynamo trace through custom `__init__`
|
||||
init_func = self.value.__init__
|
||||
# TODO builder should be able to handle `torch.Tensor.__init__`,
|
||||
# which is `object.__init__`, so that we can remove this check.
|
||||
if init_func is not torch.Tensor.__init__:
|
||||
cls_kwargs = kwargs or {}
|
||||
VariableTracker.build(tx, init_func).call_function(
|
||||
tx, [var], cls_kwargs
|
||||
)
|
||||
return var
|
||||
|
||||
return super().call_function(tx, args, kwargs)
|
||||
|
||||
var = impl()
|
||||
# See NOTE [Side effect tracking for newly constructed tensor]
|
||||
tx.output.side_effects._track_obj(
|
||||
object(), var, mutation_type_cls=AttributeMutationNew
|
||||
)
|
||||
return var
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.value
|
||||
|
@ -62,6 +62,7 @@ from ..utils import (
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .functions import UserMethodVariable
|
||||
from .lazy import LazyVariableTracker
|
||||
from .lists import TupleVariable
|
||||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
@ -592,12 +593,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
def from_tensor_var(cls, tx, tensor_var, class_type, cls_source):
|
||||
# [Note: __torch_function__] coerce `tensor_var` into a
|
||||
# TensorWithTFOverrideVariable. In eager, this is just a type change.
|
||||
# This isn't sound if a __torch_function__ tensor subclass defines a
|
||||
# constructor, but if only a __torch_function__ impl is defined, this is
|
||||
# okay to call. It is up to the user whether this is correct behavior
|
||||
# or not.
|
||||
import torch
|
||||
|
||||
# This simulates shallow-copying the tensor object.
|
||||
kwargs = dict(tensor_var.__dict__)
|
||||
assert kwargs.pop("class_type") is torch.Tensor, (
|
||||
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
@ -640,30 +638,48 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Accessing overridden method/attribute {name} on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
if tx.output.torch_function_enabled and hasattr(torch.Tensor, name):
|
||||
if self.source:
|
||||
install_guard(
|
||||
AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
|
||||
GuardBuilder.FUNCTION_MATCH
|
||||
)
|
||||
if hasattr(torch.Tensor, name):
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Accessing overridden method/attribute {name} on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
||||
|
||||
return self.call_torch_function(
|
||||
tx,
|
||||
get_fn,
|
||||
TupleVariable([self.class_type_var(tx)]),
|
||||
[self],
|
||||
{},
|
||||
)
|
||||
if tx.output.torch_function_enabled:
|
||||
if self.source:
|
||||
install_guard(
|
||||
AttrSource(
|
||||
AttrSource(self.source, "__class__"), name
|
||||
).make_guard(GuardBuilder.FUNCTION_MATCH)
|
||||
)
|
||||
get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__)
|
||||
|
||||
return self.call_torch_function(
|
||||
tx,
|
||||
get_fn,
|
||||
TupleVariable([self.class_type_var(tx)]),
|
||||
[self],
|
||||
{},
|
||||
)
|
||||
else:
|
||||
return super().var_getattr(tx, name)
|
||||
# `TensorVariable.var_getattr` doesn't handle user-defined
|
||||
# function/attribute well, so we explicitly handle them here.
|
||||
#
|
||||
# TODO move this logic into `TensorVariable`, or try to merge it
|
||||
# with similar logic in `UserDefinedObjectVariable`.
|
||||
try:
|
||||
attr = inspect.getattr_static(self.class_type, name)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
import types
|
||||
|
||||
if isinstance(attr, types.FunctionType):
|
||||
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
||||
func_source = AttrSource(cls_source, name)
|
||||
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
return UserMethodVariable(attr, self)
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
return call_torch_function(
|
||||
|
Reference in New Issue
Block a user