[dynamo] Support method with different __self__ on user defined objects (#139953)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139953
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2024-11-07 09:58:55 -08:00
committed by PyTorch MergeBot
parent d18bca4961
commit a140e65e0f
3 changed files with 69 additions and 1 deletions

View File

@ -6220,6 +6220,38 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
with torch.device("cpu"):
self.assertEqual(res, split(x))
def test_method_overriding(self):
class DilateConv(torch.nn.Module):
def __init__(
self,
dilate_func=None,
):
super().__init__()
self.dilate_func = dilate_func
def forward(self, x):
return self.dilate_func() * torch.sin(x)
class MainModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = DilateConv(self.dilate_func)
self.a = 4
def dilate_func(self):
return self.a
def forward(self, x):
return self.mod(x)
mod = MainModule()
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = mod(x)
res = opt_mod(x)
self.assertEqual(ref, res)
instantiate_parametrized_tests(ReproTests)

View File

@ -2978,3 +2978,26 @@ class SourcelessBuilder:
SourcelessBuilder._type_handlers = SourcelessBuilder.make_type_handlers()
class SourcelessUserDefinedObjectBuilder:
"""
SourceLessBuilder does not return a UserDefinedObjectVariable, but in some
cases it might be ok to return UserDefinedObjects. In such case, use this
builder.
"""
def __init__(self) -> None:
raise AssertionError("Use SourcelessUserDefinedObjectBuilder.create()")
@staticmethod
def create(tx: "InstructionTranslator", value) -> VariableTracker:
value_type = type(value)
if issubclass(value_type, MutableMapping):
return MutableMappingVariable(value, mutation_type=ValueMutationNew())
elif isinstance(value, torch.nn.Module):
return UnspecializedNNModuleVariable(
value, mutation_type=ValueMutationNew()
)
else:
return UserDefinedObjectVariable(value, mutation_type=ValueMutationNew())

View File

@ -1158,7 +1158,20 @@ class UserDefinedObjectVariable(UserDefinedVariable):
if isinstance(subobj, types.MethodType):
if dynamic_subobj.__self__ is not self.value:
unimplemented("__self__ mismatch for bound method")
if not isinstance(dynamic_subobj.__func__, types.FunctionType):
unimplemented(
f"Found a method whose __func__ is not of FunctionType - {dynamic_subobj}"
)
from .builder import SourcelessUserDefinedObjectBuilder
# This means that we are calling a method of some other object here.
object_vt = SourcelessUserDefinedObjectBuilder.create(
tx, dynamic_subobj.__self__
)
return variables.UserMethodVariable(
dynamic_subobj.__func__, object_vt
)
func = subobj.__func__
else:
assert isinstance(subobj, types.FunctionType)