mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[resubmit][export] add _union_dataclass to support comparing dataclasses that inherits from union. (#156765)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156765 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
53057fc16a
commit
61f6aa36b9
@ -404,6 +404,62 @@ Example(s):
|
||||
next_version, _ = check(commit)
|
||||
self.assertEqual(next_version, [4, 1])
|
||||
|
||||
def test_schema_comparison(self):
|
||||
import torch._export.serde.schema as schema
|
||||
|
||||
sig = schema.ModuleCallSignature(
|
||||
inputs=[
|
||||
schema.Argument.create(as_none=True),
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s0")
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s1")
|
||||
)
|
||||
],
|
||||
in_spec="foo",
|
||||
out_spec="bar",
|
||||
forward_arg_names=["None", "symint"],
|
||||
)
|
||||
# same content as sig
|
||||
sig_same = schema.ModuleCallSignature(
|
||||
inputs=[
|
||||
schema.Argument.create(as_none=True),
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s0")
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s1")
|
||||
)
|
||||
],
|
||||
in_spec="foo",
|
||||
out_spec="bar",
|
||||
forward_arg_names=["None", "symint"],
|
||||
)
|
||||
# as_name of symint is different
|
||||
sig_diff = schema.ModuleCallSignature(
|
||||
inputs=[
|
||||
schema.Argument.create(as_none=True),
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s0")
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
schema.Argument.create(
|
||||
as_sym_int=schema.SymIntArgument.create(as_name="s2")
|
||||
)
|
||||
],
|
||||
in_spec="foo",
|
||||
out_spec="bar",
|
||||
forward_arg_names=["None", "symint"],
|
||||
)
|
||||
self.assertEqual(sig, sig_same)
|
||||
self.assertNotEqual(sig, sig_diff)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from torch._export.serde.union import _Union
|
||||
from torch._export.serde.union import _Union, _union_dataclass
|
||||
|
||||
|
||||
# NOTE: Please update this value if any modifications are made to the schema
|
||||
@ -60,7 +60,7 @@ class Device:
|
||||
index: Annotated[Optional[int], 20] = None
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymExprHint(_Union):
|
||||
as_int: Annotated[int, 10]
|
||||
as_bool: Annotated[bool, 20]
|
||||
@ -77,19 +77,19 @@ class SymExpr:
|
||||
hint: Annotated[Optional[SymExprHint], 20] = None
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymInt(_Union):
|
||||
as_expr: Annotated[SymExpr, 10]
|
||||
as_int: Annotated[int, 20]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymFloat(_Union):
|
||||
as_expr: Annotated[SymExpr, 10]
|
||||
as_float: Annotated[float, 20]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymBool(_Union):
|
||||
as_expr: Annotated[SymExpr, 10]
|
||||
as_bool: Annotated[bool, 20]
|
||||
@ -112,7 +112,7 @@ class TensorMeta:
|
||||
# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
|
||||
# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
|
||||
# to the "as_int" field.
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymIntArgument(_Union):
|
||||
as_name: Annotated[str, 10]
|
||||
as_int: Annotated[int, 20]
|
||||
@ -124,7 +124,7 @@ class SymIntArgument(_Union):
|
||||
# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to
|
||||
# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints
|
||||
# to the "as_float" field.
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymFloatArgument(_Union):
|
||||
as_name: Annotated[str, 10]
|
||||
as_float: Annotated[float, 20]
|
||||
@ -136,7 +136,7 @@ class SymFloatArgument(_Union):
|
||||
# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
|
||||
# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
|
||||
# to the "as_bool" field.
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class SymBoolArgument(_Union):
|
||||
as_name: Annotated[str, 10]
|
||||
as_bool: Annotated[bool, 20]
|
||||
@ -156,7 +156,7 @@ class TokenArgument:
|
||||
# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
|
||||
# type List[OptionalTensorArgument], with tensor values seiralized to the
|
||||
# "as_tensor" field, and None values serialized to the "as_none" field.
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class OptionalTensorArgument(_Union):
|
||||
as_tensor: Annotated[TensorArgument, 20]
|
||||
as_none: Annotated[bool, 10]
|
||||
@ -175,7 +175,7 @@ class CustomObjArgument:
|
||||
|
||||
|
||||
# This is actually a union type
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class Argument(_Union):
|
||||
as_none: Annotated[bool, 10]
|
||||
as_tensor: Annotated[TensorArgument, 20]
|
||||
@ -253,7 +253,7 @@ class UserInputSpec:
|
||||
arg: Annotated[Argument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class ConstantValue(_Union):
|
||||
as_none: Annotated[bool, 10]
|
||||
as_int: Annotated[int, 20]
|
||||
@ -298,7 +298,7 @@ class InputTokenSpec:
|
||||
arg: Annotated[TokenArgument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class InputSpec(_Union):
|
||||
user_input: Annotated[UserInputSpec, 10]
|
||||
parameter: Annotated[InputToParameterSpec, 20]
|
||||
@ -348,7 +348,7 @@ class OutputTokenSpec:
|
||||
arg: Annotated[TokenArgument, 10]
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
@_union_dataclass
|
||||
class OutputSpec(_Union):
|
||||
user_output: Annotated[UserOutputSpec, 10]
|
||||
loss_output: Annotated[LossOutputSpec, 20]
|
||||
|
@ -1,7 +1,12 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections.abc import Hashable
|
||||
from dataclasses import fields
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import TypeVar
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
|
||||
T = TypeVar("T", bound="_Union")
|
||||
|
||||
|
||||
class _UnionTag(str):
|
||||
@ -32,6 +37,18 @@ def _get_field_names(cls) -> set[str]:
|
||||
return {f.name for f in fields(cls)}
|
||||
|
||||
|
||||
# If you turn a schema class that inherits from union into a dataclass, please use
|
||||
# this decorator to configure it. It's safe, faster and allows code sharing.
|
||||
#
|
||||
# For example, _union_dataclass customizes the __eq__ method to only check the type
|
||||
# and value property instead of default implmentation of dataclass which goes
|
||||
# through every field in the dataclass.
|
||||
@dataclass_transform(eq_default=False)
|
||||
def _union_dataclass(cls: type[T]) -> type[T]:
|
||||
assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}."
|
||||
return dataclass(repr=False, eq=False)(cls)
|
||||
|
||||
|
||||
class _Union:
|
||||
_type: _UnionTag
|
||||
|
||||
@ -67,6 +84,11 @@ class _Union:
|
||||
raise AttributeError(f"Field {name} is not set.")
|
||||
return attr
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _Union):
|
||||
return False
|
||||
return self.type == other.type and self.value == other.value
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
Reference in New Issue
Block a user