mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make dynamism code robust to NotImplementedException (#148823)
In prod many models have `@property` methods that raise NotImplementedError. This PR updates our dynamism code to be more robust to these types of models. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148823 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
ff58ccec6c
commit
eb7bf4202d
@ -110,6 +110,42 @@ class TestDynamism(TestCase):
|
||||
}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_property_not_implemented(self):
|
||||
class ModuleWithNotImplementedProperty(torch.nn.Module):
|
||||
def __init__(self, x, y):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(x, y)
|
||||
|
||||
@property
|
||||
def not_implemented_property(self):
|
||||
raise NotImplementedError("This property is not implemented")
|
||||
|
||||
module1 = ModuleWithNotImplementedProperty(10, 10)
|
||||
module2 = ModuleWithNotImplementedProperty(10, 10)
|
||||
|
||||
result = track_dynamism_across_examples(
|
||||
[
|
||||
{"self": module1},
|
||||
{"self": module2},
|
||||
]
|
||||
)
|
||||
|
||||
expected = {
|
||||
"self": {
|
||||
"L['self']['_modules']['linear']['_parameters']['weight']": (
|
||||
False,
|
||||
False,
|
||||
),
|
||||
"L['self']['_modules']['linear']['_parameters']['bias']": (False,),
|
||||
"L['self']['_modules']['linear']['bias']": (False,),
|
||||
"L['self']['_modules']['linear']['in_features']": (False,),
|
||||
"L['self']['_modules']['linear']['out_features']": (False,),
|
||||
"L['self']['_modules']['linear']['weight']": (False, False),
|
||||
}
|
||||
}
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -29,14 +29,21 @@ def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
|
||||
self_dict["_modules"] = {}
|
||||
|
||||
for attr_name in dir(module):
|
||||
if not attr_name.startswith("_") and not callable(getattr(module, attr_name)):
|
||||
attr_value = getattr(module, attr_name)
|
||||
if (
|
||||
not isinstance(attr_value, torch.nn.Module)
|
||||
and isinstance(attr_value, (int, float, torch.Tensor))
|
||||
and type(attr_value) is not bool
|
||||
try:
|
||||
if not attr_name.startswith("_") and not callable(
|
||||
getattr(module, attr_name)
|
||||
):
|
||||
self_dict[attr_name] = attr_value
|
||||
attr_value = getattr(module, attr_name)
|
||||
if (
|
||||
not isinstance(attr_value, torch.nn.Module)
|
||||
and isinstance(attr_value, (int, float, torch.Tensor))
|
||||
and type(attr_value) is not bool
|
||||
):
|
||||
self_dict[attr_name] = attr_value
|
||||
except NotImplementedError:
|
||||
# Skip attributes that raise NotImplementedError since they won't
|
||||
# contain any dynamism anyways.
|
||||
continue
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
self_dict["_parameters"][name] = param
|
||||
|
Reference in New Issue
Block a user