mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[export] Ensure optional fields always have default value. (#121163)
Summary: Add additional check to make sure we can always unset an optional field. Test Plan: CI Differential Revision: D54504243 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121163 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
35004b8ab4
commit
85c807b3fd
@ -51,7 +51,7 @@ class MemoryFormat(IntEnum):
|
||||
@dataclass
|
||||
class Device:
|
||||
type: str
|
||||
index: Optional[int]
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<ada9ec9136fafed82dd21a559f1dd5bab5efc97605692aa037d492c719d1096d>>
|
||||
# checksum<<7b0269f73d3ea9a084c796ac9323e50cc769759f9505c8448b3e330ef91556c0>>
|
||||
Argument:
|
||||
kind: union
|
||||
fields:
|
||||
@ -70,6 +70,7 @@ Device:
|
||||
type: str
|
||||
index:
|
||||
type: Optional[int]
|
||||
default: None
|
||||
ExportedProgram:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -215,6 +216,7 @@ ModuleCallEntry:
|
||||
type: str
|
||||
signature:
|
||||
type: Optional[ModuleCallSignature]
|
||||
default: None
|
||||
ModuleCallSignature:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -318,6 +320,7 @@ SymExpr:
|
||||
type: str
|
||||
hint:
|
||||
type: Optional[SymExprHint]
|
||||
default: None
|
||||
SymExprHint:
|
||||
kind: union
|
||||
fields:
|
||||
|
@ -54,15 +54,21 @@ def _staged_schema():
|
||||
raise AssertionError(f"Type {t} is not supported in export schema.")
|
||||
|
||||
def dump_field(f):
|
||||
ret = {"type": dump_type(f.type)}
|
||||
t = dump_type(f.type)
|
||||
ret = {"type": t}
|
||||
|
||||
value = None
|
||||
value = dataclasses.MISSING
|
||||
if f.default is not dataclasses.MISSING:
|
||||
value = f.default
|
||||
elif f.default_factory is not dataclasses.MISSING:
|
||||
value = f.default_factory()
|
||||
|
||||
if value is not None:
|
||||
if t.startswith("Optional[") and value is not None:
|
||||
raise AssertionError(
|
||||
f"Optional field {ty.__name__}.{f.name} must have default value to be None."
|
||||
)
|
||||
|
||||
if value is not dataclasses.MISSING:
|
||||
default = str(value)
|
||||
ret["default"] = default
|
||||
return ret
|
||||
|
Reference in New Issue
Block a user