mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D65167805 (#139371)
Summary: This diff reverts D65167805 broke the release pipeline Test Plan: NA Differential Revision: D65245198 @diff-train-skip-merge (to silent facebook-github-bot until I have a stamp to land this) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139371 Approved by: https://github.com/malfet
This commit is contained in:
@ -322,34 +322,6 @@ def forward(self, x):
|
||||
self.assertEqual(node.inputs[0].name, "self")
|
||||
self.assertEqual(node.inputs[1].name, "dim")
|
||||
|
||||
def test_serialize_infinite_sym_int(self) -> None:
|
||||
class DynamicShapeSimpleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, a, b, c) -> torch.Tensor:
|
||||
d = (torch.matmul(a, b) + c) / 2
|
||||
d_s0 = d.shape[0]
|
||||
d_s1 = d.shape[1]
|
||||
d_s3 = d_s0 * d_s1
|
||||
e = d.view(d_s3)
|
||||
return torch.cat([e, e])
|
||||
|
||||
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
||||
dim0_ac = torch.export.Dim("dim0_ac")
|
||||
dim1_bc = torch.export.Dim("dim1_b")
|
||||
dynamic_shapes = {
|
||||
"a": {0: dim0_ac},
|
||||
"b": {1: dim1_bc},
|
||||
"c": {0: dim0_ac, 1: dim1_bc},
|
||||
}
|
||||
exported_module = export_for_training(
|
||||
DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
|
||||
).run_decompositions()
|
||||
serialized = ExportedProgramSerializer().serialize(exported_module)
|
||||
for v in serialized.exported_program.range_constraints.values():
|
||||
self.assertEqual(v.max_val, None)
|
||||
|
||||
def test_serialize_list_returns(self) -> None:
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from torch._export.serde.union import _Union
|
||||
|
||||
# NOTE: Please update this value if any modifications are made to the schema
|
||||
SCHEMA_VERSION = (8, 1)
|
||||
SCHEMA_VERSION = (7, 4)
|
||||
TREESPEC_VERSION = 1
|
||||
|
||||
|
||||
@ -331,8 +331,8 @@ class GraphSignature:
|
||||
|
||||
@dataclass
|
||||
class RangeConstraint:
|
||||
min_val: Optional[int]
|
||||
max_val: Optional[int]
|
||||
min_val: int
|
||||
max_val: int
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
|
||||
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
@ -318,9 +318,9 @@ RangeConstraint:
|
||||
kind: struct
|
||||
fields:
|
||||
min_val:
|
||||
type: Optional[int]
|
||||
type: int
|
||||
max_val:
|
||||
type: Optional[int]
|
||||
type: int
|
||||
ScalarType:
|
||||
kind: enum
|
||||
fields:
|
||||
@ -436,6 +436,6 @@ UserOutputSpec:
|
||||
arg:
|
||||
type: Argument
|
||||
SCHEMA_VERSION:
|
||||
- 8
|
||||
- 1
|
||||
- 7
|
||||
- 4
|
||||
TREESPEC_VERSION: 1
|
||||
|
@ -64,15 +64,14 @@ def _staged_schema():
|
||||
elif f.default_factory is not dataclasses.MISSING:
|
||||
value = f.default_factory()
|
||||
|
||||
if t.startswith("Optional[") and value is not None:
|
||||
raise AssertionError(
|
||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||
)
|
||||
|
||||
if value is not dataclasses.MISSING:
|
||||
default = str(value)
|
||||
ret["default"] = default
|
||||
|
||||
if t.startswith("Optional[") and value is not None:
|
||||
raise AssertionError(
|
||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
|
||||
|
@ -333,12 +333,12 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
|
||||
return artifact
|
||||
|
||||
|
||||
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
|
||||
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
|
||||
# Convert simple sympy Integers into concrete int
|
||||
if val in (sympy.oo, int_oo):
|
||||
return None
|
||||
return math.inf
|
||||
if val in (-sympy.oo, -int_oo):
|
||||
return None
|
||||
return -math.inf
|
||||
if isinstance(val, sympy.Integer):
|
||||
return int(val)
|
||||
|
||||
@ -357,10 +357,8 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
|
||||
raise RuntimeError(f"Got invalid adjustment {adjust}")
|
||||
|
||||
|
||||
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
|
||||
def _int_to_sympy_int(val) -> sympy.Expr:
|
||||
# Convert concrete int into simple sympy Integers
|
||||
if val is None:
|
||||
return default
|
||||
if val == math.inf:
|
||||
return int_oo
|
||||
if val == -math.inf:
|
||||
@ -1912,7 +1910,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
lower = vr.lower
|
||||
if vr.upper >= 2: # max is >= 2, not sym bool range
|
||||
lower = max(2, lower)
|
||||
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower, -int_oo), vr.upper)
|
||||
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
|
||||
|
||||
if example_inputs is not None and len(example_inputs) > 0:
|
||||
self.example_inputs = deserialize_torch_artifact(example_inputs)
|
||||
@ -2329,7 +2327,7 @@ class ExportedProgramDeserializer(metaclass=Final):
|
||||
|
||||
symbol_name_to_range = {
|
||||
k: symbolic_shapes.ValueRanges(
|
||||
_int_to_sympy_int(v.min_val, -int_oo), _int_to_sympy_int(v.max_val, int_oo)
|
||||
_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
|
||||
)
|
||||
for k, v in exported_program.range_constraints.items()
|
||||
}
|
||||
|
Reference in New Issue
Block a user