[export] Update min_val and max_val to Optional[int] in serialization. (#139223)

Summary: According to export team's discussion, we are upgrading min_val and max_val to optional fields which shouldn't break BC and allows the schema to express infinity.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_serialize_infinite_sym_int

Differential Revision: D65167805

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139223
Approved by: https://github.com/yiming0416
This commit is contained in:
Zhengxu Chen
2024-10-30 21:14:17 +00:00
committed by PyTorch MergeBot
parent 6d5944c9f1
commit 03ec25053a
5 changed files with 50 additions and 19 deletions

View File

@ -322,6 +322,34 @@ def forward(self, x):
self.assertEqual(node.inputs[0].name, "self")
self.assertEqual(node.inputs[1].name, "dim")
def test_serialize_infinite_sym_int(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
dim0_ac = torch.export.Dim("dim0_ac")
dim1_bc = torch.export.Dim("dim1_b")
dynamic_shapes = {
"a": {0: dim0_ac},
"b": {1: dim1_bc},
"c": {0: dim0_ac, 1: dim1_bc},
}
exported_module = export_for_training(
DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
for v in serialized.exported_program.range_constraints.values():
self.assertEqual(v.max_val, None)
def test_serialize_list_returns(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (7, 4)
SCHEMA_VERSION = (8, 1)
TREESPEC_VERSION = 1
@ -331,8 +331,8 @@ class GraphSignature:
@dataclass
class RangeConstraint:
min_val: int
max_val: int
min_val: Optional[int]
max_val: Optional[int]
@dataclass

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
Argument:
kind: union
fields:
@ -318,9 +318,9 @@ RangeConstraint:
kind: struct
fields:
min_val:
type: int
type: Optional[int]
max_val:
type: int
type: Optional[int]
ScalarType:
kind: enum
fields:
@ -436,6 +436,6 @@ UserOutputSpec:
arg:
type: Argument
SCHEMA_VERSION:
- 7
- 4
- 8
- 1
TREESPEC_VERSION: 1

View File

@ -64,14 +64,15 @@ def _staged_schema():
elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory()
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}

View File

@ -331,12 +331,12 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
return artifact
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
return math.inf
return None
if val in (-sympy.oo, -int_oo):
return -math.inf
return None
if isinstance(val, sympy.Integer):
return int(val)
@ -355,8 +355,10 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str):
raise RuntimeError(f"Got invalid adjustment {adjust}")
def _int_to_sympy_int(val) -> sympy.Expr:
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val is None:
return default
if val == math.inf:
return int_oo
if val == -math.inf:
@ -1908,7 +1910,7 @@ class GraphModuleDeserializer(metaclass=Final):
lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower, -int_oo), vr.upper)
if example_inputs is not None and len(example_inputs) > 0:
self.example_inputs = deserialize_torch_artifact(example_inputs)
@ -2325,7 +2327,7 @@ class ExportedProgramDeserializer(metaclass=Final):
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(
_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
_int_to_sympy_int(v.min_val, -int_oo), _int_to_sympy_int(v.max_val, int_oo)
)
for k, v in exported_program.range_constraints.items()
}