[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)

This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65).

A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
   `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
   to support that.
2. it overrides the `shape` property, so this patch updates
   `TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
   which gets reconstructed in `torch.Tensor.__torch_function__`, so
   this patch adds support for calling `torch.Size(...)` on non-constant
   inputs.

Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
This commit is contained in:
Ryan Guo
2025-04-01 17:29:40 -07:00
committed by PyTorch MergeBot
parent 0d4dbfd9ed
commit 3463ea1059
4 changed files with 137 additions and 47 deletions

View File

@ -597,8 +597,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
# This simulates shallow-copying the tensor object.
kwargs = dict(tensor_var.__dict__)
assert kwargs.pop("class_type") is torch.Tensor, (
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
input_tensor_type = kwargs.pop("class_type")
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
)
torch_fn_var = build_torch_function_fn(tx, class_type, cls_source)
var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs)
@ -638,13 +639,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
)
if hasattr(torch.Tensor, name):
if _is_attr_overidden(tx, self, name):
unimplemented(
f"Accessing overridden method/attribute {name} on a tensor"
" subclass with a __torch_function__ override is not supported"
)
# Handle non-overriden attributes inherited from `torch.Tensor`.
attr_is_overriden = _is_attr_overidden(tx, self, name)
if hasattr(torch.Tensor, name) and not attr_is_overriden:
if tx.output.torch_function_enabled:
if self.source:
install_guard(
@ -674,11 +671,23 @@ class TensorWithTFOverrideVariable(TensorVariable):
else:
import types
cls_source = GlobalSource(self.global_mangled_class_name(tx))
attr_source = AttrSource(cls_source, name)
if isinstance(attr, types.FunctionType):
cls_source = GlobalSource(self.global_mangled_class_name(tx))
func_source = AttrSource(cls_source, name)
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
return UserMethodVariable(attr, self)
elif isinstance(attr, property):
getter_source = AttrSource(attr_source, "fget")
getter = attr.fget
getter_var = UserMethodVariable(getter, self, source=getter_source)
return getter_var.call_function(tx, [], {})
elif attr_is_overriden:
unimplemented(
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
)
return super().var_getattr(tx, name)
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):