mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert D65167805 (#139371)
Summary: This diff reverts D65167805 broke the release pipeline Test Plan: NA Differential Revision: D65245198 @diff-train-skip-merge (to silent facebook-github-bot until I have a stamp to land this) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139371 Approved by: https://github.com/malfet
This commit is contained in:
@ -322,34 +322,6 @@ def forward(self, x):
|
|||||||
self.assertEqual(node.inputs[0].name, "self")
|
self.assertEqual(node.inputs[0].name, "self")
|
||||||
self.assertEqual(node.inputs[1].name, "dim")
|
self.assertEqual(node.inputs[1].name, "dim")
|
||||||
|
|
||||||
def test_serialize_infinite_sym_int(self) -> None:
|
|
||||||
class DynamicShapeSimpleModel(torch.nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, a, b, c) -> torch.Tensor:
|
|
||||||
d = (torch.matmul(a, b) + c) / 2
|
|
||||||
d_s0 = d.shape[0]
|
|
||||||
d_s1 = d.shape[1]
|
|
||||||
d_s3 = d_s0 * d_s1
|
|
||||||
e = d.view(d_s3)
|
|
||||||
return torch.cat([e, e])
|
|
||||||
|
|
||||||
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
||||||
dim0_ac = torch.export.Dim("dim0_ac")
|
|
||||||
dim1_bc = torch.export.Dim("dim1_b")
|
|
||||||
dynamic_shapes = {
|
|
||||||
"a": {0: dim0_ac},
|
|
||||||
"b": {1: dim1_bc},
|
|
||||||
"c": {0: dim0_ac, 1: dim1_bc},
|
|
||||||
}
|
|
||||||
exported_module = export_for_training(
|
|
||||||
DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes
|
|
||||||
).run_decompositions()
|
|
||||||
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
||||||
for v in serialized.exported_program.range_constraints.values():
|
|
||||||
self.assertEqual(v.max_val, None)
|
|
||||||
|
|
||||||
def test_serialize_list_returns(self) -> None:
|
def test_serialize_list_returns(self) -> None:
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
from torch._export.serde.union import _Union
|
from torch._export.serde.union import _Union
|
||||||
|
|
||||||
# NOTE: Please update this value if any modifications are made to the schema
|
# NOTE: Please update this value if any modifications are made to the schema
|
||||||
SCHEMA_VERSION = (8, 1)
|
SCHEMA_VERSION = (7, 4)
|
||||||
TREESPEC_VERSION = 1
|
TREESPEC_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
@ -331,8 +331,8 @@ class GraphSignature:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RangeConstraint:
|
class RangeConstraint:
|
||||||
min_val: Optional[int]
|
min_val: int
|
||||||
max_val: Optional[int]
|
max_val: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# @generated by update_schema.py
|
# @generated by update_schema.py
|
||||||
# checksum<<8e27d48014d4ec1c773aef056c0c20b61bead54be8338e95d3347d3422472b9a>>
|
# checksum<<69912a674f9c3123e488399d2bc8fdcf1226005721e4ec3dd12da0e176c16e50>>
|
||||||
Argument:
|
Argument:
|
||||||
kind: union
|
kind: union
|
||||||
fields:
|
fields:
|
||||||
@ -318,9 +318,9 @@ RangeConstraint:
|
|||||||
kind: struct
|
kind: struct
|
||||||
fields:
|
fields:
|
||||||
min_val:
|
min_val:
|
||||||
type: Optional[int]
|
type: int
|
||||||
max_val:
|
max_val:
|
||||||
type: Optional[int]
|
type: int
|
||||||
ScalarType:
|
ScalarType:
|
||||||
kind: enum
|
kind: enum
|
||||||
fields:
|
fields:
|
||||||
@ -436,6 +436,6 @@ UserOutputSpec:
|
|||||||
arg:
|
arg:
|
||||||
type: Argument
|
type: Argument
|
||||||
SCHEMA_VERSION:
|
SCHEMA_VERSION:
|
||||||
- 8
|
- 7
|
||||||
- 1
|
- 4
|
||||||
TREESPEC_VERSION: 1
|
TREESPEC_VERSION: 1
|
||||||
|
@ -64,15 +64,14 @@ def _staged_schema():
|
|||||||
elif f.default_factory is not dataclasses.MISSING:
|
elif f.default_factory is not dataclasses.MISSING:
|
||||||
value = f.default_factory()
|
value = f.default_factory()
|
||||||
|
|
||||||
|
if t.startswith("Optional[") and value is not None:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||||
|
)
|
||||||
|
|
||||||
if value is not dataclasses.MISSING:
|
if value is not dataclasses.MISSING:
|
||||||
default = str(value)
|
default = str(value)
|
||||||
ret["default"] = default
|
ret["default"] = default
|
||||||
|
|
||||||
if t.startswith("Optional[") and value is not None:
|
|
||||||
raise AssertionError(
|
|
||||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
|
return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
|
||||||
|
@ -333,12 +333,12 @@ def deserialize_torch_artifact(serialized: Union[Dict[str, Any], Tuple[Any, ...]
|
|||||||
return artifact
|
return artifact
|
||||||
|
|
||||||
|
|
||||||
def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
|
def _sympy_int_to_int(val: sympy.Expr, adjust: str):
|
||||||
# Convert simple sympy Integers into concrete int
|
# Convert simple sympy Integers into concrete int
|
||||||
if val in (sympy.oo, int_oo):
|
if val in (sympy.oo, int_oo):
|
||||||
return None
|
return math.inf
|
||||||
if val in (-sympy.oo, -int_oo):
|
if val in (-sympy.oo, -int_oo):
|
||||||
return None
|
return -math.inf
|
||||||
if isinstance(val, sympy.Integer):
|
if isinstance(val, sympy.Integer):
|
||||||
return int(val)
|
return int(val)
|
||||||
|
|
||||||
@ -357,10 +357,8 @@ def _sympy_int_to_int(val: sympy.Expr, adjust: str) -> Optional[int]:
|
|||||||
raise RuntimeError(f"Got invalid adjustment {adjust}")
|
raise RuntimeError(f"Got invalid adjustment {adjust}")
|
||||||
|
|
||||||
|
|
||||||
def _int_to_sympy_int(val: Optional[int], default) -> sympy.Expr:
|
def _int_to_sympy_int(val) -> sympy.Expr:
|
||||||
# Convert concrete int into simple sympy Integers
|
# Convert concrete int into simple sympy Integers
|
||||||
if val is None:
|
|
||||||
return default
|
|
||||||
if val == math.inf:
|
if val == math.inf:
|
||||||
return int_oo
|
return int_oo
|
||||||
if val == -math.inf:
|
if val == -math.inf:
|
||||||
@ -1912,7 +1910,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||||||
lower = vr.lower
|
lower = vr.lower
|
||||||
if vr.upper >= 2: # max is >= 2, not sym bool range
|
if vr.upper >= 2: # max is >= 2, not sym bool range
|
||||||
lower = max(2, lower)
|
lower = max(2, lower)
|
||||||
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower, -int_oo), vr.upper)
|
self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges(_int_to_sympy_int(lower), vr.upper)
|
||||||
|
|
||||||
if example_inputs is not None and len(example_inputs) > 0:
|
if example_inputs is not None and len(example_inputs) > 0:
|
||||||
self.example_inputs = deserialize_torch_artifact(example_inputs)
|
self.example_inputs = deserialize_torch_artifact(example_inputs)
|
||||||
@ -2329,7 +2327,7 @@ class ExportedProgramDeserializer(metaclass=Final):
|
|||||||
|
|
||||||
symbol_name_to_range = {
|
symbol_name_to_range = {
|
||||||
k: symbolic_shapes.ValueRanges(
|
k: symbolic_shapes.ValueRanges(
|
||||||
_int_to_sympy_int(v.min_val, -int_oo), _int_to_sympy_int(v.max_val, int_oo)
|
_int_to_sympy_int(v.min_val), _int_to_sympy_int(v.max_val)
|
||||||
)
|
)
|
||||||
for k, v in exported_program.range_constraints.items()
|
for k, v in exported_program.range_constraints.items()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user