[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)

Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2025-03-06 22:53:58 +08:00
committed by PyTorch MergeBot
parent 96176e32a9
commit f08146b67b
8 changed files with 322 additions and 37 deletions

View File

@ -7,6 +7,7 @@ import os
import re
import subprocess
import sys
import time
import unittest
from collections import defaultdict, namedtuple, OrderedDict, UserDict
from dataclasses import dataclass
@ -732,6 +733,133 @@ class TestGenericPytree(TestCase):
with self.assertRaises(TypeError):
pytree_impl.treespec_dumps("random_blurb")
@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_namedtuple(self, pytree):
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
class DirectNamedTuple2(NamedTuple):
x: int
y: int
class IndirectNamedTuple1(DirectNamedTuple1):
pass
class IndirectNamedTuple2(DirectNamedTuple2):
pass
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
self.assertFalse(pytree.is_namedtuple((0, 1)))
self.assertFalse(pytree.is_namedtuple([0, 1]))
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
self.assertFalse(pytree.is_namedtuple({0, 1}))
self.assertFalse(pytree.is_namedtuple(1))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple(time.struct_time))
self.assertFalse(pytree.is_namedtuple(tuple))
self.assertFalse(pytree.is_namedtuple(list))
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
self.assertFalse(pytree.is_namedtuple_class(tuple))
self.assertFalse(pytree.is_namedtuple_class(list))
@parametrize(
"pytree",
[
subtest(py_pytree, name="py"),
subtest(cxx_pytree, name="cxx"),
],
)
def test_is_structseq(self, pytree):
class FakeStructSeq(tuple):
n_fields = 2
n_sequence_fields = 2
n_unnamed_fields = 0
__slots__ = ()
__match_args__ = ("x", "y")
def __new__(cls, sequence):
return super().__new__(cls, sequence)
@property
def x(self):
return self[0]
@property
def y(self):
return self[1]
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
class DirectNamedTuple2(NamedTuple):
x: int
y: int
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
self.assertTrue(pytree.is_structseq(time.gmtime()))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
self.assertFalse(pytree.is_structseq((0, 1)))
self.assertFalse(pytree.is_structseq([0, 1]))
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
self.assertFalse(pytree.is_structseq({0, 1}))
self.assertFalse(pytree.is_structseq(1))
self.assertFalse(pytree.is_structseq(FakeStructSeq))
self.assertTrue(pytree.is_structseq(time.struct_time))
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq(tuple))
self.assertFalse(pytree.is_structseq(list))
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
self.assertTrue(
pytree.is_structseq_class(time.struct_time),
)
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
self.assertFalse(pytree.is_structseq_class(tuple))
self.assertFalse(pytree.is_structseq_class(list))
# torch.return_types.* are all PyStructSequence types
for cls in vars(torch.return_types).values():
if isinstance(cls, type) and issubclass(cls, tuple):
self.assertTrue(pytree.is_structseq(cls))
self.assertTrue(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))
inst = cls(range(cls.n_sequence_fields))
self.assertTrue(pytree.is_structseq(inst))
self.assertTrue(pytree.is_structseq(type(inst)))
self.assertFalse(pytree.is_structseq_class(inst))
self.assertTrue(pytree.is_structseq_class(type(inst)))
self.assertFalse(pytree.is_namedtuple(inst))
self.assertFalse(pytree.is_namedtuple_class(inst))
else:
self.assertFalse(pytree.is_structseq(cls))
self.assertFalse(pytree.is_structseq_class(cls))
self.assertFalse(pytree.is_namedtuple(cls))
self.assertFalse(pytree.is_namedtuple_class(cls))
class TestPythonPytree(TestCase):
def test_deprecated_register_pytree_node(self):