Revert D65167805 (#139371)

Summary:
This diff reverts D65167805
broke the release pipeline

Test Plan: NA

Differential Revision: D65245198

@diff-train-skip-merge (to silent facebook-github-bot until I have a stamp to land this)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139371
Approved by: https://github.com/malfet
This commit is contained in:
Huy Do
2024-10-31 07:25:25 +00:00
committed by PyTorch MergeBot
parent 86e6513c86
commit f98bc9a49d
5 changed files with 19 additions and 50 deletions

View File

@ -322,34 +322,6 @@ def forward(self, x):
self.assertEqual(node.inputs[0].name, "self") self.assertEqual(node.inputs[0].name, "self")
self.assertEqual(node.inputs[1].name, "dim") self.assertEqual(node.inputs[1].name, "dim")
def test_serialize_infinite_sym_int(self) -> None:
class DynamicShapeSimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c) -> torch.Tensor:
d = (torch.matmul(a, b) + c) / 2
d_s0 = d.shape[0]
d_s1 = d.shape[1]
d_s3 = d_s0 * d_s1
e = d.view(d_s3)
return torch.cat([e, e])
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
dim0_ac = torch.export.Dim("dim0_ac")
dim1_bc = torch.export.Dim("dim1_b")
dynamic_shapes = {
"a": {0: dim0_ac},
"b": {1: dim1_bc},
"c": {0: dim0_ac, 1: dim1_bc},
}
exported_module = export_for_training(
DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
for v in serialized.exported_program.range_constraints.values():
self.assertEqual(v.max_val, None)
def test_serialize_list_returns(self) -> None: def test_serialize_list_returns(self) -> None:
class MyModule(torch.nn.Module): class MyModule(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema # NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 1) SCHEMA_VERSION = (7, 4)
TREESPEC_VERSION = 1 TREESPEC_VERSION = 1
@ -331,8 +331,8 @@ class GraphSignature:
@dataclass @dataclass
class RangeConstraint: class RangeConstraint:
min_val: Optional[int] min_val: int
max_val: Optional[int] max_val: int
@dataclass @dataclass

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py # @generated by update_schema.py
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>> # checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
Argument: Argument:
kind: union kind: union
fields: fields:
@ -318,9 +318,9 @@ RangeConstraint:
kind: struct kind: struct
fields: fields:
min_val: min_val:
type: Optional[int] type: int
max_val: max_val:
type: Optional[int] type: int
ScalarType: ScalarType:
kind: enum kind: enum
fields: fields:
@ -436,6 +436,6 @@ UserOutputSpec:
arg: arg:
type: Argument type: Argument
SCHEMA_VERSION: SCHEMA_VERSION:
- 8 - 7
- 1 - 4
TREESPEC_VERSION: 1 TREESPEC_VERSION: 1

View File

@ -64,15 +64,14 @@ def _staged_schema():
elif f.default_factory is not dataclasses.MISSING: elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory() value = f.default_factory()
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
if value is not dataclasses.MISSING: if value is not dataclasses.MISSING:
default = str(value) default = str(value)
ret["default"] = default ret["default"] = default
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
return ret return ret
return {f.name: dump_field(f) for f in dataclasses.fields(ty)} return {f.name: dump_field(f) for f in dataclasses.fields(ty)}

View File

@ -333,12 +333,12 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
return artifact return artifact
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]: def _sympy_int_to_int(val: sympy.Expr, adjust: str):
# Convert simple sympy Integers into concrete int # Convert simple sympy Integers into concrete int
if val in (sympy.oo, int_oo): if val in (sympy.oo, int_oo):
return None return math.inf
if val in (-sympy.oo, -int_oo): if val in (-sympy.oo, -int_oo):
return None return -math.inf
if isinstance(val, sympy.Integer): if isinstance(val, sympy.Integer):
return int(val) return int(val)
@ -357,10 +357,8 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
raise RuntimeError(f"Got invalid adjustment {adjust}") raise RuntimeError(f"Got invalid adjustment {adjust}")
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr: def _int_to_sympy_int(val) -> sympy.Expr:
# Convert concrete int into simple sympy Integers # Convert concrete int into simple sympy Integers
if val is None:
return default
if val == math.inf: if val == math.inf:
return int_oo return int_oo
if val == -math.inf: if val == -math.inf:
@ -1912,7 +1910,7 @@ class GraphModuleDeserializer(metaclass=Final):
lower = vr.lower lower = vr.lower
if vr.upper >= 2: # max is >= 2, not sym bool range if vr.upper >= 2: # max is >= 2, not sym bool range
lower = max(2, lower) lower = max(2, lower)
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower, -int_oo), vr.upper) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
if example_inputs is not None and len(example_inputs) > 0: if example_inputs is not None and len(example_inputs) > 0:
self.example_inputs = deserialize_torch_artifact(example_inputs) self.example_inputs = deserialize_torch_artifact(example_inputs)
@ -2329,7 +2327,7 @@ class ExportedProgramDeserializer(metaclass=Final):
symbol_name_to_range = { symbol_name_to_range = {
k: symbolic_shapes.ValueRanges( k: symbolic_shapes.ValueRanges(
_int_to_sympy_int(v.min_val, -int_oo), _int_to_sympy_int(v.max_val, int_oo) _int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
) )
for k, v in exported_program.range_constraints.items() for k, v in exported_program.range_constraints.items()
} }