mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)
This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65)
.
A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
`clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
to support that.
2. it overrides the `shape` property, so this patch updates
`TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
which gets reconstructed in `torch.Tensor.__torch_function__`, so
this patch adds support for calling `torch.Size(...)` on non-constant
inputs.
Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d4dbfd9ed
commit
3463ea1059
@ -597,8 +597,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
|
||||
# This simulates shallow-copying the tensor object.
|
||||
kwargs = dict(tensor_var.__dict__)
|
||||
assert kwargs.pop("class_type") is torch.Tensor, (
|
||||
"invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
input_tensor_type = kwargs.pop("class_type")
|
||||
assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), (
|
||||
f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var"
|
||||
)
|
||||
torch_fn_var = build_torch_function_fn(tx, class_type, cls_source)
|
||||
var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs)
|
||||
@ -638,13 +639,9 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
if hasattr(torch.Tensor, name):
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
unimplemented(
|
||||
f"Accessing overridden method/attribute {name} on a tensor"
|
||||
" subclass with a __torch_function__ override is not supported"
|
||||
)
|
||||
|
||||
# Handle non-overriden attributes inherited from `torch.Tensor`.
|
||||
attr_is_overriden = _is_attr_overidden(tx, self, name)
|
||||
if hasattr(torch.Tensor, name) and not attr_is_overriden:
|
||||
if tx.output.torch_function_enabled:
|
||||
if self.source:
|
||||
install_guard(
|
||||
@ -674,11 +671,23 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
||||
else:
|
||||
import types
|
||||
|
||||
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
||||
attr_source = AttrSource(cls_source, name)
|
||||
if isinstance(attr, types.FunctionType):
|
||||
cls_source = GlobalSource(self.global_mangled_class_name(tx))
|
||||
func_source = AttrSource(cls_source, name)
|
||||
install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
return UserMethodVariable(attr, self)
|
||||
|
||||
elif isinstance(attr, property):
|
||||
getter_source = AttrSource(attr_source, "fget")
|
||||
getter = attr.fget
|
||||
getter_var = UserMethodVariable(getter, self, source=getter_source)
|
||||
return getter_var.call_function(tx, [], {})
|
||||
|
||||
elif attr_is_overriden:
|
||||
unimplemented(
|
||||
f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950
|
||||
)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
|
||||
|
Reference in New Issue
Block a user