mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamic Shape][BE] trim _DimHint serialization (#163891)
Summary: current serialization is a bit hard to read ``` Exporting with the dynamic shape spec: {getitem_123: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_118: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_117: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_116: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=489, max=31232, _factory=False)), getitem_115: ( _DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True), _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False)), getitem_46: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False), _DimHint(type=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_ro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=1, max=64, _factory=False), _DimHint(t ype=<_DimHintType.STATIC: 2>, min=None, max=None, _factory=True)), _predict_module__base_model_model_nro_sparse_arch_ebc__output_dists_0__dist: (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=29, max=1792, _factory=False)... ``` Test Plan: UT Differential Revision: D83175131 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163891 Approved by: https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4ffd718ec
commit
008051b13c
@ -57,6 +57,15 @@ class _DimHintType(Enum):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DimHint:
|
||||
"""
|
||||
Internal class for dynamic shape hints.
|
||||
- min and max are optional.
|
||||
- _factory is for UX only, below example:
|
||||
auto_hint = _DimHint.AUTO() # _factory=True
|
||||
bounded_hint = auto_hint(min=10, max=100) # Returns new instance with _factory=False
|
||||
bounded_hint(min=5, max=50) # Will fail, non-factory instance cannot be called
|
||||
"""
|
||||
|
||||
type: _DimHintType
|
||||
min: Optional[int] = None
|
||||
max: Optional[int] = None
|
||||
@ -82,6 +91,14 @@ class _DimHint:
|
||||
assert min is None or max is None or min <= max, "min must be <= max"
|
||||
return _DimHint(self.type, min=min, max=max, _factory=False)
|
||||
|
||||
def __repr__(self):
|
||||
parts = [self.type.name]
|
||||
if self.min is not None:
|
||||
parts.append(f"min={self.min}")
|
||||
if self.max is not None:
|
||||
parts.append(f"max={self.max}")
|
||||
return f"DimHint({', '.join(parts)})"
|
||||
|
||||
|
||||
class Dim:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user