[pytree] Register normal class to register_dataclass (#147752)

Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi
2025-04-01 23:28:20 +00:00
committed by PyTorch MergeBot
parent 203a27e0ce
commit 60fe0922f6
4 changed files with 138 additions and 27 deletions

View File

@ -523,9 +523,4 @@ def register_dataclass(
print(ep)
"""
from torch._export.utils import register_dataclass_as_pytree_node
return register_dataclass_as_pytree_node(
cls, serialized_type_name=serialized_type_name
)
pytree.register_dataclass(cls, serialized_type_name=serialized_type_name)