mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support method with different __self__ on user defined objects (#139953)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139953 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d18bca4961
commit
a140e65e0f
@ -6220,6 +6220,38 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
with torch.device("cpu"):
|
||||
self.assertEqual(res, split(x))
|
||||
|
||||
def test_method_overriding(self):
|
||||
class DilateConv(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dilate_func=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dilate_func = dilate_func
|
||||
|
||||
def forward(self, x):
|
||||
return self.dilate_func() * torch.sin(x)
|
||||
|
||||
class MainModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mod = DilateConv(self.dilate_func)
|
||||
self.a = 4
|
||||
|
||||
def dilate_func(self):
|
||||
return self.a
|
||||
|
||||
def forward(self, x):
|
||||
return self.mod(x)
|
||||
|
||||
mod = MainModule()
|
||||
|
||||
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref = mod(x)
|
||||
res = opt_mod(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(ReproTests)
|
||||
|
||||
|
@ -2978,3 +2978,26 @@ class SourcelessBuilder:
|
||||
|
||||
|
||||
SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
|
||||
|
||||
|
||||
class SourcelessUserDefinedObjectBuilder:
|
||||
"""
|
||||
SourceLessBuilder does not return a UserDefinedObjectVariable, but in some
|
||||
cases it might be ok to return UserDefinedObjects. In such case, use this
|
||||
builder.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()")
|
||||
|
||||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value) -> VariableTracker:
|
||||
value_type = type(value)
|
||||
if issubclass(value_type, MutableMapping):
|
||||
return MutableMappingVariable(value, mutation_type=ValueMutationNew())
|
||||
elif isinstance(value, torch.nn.Module):
|
||||
return UnspecializedNNModuleVariable(
|
||||
value, mutation_type=ValueMutationNew()
|
||||
)
|
||||
else:
|
||||
return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew())
|
||||
|
@ -1158,7 +1158,20 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
|
||||
if isinstance(subobj, types.MethodType):
|
||||
if dynamic_subobj.__self__ is not self.value:
|
||||
unimplemented("__self__ mismatch for bound method")
|
||||
if not isinstance(dynamic_subobj.__func__, types.FunctionType):
|
||||
unimplemented(
|
||||
f"Found a method whose __func__ is not of FunctionType - {dynamic_subobj}"
|
||||
)
|
||||
|
||||
from .builder import SourcelessUserDefinedObjectBuilder
|
||||
|
||||
# This means that we are calling a method of some other object here.
|
||||
object_vt = SourcelessUserDefinedObjectBuilder.create(
|
||||
tx, dynamic_subobj.__self__
|
||||
)
|
||||
return variables.UserMethodVariable(
|
||||
dynamic_subobj.__func__, object_vt
|
||||
)
|
||||
func = subobj.__func__
|
||||
else:
|
||||
assert isinstance(subobj, types.FunctionType)
|
||||
|
Reference in New Issue
Block a user