Revert D65167805 (#139371)

Summary:
This diff reverts D65167805
broke the release pipeline

Test Plan: NA

Differential Revision: D65245198

@diff-train-skip-merge (to silent facebook-github-bot until I have a stamp to land this)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139371
Approved by: https://github.com/malfet
This commit is contained in:
Huy Do
2024-10-31 07:25:25 +00:00
committed by PyTorch MergeBot
parent 86e6513c86
commit f98bc9a49d
5 changed files with 19 additions and 50 deletions

View File

@ -322,34 +322,6 @@ def forward(self, x):
self.assertEqual(node.inputs[0].name, "self")
self.assertEqual(node.inputs[1].name, "dim")
def test_serialize_infinite_sym_int(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
dim0_ac = torch.export.Dim("dim0_ac")
dim1_bc = torch.export.Dim("dim1_b")
dynamic_shapes = {
"a": {0: dim0_ac},
"b": {1: dim1_bc},
"c": {0: dim0_ac, 1: dim1_bc},
}
exported_module = export_for_training(
DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
for v in serialized.exported_program.range_constraints.values():
self.assertEqual(v.max_val, None)
def test_serialize_list_returns(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self) -> None:

View File

@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 1)
SCHEMA_VERSION = (7, 4)
TREESPEC_VERSION = 1
@ -331,8 +331,8 @@ class GraphSignature:
@dataclass
class RangeConstraint:
min_val: Optional[int]
max_val: Optional[int]
min_val: int
max_val: int
@dataclass

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
Argument:
kind: union
fields:
@ -318,9 +318,9 @@ RangeConstraint:
kind: struct
fields:
min_val:
type: Optional[int]
type: int
max_val:
type: Optional[int]
type: int
ScalarType:
kind: enum
fields:
@ -436,6 +436,6 @@ UserOutputSpec:
arg:
type: Argument
SCHEMA_VERSION:
- 8
- 1
- 7
- 4
TREESPEC_VERSION: 1

View File

@ -64,15 +64,14 @@ def _staged_schema():
elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory()
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}

View File

@ -333,12 +333,12 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
return artifact
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo):
return None
return math.inf
if val in (-sympy.oo, -int_oo):
return None
return -math.inf
if isinstance(val, sympy.Integer):
return int(val)
@ -357,10 +357,8 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
raise RuntimeError(f"Got invalid adjustment {adjust}")
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers
if val is None:
return default
if val == math.inf:
return int_oo
if val == -math.inf:
@ -1912,7 +1910,7 @@ class GraphModuleDeserializer(metaclass=Final):
lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower, -int_oo), vr.upper)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
if example_inputs is not None and len(example_inputs) > 0:
self.example_inputs = deserialize_torch_artifact(example_inputs)
@ -2329,7 +2327,7 @@ class ExportedProgramDeserializer(metaclass=Final):
symbol_name_to_range = {
k: symbolic_shapes.ValueRanges(
_int_to_sympy_int(v.min_val, -int_oo), _int_to_sympy_int(v.max_val, int_oo)
_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
)
for k, v in exported_program.range_constraints.items()
}