mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: current serialization is a bit hard to read ``` Exporting with the dynamic shape spec: {getitem_123: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_118: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_117: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_116: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_115: ( _DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_46: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False), _DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_ro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False), _DimHint(t ype=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_nro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False)... ``` Test Plan: UT Differential Revision: D83175131 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163891 Approved by: https://github.com/pianpwk
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim
|
|
|
|
|
|
class TestDimHint(TestCase):
|
|
def test_dimhint_repr(self):
|
|
hint = _DimHint(_DimHintType.DYNAMIC)
|
|
self.assertEqual(repr(hint), "DimHint(DYNAMIC)")
|
|
|
|
hint_with_bounds = _DimHint(_DimHintType.AUTO, min=1, max=64)
|
|
self.assertEqual(repr(hint_with_bounds), "DimHint(AUTO, min=1, max=64)")
|
|
|
|
non_factory_hint = _DimHint(_DimHintType.STATIC, min=16, _factory=False)
|
|
self.assertEqual(repr(non_factory_hint), "DimHint(STATIC, min=16)")
|
|
|
|
def test_dimhint_factory(self):
|
|
factory = _DimHint(_DimHintType.AUTO)
|
|
self.assertTrue(factory._factory)
|
|
|
|
result = factory(min=8, max=32)
|
|
self.assertEqual(result.type, _DimHintType.AUTO)
|
|
self.assertEqual(result.min, 8)
|
|
self.assertEqual(result.max, 32)
|
|
self.assertFalse(result._factory)
|
|
|
|
with self.assertRaises(TypeError) as cm:
|
|
result(min=1, max=10)
|
|
self.assertIn("object is not callable", str(cm.exception))
|
|
|
|
bounded = Dim.DYNAMIC(min=4, max=16)
|
|
self.assertEqual(repr(bounded), "DimHint(DYNAMIC, min=4, max=16)")
|
|
|
|
with self.assertRaises(AssertionError):
|
|
factory(min=-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|