Make dynamism code robust to NotImplementedException (#148823)

In prod many models have `@property` methods that raise
NotImplementedError. This PR updates our dynamism code to be more robust
to these types of models.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148823
Approved by: https://github.com/laithsakka
This commit is contained in:
bobrenjc93
2025-03-14 10:57:44 -07:00
committed by PyTorch MergeBot
parent ff58ccec6c
commit eb7bf4202d
2 changed files with 50 additions and 7 deletions

View File

@ -110,6 +110,42 @@ class TestDynamism(TestCase):
}
self.assertEqual(result, expected)
def test_property_not_implemented(self):
class ModuleWithNotImplementedProperty(torch.nn.Module):
def __init__(self, x, y):
super().__init__()
self.linear = torch.nn.Linear(x, y)
@property
def not_implemented_property(self):
raise NotImplementedError("This property is not implemented")
module1 = ModuleWithNotImplementedProperty(10, 10)
module2 = ModuleWithNotImplementedProperty(10, 10)
result = track_dynamism_across_examples(
[
{"self": module1},
{"self": module2},
]
)
expected = {
"self": {
"L['self']['_modules']['linear']['_parameters']['weight']": (
False,
False,
),
"L['self']['_modules']['linear']['_parameters']['bias']": (False,),
"L['self']['_modules']['linear']['bias']": (False,),
"L['self']['_modules']['linear']['in_features']": (False,),
"L['self']['_modules']['linear']['out_features']": (False,),
"L['self']['_modules']['linear']['weight']": (False, False),
}
}
self.assertEqual(result, expected)
if __name__ == "__main__":
run_tests()

View File

@ -29,14 +29,21 @@ def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
self_dict["_modules"] = {}
for attr_name in dir(module):
if not attr_name.startswith("_") and not callable(getattr(module, attr_name)):
attr_value = getattr(module, attr_name)
if (
not isinstance(attr_value, torch.nn.Module)
and isinstance(attr_value, (int, float, torch.Tensor))
and type(attr_value) is not bool
try:
if not attr_name.startswith("_") and not callable(
getattr(module, attr_name)
):
self_dict[attr_name] = attr_value
attr_value = getattr(module, attr_name)
if (
not isinstance(attr_value, torch.nn.Module)
and isinstance(attr_value, (int, float, torch.Tensor))
and type(attr_value) is not bool
):
self_dict[attr_name] = attr_value
except NotImplementedError:
# Skip attributes that raise NotImplementedError since they won't
# contain any dynamism anyways.
continue
for name, param in module.named_parameters(recurse=False):
self_dict["_parameters"][name] = param