mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves #75982. New tests are included in this PR. - #75982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
05ac99042f
commit
c95a6b416b
@ -23,7 +23,15 @@ import optree
|
||||
from optree import PyTreeSpec as TreeSpec # direct import for type annotations
|
||||
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch.utils._pytree import KeyEntry as KeyEntry
|
||||
from torch.utils._pytree import (
|
||||
is_namedtuple as is_namedtuple,
|
||||
is_namedtuple_class as is_namedtuple_class,
|
||||
is_namedtuple_instance as is_namedtuple_instance,
|
||||
is_structseq as is_structseq,
|
||||
is_structseq_class as is_structseq_class,
|
||||
is_structseq_instance as is_structseq_instance,
|
||||
KeyEntry as KeyEntry,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -39,6 +47,7 @@ __all__ = [
|
||||
"keystr",
|
||||
"key_get",
|
||||
"register_pytree_node",
|
||||
"tree_is_leaf",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_unflatten",
|
||||
@ -58,6 +67,12 @@ __all__ = [
|
||||
"treespec_dumps",
|
||||
"treespec_loads",
|
||||
"treespec_pprint",
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user