[export] Ensure optional fields always have default value. (#121163)

Summary: Add additional check to make sure we can always unset an optional field.

Test Plan: CI

Differential Revision: D54504243

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121163
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Zhengxu Chen
2024-03-05 17:16:46 +00:00
committed by PyTorch MergeBot
parent 35004b8ab4
commit 85c807b3fd
3 changed files with 14 additions and 5 deletions

View File

@ -51,7 +51,7 @@ class MemoryFormat(IntEnum):
@dataclass
class Device:
type: str
index: Optional[int]
index: Optional[int] = None
@dataclass(repr=False)

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<ada9ec9136fafed82dd21a559f1dd5bab5efc97605692aa037d492c719d1096d>>
# checksum<<7b0269f73d3ea9a084c796ac9323e50cc769759f9505c8448b3e330ef91556c0>>
Argument:
kind: union
fields:
@ -70,6 +70,7 @@ Device:
type: str
index:
type: Optional[int]
default: None
ExportedProgram:
kind: struct
fields:
@ -215,6 +216,7 @@ ModuleCallEntry:
type: str
signature:
type: Optional[ModuleCallSignature]
default: None
ModuleCallSignature:
kind: struct
fields:
@ -318,6 +320,7 @@ SymExpr:
type: str
hint:
type: Optional[SymExprHint]
default: None
SymExprHint:
kind: union
fields:

View File

@ -54,15 +54,21 @@ def _staged_schema():
raise AssertionError(f"Type {t} is not supported in export schema.")
def dump_field(f):
ret = {"type": dump_type(f.type)}
t = dump_type(f.type)
ret = {"type": t}
value = None
value = dataclasses.MISSING
if f.default is not dataclasses.MISSING:
value = f.default
elif f.default_factory is not dataclasses.MISSING:
value = f.default_factory()
if value is not None:
if t.startswith("Optional[") and value is not None:
raise AssertionError(
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
)
if value is not dataclasses.MISSING:
default = str(value)
ret["default"] = default
return ret