Files
pytorch/test/export/test_dynamic_shapes.py
Chang Pan 008051b13c [Dynamic Shape][BE] trim _DimHint serialization (#163891)
Summary:
current serialization is a bit hard to read
```
Exporting with the dynamic shape spec: {getitem_123: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_118: (_DimHint(type=<_DimHintType.DYNAMIC: 3>,
min=489, max=31232, _factory=False)), getitem_117: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_116: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_115: (
_DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_46: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False),
 _DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_ro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False), _DimHint(t
ype=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_nro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False)...
```

Test Plan: UT

Differential Revision: D83175131

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163891
Approved by: https://github.com/pianpwk
2025-09-27 00:08:01 +00:00

41 lines
1.3 KiB
Python

# Owner(s): ["oncall: export"]
from torch._dynamo.test_case import run_tests, TestCase
from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim
class TestDimHint(TestCase):
def test_dimhint_repr(self):
hint = _DimHint(_DimHintType.DYNAMIC)
self.assertEqual(repr(hint), "DimHint(DYNAMIC)")
hint_with_bounds = _DimHint(_DimHintType.AUTO, min=1, max=64)
self.assertEqual(repr(hint_with_bounds), "DimHint(AUTO, min=1, max=64)")
non_factory_hint = _DimHint(_DimHintType.STATIC, min=16, _factory=False)
self.assertEqual(repr(non_factory_hint), "DimHint(STATIC, min=16)")
def test_dimhint_factory(self):
factory = _DimHint(_DimHintType.AUTO)
self.assertTrue(factory._factory)
result = factory(min=8, max=32)
self.assertEqual(result.type, _DimHintType.AUTO)
self.assertEqual(result.min, 8)
self.assertEqual(result.max, 32)
self.assertFalse(result._factory)
with self.assertRaises(TypeError) as cm:
result(min=1, max=10)
self.assertIn("object is not callable", str(cm.exception))
bounded = Dim.DYNAMIC(min=4, max=16)
self.assertEqual(repr(bounded), "DimHint(DYNAMIC, min=4, max=16)")
with self.assertRaises(AssertionError):
factory(min=-1)
if __name__ == "__main__":
run_tests()