# Owner(s): ["oncall: fx"] import torch from torch.fx.experimental._dynamism import track_dynamism_across_examples from torch.testing._internal.common_utils import TestCase class TestDynamism(TestCase): def test_dynamic_tensor(self): ex1 = {"x": 1, "y": torch.ones(1, 1), "z": {0: torch.ones(1)}} ex2 = {"x": 2, "y": torch.ones(2, 1), "z": {0: torch.ones(2)}} ex3 = {"x": 3, "y": torch.ones(3, 1), "z": {0: torch.ones(3)}} ex4 = {"x": 4, "y": torch.ones(4, 1), "z": {0: torch.ones(4)}} ex5 = {"x": 5, "y": torch.ones(5, 1), "z": {0: torch.ones(5)}} examples = [ex1, ex2, ex3, ex4, ex5] result = track_dynamism_across_examples(examples) expected = { "x": {"L['x']": (True,)}, "y": {"L['y']": (True, False)}, "z": {"L['z'][0]": (True,)}, } self.assertEqual(result, expected) def test_dynamic_tensor_deeply_nested(self): ex1 = {"z": {"z": {"z": {"z": {0: torch.ones(1)}}}}} ex2 = {"z": {"z": {"z": {"z": {0: torch.ones(2)}}}}} ex3 = {"z": {"z": {"z": {"z": {0: torch.ones(3)}}}}} ex4 = {"z": {"z": {"z": {"z": {0: torch.ones(4)}}}}} ex5 = {"z": {"z": {"z": {"z": {0: torch.ones(5)}}}}} examples = [ex1, ex2, ex3, ex4, ex5] result = track_dynamism_across_examples(examples) expected = { "z": { "L['z']['z']['z']['z'][0]": (True,), }, } self.assertEqual(result, expected) def test_mixed_dynamism(self): ex1 = {"a": torch.ones(1, 2), "b": [torch.ones(1), 3], "c": {"d": 42}} ex2 = {"a": torch.ones(2, 2), "b": [torch.ones(2), 4], "c": {"d": 42}} ex3 = {"a": torch.ones(3, 2), "b": [torch.ones(3), 5], "c": {"d": 42}} ex4 = {"a": torch.ones(4, 2), "b": [torch.ones(4), 6], "c": {"d": 42}} ex5 = {"a": torch.ones(5, 2), "b": [torch.ones(5), 7], "c": {"d": 42}} examples = [ex1, ex2, ex3, ex4, ex5] result = track_dynamism_across_examples(examples) expected = { "a": {"L['a']": (True, False)}, "b": {"L['b'][0]": (True,), "L['b'][1]": (True,)}, "c": {"L['c']['d']": (False,)}, } self.assertEqual(result, expected) def test_nn_module(self): class Y(torch.nn.Module): def __init__(self, n_input, n_output): super().__init__() self.compress = torch.nn.Linear(n_input, n_output) self.x = n_input def forward(self, x): return self.compress(x) * self.x class M(torch.nn.Module): def __init__(self, n_input, n_output): self.n_input = n_input self.n_output = n_output super().__init__() self.y = Y(n_input, n_output) def forward(self, x): return self.y(x) model1 = M(3210, 30) model2 = M(3211, 30) result = track_dynamism_across_examples( [ {"self": model1}, {"self": model2}, ] ) expected = { "self": { "L['self']['_modules']['y']['_modules']['compress']['_parameters']['weight']": ( False, True, ), "L['self']['_modules']['y']['_modules']['compress']['_parameters']['bias']": ( False, ), "L['self']['_modules']['y']['_modules']['compress']['bias']": (False,), "L['self']['_modules']['y']['_modules']['compress']['in_features']": ( True, ), "L['self']['_modules']['y']['_modules']['compress']['out_features']": ( False, ), "L['self']['_modules']['y']['_modules']['compress']['weight']": ( False, True, ), "L['self']['_modules']['y']['x']": (True,), "L['self']['n_input']": (True,), "L['self']['n_output']": (False,), } } self.assertEqual(result, expected) def test_property_not_implemented(self): class ModuleWithNotImplementedProperty(torch.nn.Module): def __init__(self, x, y): super().__init__() self.linear = torch.nn.Linear(x, y) @property def not_implemented_property(self): raise NotImplementedError("This property is not implemented") module1 = ModuleWithNotImplementedProperty(10, 10) module2 = ModuleWithNotImplementedProperty(10, 10) result = track_dynamism_across_examples( [ {"self": module1}, {"self": module2}, ] ) expected = { "self": { "L['self']['_modules']['linear']['_parameters']['weight']": ( False, False, ), "L['self']['_modules']['linear']['_parameters']['bias']": (False,), "L['self']['_modules']['linear']['bias']": (False,), "L['self']['_modules']['linear']['in_features']": (False,), "L['self']['_modules']['linear']['out_features']": (False,), "L['self']['_modules']['linear']['weight']": (False, False), } } self.assertEqual(result, expected) if __name__ == "__main__": raise RuntimeError( "This test is not currently used and should be " "enabled in discover_tests.py if required." )