[resubmit][export] add _union_dataclass to support comparing dataclasses that inherits from union. (#156765)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156765
Approved by: https://github.com/zhxchen17
This commit is contained in:
Yidi Wu
2025-06-24 16:30:30 -07:00
committed by PyTorch MergeBot
parent 53057fc16a
commit 61f6aa36b9
3 changed files with 92 additions and 14 deletions

View File

@ -404,6 +404,62 @@ Example(s):
next_version, _ = check(commit)
self.assertEqual(next_version, [4, 1])
def test_schema_comparison(self):
import torch._export.serde.schema as schema
sig = schema.ModuleCallSignature(
inputs=[
schema.Argument.create(as_none=True),
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s0")
),
],
outputs=[
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s1")
)
],
in_spec="foo",
out_spec="bar",
forward_arg_names=["None", "symint"],
)
# same content as sig
sig_same = schema.ModuleCallSignature(
inputs=[
schema.Argument.create(as_none=True),
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s0")
),
],
outputs=[
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s1")
)
],
in_spec="foo",
out_spec="bar",
forward_arg_names=["None", "symint"],
)
# as_name of symint is different
sig_diff = schema.ModuleCallSignature(
inputs=[
schema.Argument.create(as_none=True),
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s0")
),
],
outputs=[
schema.Argument.create(
as_sym_int=schema.SymIntArgument.create(as_name="s2")
)
],
in_spec="foo",
out_spec="bar",
forward_arg_names=["None", "symint"],
)
self.assertEqual(sig, sig_same)
self.assertNotEqual(sig, sig_diff)
if __name__ == "__main__":
run_tests()

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
from enum import IntEnum
from typing import Annotated, Optional
from torch._export.serde.union import _Union
from torch._export.serde.union import _Union, _union_dataclass
# NOTE: Please update this value if any modifications are made to the schema
@ -60,7 +60,7 @@ class Device:
index: Annotated[Optional[int], 20] = None
@dataclass(repr=False)
@_union_dataclass
class SymExprHint(_Union):
as_int: Annotated[int, 10]
as_bool: Annotated[bool, 20]
@ -77,19 +77,19 @@ class SymExpr:
hint: Annotated[Optional[SymExprHint], 20] = None
@dataclass(repr=False)
@_union_dataclass
class SymInt(_Union):
as_expr: Annotated[SymExpr, 10]
as_int: Annotated[int, 20]
@dataclass(repr=False)
@_union_dataclass
class SymFloat(_Union):
as_expr: Annotated[SymExpr, 10]
as_float: Annotated[float, 20]
@dataclass(repr=False)
@_union_dataclass
class SymBool(_Union):
as_expr: Annotated[SymExpr, 10]
as_bool: Annotated[bool, 20]
@ -112,7 +112,7 @@ class TensorMeta:
# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
# to the "as_int" field.
@dataclass(repr=False)
@_union_dataclass
class SymIntArgument(_Union):
as_name: Annotated[str, 10]
as_int: Annotated[int, 20]
@ -124,7 +124,7 @@ class SymIntArgument(_Union):
# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to
# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints
# to the "as_float" field.
@dataclass(repr=False)
@_union_dataclass
class SymFloatArgument(_Union):
as_name: Annotated[str, 10]
as_float: Annotated[float, 20]
@ -136,7 +136,7 @@ class SymFloatArgument(_Union):
# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
# to the "as_bool" field.
@dataclass(repr=False)
@_union_dataclass
class SymBoolArgument(_Union):
as_name: Annotated[str, 10]
as_bool: Annotated[bool, 20]
@ -156,7 +156,7 @@ class TokenArgument:
# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
# type List[OptionalTensorArgument], with tensor values seiralized to the
# "as_tensor" field, and None values serialized to the "as_none" field.
@dataclass(repr=False)
@_union_dataclass
class OptionalTensorArgument(_Union):
as_tensor: Annotated[TensorArgument, 20]
as_none: Annotated[bool, 10]
@ -175,7 +175,7 @@ class CustomObjArgument:
# This is actually a union type
@dataclass(repr=False)
@_union_dataclass
class Argument(_Union):
as_none: Annotated[bool, 10]
as_tensor: Annotated[TensorArgument, 20]
@ -253,7 +253,7 @@ class UserInputSpec:
arg: Annotated[Argument, 10]
@dataclass(repr=False)
@_union_dataclass
class ConstantValue(_Union):
as_none: Annotated[bool, 10]
as_int: Annotated[int, 20]
@ -298,7 +298,7 @@ class InputTokenSpec:
arg: Annotated[TokenArgument, 10]
@dataclass(repr=False)
@_union_dataclass
class InputSpec(_Union):
user_input: Annotated[UserInputSpec, 10]
parameter: Annotated[InputToParameterSpec, 20]
@ -348,7 +348,7 @@ class OutputTokenSpec:
arg: Annotated[TokenArgument, 10]
@dataclass(repr=False)
@_union_dataclass
class OutputSpec(_Union):
user_output: Annotated[UserOutputSpec, 10]
loss_output: Annotated[LossOutputSpec, 20]

View File

@ -1,7 +1,12 @@
# mypy: allow-untyped-defs
import functools
from collections.abc import Hashable
from dataclasses import fields
from dataclasses import dataclass, fields
from typing import TypeVar
from typing_extensions import dataclass_transform
T = TypeVar("T", bound="_Union")
class _UnionTag(str):
@ -32,6 +37,18 @@ def _get_field_names(cls) -> set[str]:
return {f.name for f in fields(cls)}
# If you turn a schema class that inherits from union into a dataclass, please use
# this decorator to configure it. It's safe, faster and allows code sharing.
#
# For example, _union_dataclass customizes the __eq__ method to only check the type
# and value property instead of default implmentation of dataclass which goes
# through every field in the dataclass.
@dataclass_transform(eq_default=False)
def _union_dataclass(cls: type[T]) -> type[T]:
assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}."
return dataclass(repr=False, eq=False)(cls)
class _Union:
_type: _UnionTag
@ -67,6 +84,11 @@ class _Union:
raise AttributeError(f"Field {name} is not set.")
return attr
def __eq__(self, other: object) -> bool:
if not isinstance(other, _Union):
return False
return self.type == other.type and self.value == other.value
def __str__(self):
return self.__repr__()