Update (base update)

[ghstack-poisoned]
This commit is contained in:
Xuehai Pan
2025-11-07 23:33:46 +08:00
parent 625d31b578
commit a429c5de12

View File

@ -77,20 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
__name = "" __name = ""
for __name in ( for __name, __func in (
"is_namedtuple", ("is_namedtuple", is_namedtuple),
"is_namedtuple_class", ("is_namedtuple_class", is_namedtuple_class),
"is_namedtuple_instance", ("is_namedtuple_instance", is_namedtuple_instance),
"is_structseq", ("is_structseq", is_structseq),
"is_structseq_class", ("is_structseq_class", is_structseq_class),
"is_structseq_instance", ("is_structseq_instance", is_structseq_instance),
"namedtuple_fields", ("namedtuple_fields", namedtuple_fields),
"structseq_fields", ("structseq_fields", structseq_fields),
): ):
__func = globals()[__name] globals()[__name] = substitute_in_graph(
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( __func, # type: ignore[arg-type]
__func.__python_implementation__ can_constant_fold_through=True,
) )(__func.__python_implementation__) # type: ignore[attr-defined]
del __func del __func
del __name del __name
@ -178,7 +178,7 @@ del _Asterisk
class PyTreeSpec: class PyTreeSpec:
"""Analog for :class:`optree.PyTreeSpec` in Python.""" """Analog for :class:`optree.PyTreeSpec` in Python."""
_children: tuple[Self, ...] _children: tuple[PyTreeSpec, ...]
_type: builtins.type | None _type: builtins.type | None
_metadata: Any _metadata: Any
_entries: tuple[Any, ...] _entries: tuple[Any, ...]
@ -210,7 +210,7 @@ class PyTreeSpec:
object.__setattr__(self, "num_children", num_children) object.__setattr__(self, "num_children", num_children)
def __repr__(self, /) -> str: def __repr__(self, /) -> str:
def helper(treespec: Self) -> str: def helper(treespec: PyTreeSpec) -> str:
if treespec.is_leaf(): if treespec.is_leaf():
assert treespec.type is None assert treespec.type is None
return _asterisk return _asterisk
@ -254,7 +254,7 @@ class PyTreeSpec:
return self.num_nodes == 1 and self.num_leaves == 1 return self.num_nodes == 1 and self.num_leaves == 1
def paths(self, /) -> list[tuple[Any, ...]]: def paths(self, /) -> list[tuple[Any, ...]]:
def helper(treespec: Self, path_prefix: list[Any]) -> None: def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
if treespec.is_leaf(): if treespec.is_leaf():
paths.append(path_prefix) paths.append(path_prefix)
return return
@ -272,7 +272,7 @@ class PyTreeSpec:
def accessors(self, /) -> list[optree.PyTreeAccessor]: def accessors(self, /) -> list[optree.PyTreeAccessor]:
def helper( def helper(
treespec: Self, treespec: PyTreeSpec,
entry_path_prefix: list[optree.PyTreeEntry], entry_path_prefix: list[optree.PyTreeEntry],
) -> None: ) -> None:
if treespec.is_leaf(): if treespec.is_leaf():
@ -302,10 +302,10 @@ class PyTreeSpec:
helper(self, []) helper(self, [])
return [optree.PyTreeAccessor(path) for path in entry_paths] return [optree.PyTreeAccessor(path) for path in entry_paths]
def children(self, /) -> list[Self]: def children(self, /) -> list[PyTreeSpec]:
return list(self._children) return list(self._children)
def child(self, index: int, /) -> Self: def child(self, index: int, /) -> PyTreeSpec:
return self._children[index] return self._children[index]
def entries(self, /) -> list[Any]: def entries(self, /) -> list[Any]:
@ -316,7 +316,7 @@ class PyTreeSpec:
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]: def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
def helper( def helper(
treespec: Self, treespec: PyTreeSpec,
node: PyTree, node: PyTree,
subtrees: list[PyTree], subtrees: list[PyTree],
) -> None: ) -> None: