mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Dynamo] Catch unserialisable NN modules (#153503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153503 Approved by: https://github.com/c00w, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
d1f1ff8610
commit
56e1c236bf
@ -3359,6 +3359,24 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
|
||||
res = compiled_module(input_tensor)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_unhashable_nn_submodule(self):
|
||||
class UnhashableModule(torch.nn.Module):
|
||||
def __hash__(self):
|
||||
raise TypeError("Unhashable module")
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unhashable_attr = UnhashableModule()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
mod = MyModule()
|
||||
x = torch.randn(1)
|
||||
compiled_mod = torch.compile(mod, backend="eager")
|
||||
compiled_mod(x)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(NNModuleTestsDevice, globals(), only_for=devices)
|
||||
|
@ -78,7 +78,7 @@ from torch.utils.weak import TensorWeakRef
|
||||
|
||||
from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import InternalTorchDynamoError, unimplemented_v2
|
||||
from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2
|
||||
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||
from ..pgo import (
|
||||
auto_dynamic,
|
||||
@ -1686,6 +1686,7 @@ class VariableBuilder:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
if torch._dynamo.config.inline_inbuilt_nn_modules:
|
||||
freezing = is_parameter_freezing()
|
||||
|
||||
# Guard against the case where user may overwrite named parameters
|
||||
# / named buffers
|
||||
# NOTE: This is not likely to happen but worth guarding to avoid
|
||||
@ -1695,15 +1696,21 @@ class VariableBuilder:
|
||||
and value.named_parameters.__func__
|
||||
is og_module_named_parameters_fn_ptr
|
||||
):
|
||||
try: # catch TypeErrors in named_parameters() from unserializable nn modules
|
||||
for _, p in value.named_parameters():
|
||||
self.mark_static_input(p, guard=freezing)
|
||||
except TypeError as e:
|
||||
raise_observed_exception(type(e), self.tx, args=list(e.args))
|
||||
|
||||
if (
|
||||
callable(value.named_buffers)
|
||||
and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr
|
||||
):
|
||||
try: # catch TypeErrors in named_parameters() from unserializable nn modules
|
||||
for _, b in value.named_buffers():
|
||||
self.mark_static_input(b, guard=freezing)
|
||||
except TypeError as e:
|
||||
raise_observed_exception(type(e), self.tx, args=list(e.args))
|
||||
|
||||
if freezing:
|
||||
# we need to add the module to tracing context
|
||||
|
Reference in New Issue
Block a user