mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -77,20 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
|
||||
|
||||
|
||||
__name = ""
|
||||
for __name in (
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
for __name, __func in (
|
||||
("is_namedtuple", is_namedtuple),
|
||||
("is_namedtuple_class", is_namedtuple_class),
|
||||
("is_namedtuple_instance", is_namedtuple_instance),
|
||||
("is_structseq", is_structseq),
|
||||
("is_structseq_class", is_structseq_class),
|
||||
("is_structseq_instance", is_structseq_instance),
|
||||
("namedtuple_fields", namedtuple_fields),
|
||||
("structseq_fields", structseq_fields),
|
||||
):
|
||||
__func = globals()[__name]
|
||||
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
__func.__python_implementation__
|
||||
)
|
||||
globals()[__name] = substitute_in_graph(
|
||||
__func, # type: ignore[arg-type]
|
||||
can_constant_fold_through=True,
|
||||
)(__func.__python_implementation__) # type: ignore[attr-defined]
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@ -178,7 +178,7 @@ del _Asterisk
|
||||
class PyTreeSpec:
|
||||
"""Analog for :class:`optree.PyTreeSpec` in Python."""
|
||||
|
||||
_children: tuple[Self, ...]
|
||||
_children: tuple[PyTreeSpec, ...]
|
||||
_type: builtins.type | None
|
||||
_metadata: Any
|
||||
_entries: tuple[Any, ...]
|
||||
@ -210,7 +210,7 @@ class PyTreeSpec:
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
|
||||
def __repr__(self, /) -> str:
|
||||
def helper(treespec: Self) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
if treespec.is_leaf():
|
||||
assert treespec.type is None
|
||||
return _asterisk
|
||||
@ -254,7 +254,7 @@ class PyTreeSpec:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def paths(self, /) -> list[tuple[Any, ...]]:
|
||||
def helper(treespec: Self, path_prefix: list[Any]) -> None:
|
||||
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
|
||||
if treespec.is_leaf():
|
||||
paths.append(path_prefix)
|
||||
return
|
||||
@ -272,7 +272,7 @@ class PyTreeSpec:
|
||||
|
||||
def accessors(self, /) -> list[optree.PyTreeAccessor]:
|
||||
def helper(
|
||||
treespec: Self,
|
||||
treespec: PyTreeSpec,
|
||||
entry_path_prefix: list[optree.PyTreeEntry],
|
||||
) -> None:
|
||||
if treespec.is_leaf():
|
||||
@ -302,10 +302,10 @@ class PyTreeSpec:
|
||||
helper(self, [])
|
||||
return [optree.PyTreeAccessor(path) for path in entry_paths]
|
||||
|
||||
def children(self, /) -> list[Self]:
|
||||
def children(self, /) -> list[PyTreeSpec]:
|
||||
return list(self._children)
|
||||
|
||||
def child(self, index: int, /) -> Self:
|
||||
def child(self, index: int, /) -> PyTreeSpec:
|
||||
return self._children[index]
|
||||
|
||||
def entries(self, /) -> list[Any]:
|
||||
@ -316,7 +316,7 @@ class PyTreeSpec:
|
||||
|
||||
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
|
||||
def helper(
|
||||
treespec: Self,
|
||||
treespec: PyTreeSpec,
|
||||
node: PyTree,
|
||||
subtrees: list[PyTree],
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user