Update (base update)

[ghstack-poisoned]
This commit is contained in:
Xuehai Pan
2025-11-07 23:33:46 +08:00
parent 625d31b578
commit a429c5de12

View File

@ -77,20 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
__name = ""
for __name in (
"is_namedtuple",
"is_namedtuple_class",
"is_namedtuple_instance",
"is_structseq",
"is_structseq_class",
"is_structseq_instance",
"namedtuple_fields",
"structseq_fields",
for __name, __func in (
("is_namedtuple", is_namedtuple),
("is_namedtuple_class", is_namedtuple_class),
("is_namedtuple_instance", is_namedtuple_instance),
("is_structseq", is_structseq),
("is_structseq_class", is_structseq_class),
("is_structseq_instance", is_structseq_instance),
("namedtuple_fields", namedtuple_fields),
("structseq_fields", structseq_fields),
):
__func = globals()[__name]
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
__func.__python_implementation__
)
globals()[__name] = substitute_in_graph(
__func, # type: ignore[arg-type]
can_constant_fold_through=True,
)(__func.__python_implementation__) # type: ignore[attr-defined]
del __func
del __name
@ -178,7 +178,7 @@ del _Asterisk
class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python."""
_children: tuple[Self, ...]
_children: tuple[PyTreeSpec, ...]
_type: builtins.type | None
_metadata: Any
_entries: tuple[Any, ...]
@ -210,7 +210,7 @@ class PyTreeSpec:
object.__setattr__(self, "num_children", num_children)
def __repr__(self, /) -> str:
def helper(treespec: Self) -> str:
def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf():
assert treespec.type is None
return _asterisk
@ -254,7 +254,7 @@ class PyTreeSpec:
return self.num_nodes == 1 and self.num_leaves == 1
def paths(self, /) -> list[tuple[Any, ...]]:
def helper(treespec: Self, path_prefix: list[Any]) -> None:
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
if treespec.is_leaf():
paths.append(path_prefix)
return
@ -272,7 +272,7 @@ class PyTreeSpec:
def accessors(self, /) -> list[optree.PyTreeAccessor]:
def helper(
treespec: Self,
treespec: PyTreeSpec,
entry_path_prefix: list[optree.PyTreeEntry],
) -> None:
if treespec.is_leaf():
@ -302,10 +302,10 @@ class PyTreeSpec:
helper(self, [])
return [optree.PyTreeAccessor(path) for path in entry_paths]
def children(self, /) -> list[Self]:
def children(self, /) -> list[PyTreeSpec]:
return list(self._children)
def child(self, index: int, /) -> Self:
def child(self, index: int, /) -> PyTreeSpec:
return self._children[index]
def entries(self, /) -> list[Any]:
@ -316,7 +316,7 @@ class PyTreeSpec:
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
def helper(
treespec: Self,
treespec: PyTreeSpec,
node: PyTree,
subtrees: list[PyTree],
) -> None: