Files
pytorch/test/export/test_tree_utils.py
Xuehai Pan 76169cf691 [BE][Easy][9/19] enforce style for empty lines in import segments in test/[e-h]*/ (#129760)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129760
Approved by: https://github.com/ezyang
2024-07-17 14:25:29 +00:00

50 lines
1.8 KiB
Python

# Owner(s): ["oncall: export"]
from collections import OrderedDict
import torch
from torch._dynamo.test_case import TestCase
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.testing._internal.common_utils import run_tests
from torch.utils._pytree import tree_structure
class TestTreeUtils(TestCase):
def test_reorder_kwargs(self):
original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)}
user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)}
orig_spec = tree_structure(((), original_kwargs))
reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec)
# Key ordering should be the same
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
def test_equivalence_check(self):
tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None}
tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None)
spec1 = tree_structure(tree1)
spec2 = tree_structure(tree2)
def dict_ordered_dict_eq(type1, context1, type2, context2):
if type1 is None or type2 is None:
return type1 is type2 and context1 == context2
if issubclass(type1, (dict, OrderedDict)) and issubclass(
type2, (dict, OrderedDict)
):
return context1 == context2
return type1 is type2 and context1 == context2
self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq))
# Wrong ordering should still fail
tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0))
spec3 = tree_structure(tree3)
self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq))
if __name__ == "__main__":
run_tests()