[Dynamo] Catch unserialisable NN modules (#153503)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153503
Approved by: https://github.com/c00w, https://github.com/jansel
This commit is contained in:
Raymond Li
2025-05-16 02:55:28 +00:00
committed by PyTorch MergeBot
parent d1f1ff8610
commit 56e1c236bf
2 changed files with 30 additions and 5 deletions

View File

@ -3359,6 +3359,24 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
res = compiled_module(input_tensor)
self.assertEqual(ref, res)
def test_unhashable_nn_submodule(self):
class UnhashableModule(torch.nn.Module):
def __hash__(self):
raise TypeError("Unhashable module")
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.unhashable_attr = UnhashableModule()
def forward(self, x):
return x
mod = MyModule()
x = torch.randn(1)
compiled_mod = torch.compile(mod, backend="eager")
compiled_mod(x)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(NNModuleTestsDevice, globals(), only_for=devices)

View File

@ -78,7 +78,7 @@ from torch.utils.weak import TensorWeakRef
from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented_v2
from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..pgo import (
auto_dynamic,
@ -1686,6 +1686,7 @@ class VariableBuilder:
self.install_guards(GuardBuilder.TYPE_MATCH)
if torch._dynamo.config.inline_inbuilt_nn_modules:
freezing = is_parameter_freezing()
# Guard against the case where user may overwrite named parameters
# / named buffers
# NOTE: This is not likely to happen but worth guarding to avoid
@ -1695,15 +1696,21 @@ class VariableBuilder:
and value.named_parameters.__func__
is og_module_named_parameters_fn_ptr
):
for _, p in value.named_parameters():
self.mark_static_input(p, guard=freezing)
try: # catch TypeErrors in named_parameters() from unserializable nn modules
for _, p in value.named_parameters():
self.mark_static_input(p, guard=freezing)
except TypeError as e:
raise_observed_exception(type(e), self.tx, args=list(e.args))
if (
callable(value.named_buffers)
and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr
):
for _, b in value.named_buffers():
self.mark_static_input(b, guard=freezing)
try: # catch TypeErrors in named_parameters() from unserializable nn modules
for _, b in value.named_buffers():
self.mark_static_input(b, guard=freezing)
except TypeError as e:
raise_observed_exception(type(e), self.tx, args=list(e.args))
if freezing:
# we need to add the module to tracing context