Files
pytorch/test/fx/test_dynamism.py
Anthony Barbier c8d44a2296 Add __main__ guards to fx tests (#154715)
This PR is part of a series attempting to re-submit #134592 as smaller PRs.

In fx tests:

- Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run.
- Raise a RuntimeError on tests which have been disabled (not run)
- Remove any remaining uses of "unittest.main()""

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154715
Approved by: https://github.com/Skylion007
2025-06-04 14:38:50 +00:00

155 lines
5.6 KiB
Python

# Owner(s): ["oncall: fx"]
import torch
from torch.fx.experimental._dynamism import track_dynamism_across_examples
from torch.testing._internal.common_utils import TestCase
class TestDynamism(TestCase):
def test_dynamic_tensor(self):
ex1 = {"x": 1, "y": torch.ones(1, 1), "z": {0: torch.ones(1)}}
ex2 = {"x": 2, "y": torch.ones(2, 1), "z": {0: torch.ones(2)}}
ex3 = {"x": 3, "y": torch.ones(3, 1), "z": {0: torch.ones(3)}}
ex4 = {"x": 4, "y": torch.ones(4, 1), "z": {0: torch.ones(4)}}
ex5 = {"x": 5, "y": torch.ones(5, 1), "z": {0: torch.ones(5)}}
examples = [ex1, ex2, ex3, ex4, ex5]
result = track_dynamism_across_examples(examples)
expected = {
"x": {"L['x']": (True,)},
"y": {"L['y']": (True, False)},
"z": {"L['z'][0]": (True,)},
}
self.assertEqual(result, expected)
def test_dynamic_tensor_deeply_nested(self):
ex1 = {"z": {"z": {"z": {"z": {0: torch.ones(1)}}}}}
ex2 = {"z": {"z": {"z": {"z": {0: torch.ones(2)}}}}}
ex3 = {"z": {"z": {"z": {"z": {0: torch.ones(3)}}}}}
ex4 = {"z": {"z": {"z": {"z": {0: torch.ones(4)}}}}}
ex5 = {"z": {"z": {"z": {"z": {0: torch.ones(5)}}}}}
examples = [ex1, ex2, ex3, ex4, ex5]
result = track_dynamism_across_examples(examples)
expected = {
"z": {
"L['z']['z']['z']['z'][0]": (True,),
},
}
self.assertEqual(result, expected)
def test_mixed_dynamism(self):
ex1 = {"a": torch.ones(1, 2), "b": [torch.ones(1), 3], "c": {"d": 42}}
ex2 = {"a": torch.ones(2, 2), "b": [torch.ones(2), 4], "c": {"d": 42}}
ex3 = {"a": torch.ones(3, 2), "b": [torch.ones(3), 5], "c": {"d": 42}}
ex4 = {"a": torch.ones(4, 2), "b": [torch.ones(4), 6], "c": {"d": 42}}
ex5 = {"a": torch.ones(5, 2), "b": [torch.ones(5), 7], "c": {"d": 42}}
examples = [ex1, ex2, ex3, ex4, ex5]
result = track_dynamism_across_examples(examples)
expected = {
"a": {"L['a']": (True, False)},
"b": {"L['b'][0]": (True,), "L['b'][1]": (True,)},
"c": {"L['c']['d']": (False,)},
}
self.assertEqual(result, expected)
def test_nn_module(self):
class Y(torch.nn.Module):
def __init__(self, n_input, n_output):
super().__init__()
self.compress = torch.nn.Linear(n_input, n_output)
self.x = n_input
def forward(self, x):
return self.compress(x) * self.x
class M(torch.nn.Module):
def __init__(self, n_input, n_output):
self.n_input = n_input
self.n_output = n_output
super().__init__()
self.y = Y(n_input, n_output)
def forward(self, x):
return self.y(x)
model1 = M(3210, 30)
model2 = M(3211, 30)
result = track_dynamism_across_examples(
[
{"self": model1},
{"self": model2},
]
)
expected = {
"self": {
"L['self']['_modules']['y']['_modules']['compress']['_parameters']['weight']": (
False,
True,
),
"L['self']['_modules']['y']['_modules']['compress']['_parameters']['bias']": (
False,
),
"L['self']['_modules']['y']['_modules']['compress']['bias']": (False,),
"L['self']['_modules']['y']['_modules']['compress']['in_features']": (
True,
),
"L['self']['_modules']['y']['_modules']['compress']['out_features']": (
False,
),
"L['self']['_modules']['y']['_modules']['compress']['weight']": (
False,
True,
),
"L['self']['_modules']['y']['x']": (True,),
"L['self']['n_input']": (True,),
"L['self']['n_output']": (False,),
}
}
self.assertEqual(result, expected)
def test_property_not_implemented(self):
class ModuleWithNotImplementedProperty(torch.nn.Module):
def __init__(self, x, y):
super().__init__()
self.linear = torch.nn.Linear(x, y)
@property
def not_implemented_property(self):
raise NotImplementedError("This property is not implemented")
module1 = ModuleWithNotImplementedProperty(10, 10)
module2 = ModuleWithNotImplementedProperty(10, 10)
result = track_dynamism_across_examples(
[
{"self": module1},
{"self": module2},
]
)
expected = {
"self": {
"L['self']['_modules']['linear']['_parameters']['weight']": (
False,
False,
),
"L['self']['_modules']['linear']['_parameters']['bias']": (False,),
"L['self']['_modules']['linear']['bias']": (False,),
"L['self']['_modules']['linear']['in_features']": (False,),
"L['self']['_modules']['linear']['out_features']": (False,),
"L['self']['_modules']['linear']['weight']": (False, False),
}
}
self.assertEqual(result, expected)
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)