mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves #75982. New tests are included in this PR. - #75982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
96176e32a9
commit
f08146b67b
@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict, namedtuple, OrderedDict, UserDict
|
||||
from dataclasses import dataclass
|
||||
@ -732,6 +733,133 @@ class TestGenericPytree(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
pytree_impl.treespec_dumps("random_blurb")
|
||||
|
||||
@parametrize(
|
||||
"pytree",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_is_namedtuple(self, pytree):
|
||||
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
||||
|
||||
class DirectNamedTuple2(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
class IndirectNamedTuple1(DirectNamedTuple1):
|
||||
pass
|
||||
|
||||
class IndirectNamedTuple2(DirectNamedTuple2):
|
||||
pass
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1)))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1)))
|
||||
self.assertFalse(pytree.is_namedtuple(time.gmtime()))
|
||||
self.assertFalse(pytree.is_namedtuple((0, 1)))
|
||||
self.assertFalse(pytree.is_namedtuple([0, 1]))
|
||||
self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2}))
|
||||
self.assertFalse(pytree.is_namedtuple({0, 1}))
|
||||
self.assertFalse(pytree.is_namedtuple(1))
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_namedtuple(time.struct_time))
|
||||
self.assertFalse(pytree.is_namedtuple(tuple))
|
||||
self.assertFalse(pytree.is_namedtuple(list))
|
||||
|
||||
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2))
|
||||
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1))
|
||||
self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_namedtuple_class(time.struct_time))
|
||||
self.assertFalse(pytree.is_namedtuple_class(tuple))
|
||||
self.assertFalse(pytree.is_namedtuple_class(list))
|
||||
|
||||
@parametrize(
|
||||
"pytree",
|
||||
[
|
||||
subtest(py_pytree, name="py"),
|
||||
subtest(cxx_pytree, name="cxx"),
|
||||
],
|
||||
)
|
||||
def test_is_structseq(self, pytree):
|
||||
class FakeStructSeq(tuple):
|
||||
n_fields = 2
|
||||
n_sequence_fields = 2
|
||||
n_unnamed_fields = 0
|
||||
|
||||
__slots__ = ()
|
||||
__match_args__ = ("x", "y")
|
||||
|
||||
def __new__(cls, sequence):
|
||||
return super().__new__(cls, sequence)
|
||||
|
||||
@property
|
||||
def x(self):
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
return self[1]
|
||||
|
||||
DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"])
|
||||
|
||||
class DirectNamedTuple2(NamedTuple):
|
||||
x: int
|
||||
y: int
|
||||
|
||||
self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1))))
|
||||
self.assertTrue(pytree.is_structseq(time.gmtime()))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1)))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1)))
|
||||
self.assertFalse(pytree.is_structseq((0, 1)))
|
||||
self.assertFalse(pytree.is_structseq([0, 1]))
|
||||
self.assertFalse(pytree.is_structseq({0: 1, 1: 2}))
|
||||
self.assertFalse(pytree.is_structseq({0, 1}))
|
||||
self.assertFalse(pytree.is_structseq(1))
|
||||
|
||||
self.assertFalse(pytree.is_structseq(FakeStructSeq))
|
||||
self.assertTrue(pytree.is_structseq(time.struct_time))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple1))
|
||||
self.assertFalse(pytree.is_structseq(DirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_structseq(tuple))
|
||||
self.assertFalse(pytree.is_structseq(list))
|
||||
|
||||
self.assertFalse(pytree.is_structseq_class(FakeStructSeq))
|
||||
self.assertTrue(
|
||||
pytree.is_structseq_class(time.struct_time),
|
||||
)
|
||||
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1))
|
||||
self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2))
|
||||
self.assertFalse(pytree.is_structseq_class(tuple))
|
||||
self.assertFalse(pytree.is_structseq_class(list))
|
||||
|
||||
# torch.return_types.* are all PyStructSequence types
|
||||
for cls in vars(torch.return_types).values():
|
||||
if isinstance(cls, type) and issubclass(cls, tuple):
|
||||
self.assertTrue(pytree.is_structseq(cls))
|
||||
self.assertTrue(pytree.is_structseq_class(cls))
|
||||
self.assertFalse(pytree.is_namedtuple(cls))
|
||||
self.assertFalse(pytree.is_namedtuple_class(cls))
|
||||
|
||||
inst = cls(range(cls.n_sequence_fields))
|
||||
self.assertTrue(pytree.is_structseq(inst))
|
||||
self.assertTrue(pytree.is_structseq(type(inst)))
|
||||
self.assertFalse(pytree.is_structseq_class(inst))
|
||||
self.assertTrue(pytree.is_structseq_class(type(inst)))
|
||||
self.assertFalse(pytree.is_namedtuple(inst))
|
||||
self.assertFalse(pytree.is_namedtuple_class(inst))
|
||||
else:
|
||||
self.assertFalse(pytree.is_structseq(cls))
|
||||
self.assertFalse(pytree.is_structseq_class(cls))
|
||||
self.assertFalse(pytree.is_namedtuple(cls))
|
||||
self.assertFalse(pytree.is_namedtuple_class(cls))
|
||||
|
||||
|
||||
class TestPythonPytree(TestCase):
|
||||
def test_deprecated_register_pytree_node(self):
|
||||
|
Reference in New Issue
Block a user