mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129760 Approved by: https://github.com/ezyang
50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
# Owner(s): ["oncall: export"]
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.utils._pytree import tree_structure
|
|
|
|
|
|
class TestTreeUtils(TestCase):
|
|
def test_reorder_kwargs(self):
|
|
original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)}
|
|
user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)}
|
|
orig_spec = tree_structure(((), original_kwargs))
|
|
|
|
reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec)
|
|
|
|
# Key ordering should be the same
|
|
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
|
|
self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
|
|
|
|
def test_equivalence_check(self):
|
|
tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None}
|
|
tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None)
|
|
spec1 = tree_structure(tree1)
|
|
spec2 = tree_structure(tree2)
|
|
|
|
def dict_ordered_dict_eq(type1, context1, type2, context2):
|
|
if type1 is None or type2 is None:
|
|
return type1 is type2 and context1 == context2
|
|
|
|
if issubclass(type1, (dict, OrderedDict)) and issubclass(
|
|
type2, (dict, OrderedDict)
|
|
):
|
|
return context1 == context2
|
|
|
|
return type1 is type2 and context1 == context2
|
|
|
|
self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq))
|
|
|
|
# Wrong ordering should still fail
|
|
tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0))
|
|
spec3 = tree_structure(tree3)
|
|
self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|