[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)

Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-04-01 02:18:46 +08:00
committed by PyTorch MergeBot
parent 48e9ffc873
commit a10b765bf1
8 changed files with 345 additions and 57 deletions

View File

@ -23,7 +23,15 @@ import optree
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
import torch.utils._pytree as python_pytree
from torch.utils._pytree import KeyEntry as KeyEntry
from torch.utils._pytree import (
is_namedtuple as is_namedtuple,
is_namedtuple_class as is_namedtuple_class,
is_namedtuple_instance as is_namedtuple_instance,
is_structseq as is_structseq,
is_structseq_class as is_structseq_class,
is_structseq_instance as is_structseq_instance,
KeyEntry as KeyEntry,
)
__all__ = [
@ -39,6 +47,7 @@ __all__ = [
"keystr",
"key_get",
"register_pytree_node",
"tree_is_leaf",
"tree_flatten",
"tree_flatten_with_path",
"tree_unflatten",
@ -58,6 +67,12 @@ __all__ = [
"treespec_dumps",
"treespec_loads",
"treespec_pprint",
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
]