mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
@ -77,20 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
__name = ""
|
__name = ""
|
||||||
for __name in (
|
for __name, __func in (
|
||||||
"is_namedtuple",
|
("is_namedtuple", is_namedtuple),
|
||||||
"is_namedtuple_class",
|
("is_namedtuple_class", is_namedtuple_class),
|
||||||
"is_namedtuple_instance",
|
("is_namedtuple_instance", is_namedtuple_instance),
|
||||||
"is_structseq",
|
("is_structseq", is_structseq),
|
||||||
"is_structseq_class",
|
("is_structseq_class", is_structseq_class),
|
||||||
"is_structseq_instance",
|
("is_structseq_instance", is_structseq_instance),
|
||||||
"namedtuple_fields",
|
("namedtuple_fields", namedtuple_fields),
|
||||||
"structseq_fields",
|
("structseq_fields", structseq_fields),
|
||||||
):
|
):
|
||||||
__func = globals()[__name]
|
globals()[__name] = substitute_in_graph(
|
||||||
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
__func, # type: ignore[arg-type]
|
||||||
__func.__python_implementation__
|
can_constant_fold_through=True,
|
||||||
)
|
)(__func.__python_implementation__) # type: ignore[attr-defined]
|
||||||
del __func
|
del __func
|
||||||
del __name
|
del __name
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ del _Asterisk
|
|||||||
class PyTreeSpec:
|
class PyTreeSpec:
|
||||||
"""Analog for :class:`optree.PyTreeSpec` in Python."""
|
"""Analog for :class:`optree.PyTreeSpec` in Python."""
|
||||||
|
|
||||||
_children: tuple[Self, ...]
|
_children: tuple[PyTreeSpec, ...]
|
||||||
_type: builtins.type | None
|
_type: builtins.type | None
|
||||||
_metadata: Any
|
_metadata: Any
|
||||||
_entries: tuple[Any, ...]
|
_entries: tuple[Any, ...]
|
||||||
@ -210,7 +210,7 @@ class PyTreeSpec:
|
|||||||
object.__setattr__(self, "num_children", num_children)
|
object.__setattr__(self, "num_children", num_children)
|
||||||
|
|
||||||
def __repr__(self, /) -> str:
|
def __repr__(self, /) -> str:
|
||||||
def helper(treespec: Self) -> str:
|
def helper(treespec: PyTreeSpec) -> str:
|
||||||
if treespec.is_leaf():
|
if treespec.is_leaf():
|
||||||
assert treespec.type is None
|
assert treespec.type is None
|
||||||
return _asterisk
|
return _asterisk
|
||||||
@ -254,7 +254,7 @@ class PyTreeSpec:
|
|||||||
return self.num_nodes == 1 and self.num_leaves == 1
|
return self.num_nodes == 1 and self.num_leaves == 1
|
||||||
|
|
||||||
def paths(self, /) -> list[tuple[Any, ...]]:
|
def paths(self, /) -> list[tuple[Any, ...]]:
|
||||||
def helper(treespec: Self, path_prefix: list[Any]) -> None:
|
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
|
||||||
if treespec.is_leaf():
|
if treespec.is_leaf():
|
||||||
paths.append(path_prefix)
|
paths.append(path_prefix)
|
||||||
return
|
return
|
||||||
@ -272,7 +272,7 @@ class PyTreeSpec:
|
|||||||
|
|
||||||
def accessors(self, /) -> list[optree.PyTreeAccessor]:
|
def accessors(self, /) -> list[optree.PyTreeAccessor]:
|
||||||
def helper(
|
def helper(
|
||||||
treespec: Self,
|
treespec: PyTreeSpec,
|
||||||
entry_path_prefix: list[optree.PyTreeEntry],
|
entry_path_prefix: list[optree.PyTreeEntry],
|
||||||
) -> None:
|
) -> None:
|
||||||
if treespec.is_leaf():
|
if treespec.is_leaf():
|
||||||
@ -302,10 +302,10 @@ class PyTreeSpec:
|
|||||||
helper(self, [])
|
helper(self, [])
|
||||||
return [optree.PyTreeAccessor(path) for path in entry_paths]
|
return [optree.PyTreeAccessor(path) for path in entry_paths]
|
||||||
|
|
||||||
def children(self, /) -> list[Self]:
|
def children(self, /) -> list[PyTreeSpec]:
|
||||||
return list(self._children)
|
return list(self._children)
|
||||||
|
|
||||||
def child(self, index: int, /) -> Self:
|
def child(self, index: int, /) -> PyTreeSpec:
|
||||||
return self._children[index]
|
return self._children[index]
|
||||||
|
|
||||||
def entries(self, /) -> list[Any]:
|
def entries(self, /) -> list[Any]:
|
||||||
@ -316,7 +316,7 @@ class PyTreeSpec:
|
|||||||
|
|
||||||
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
|
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
|
||||||
def helper(
|
def helper(
|
||||||
treespec: Self,
|
treespec: PyTreeSpec,
|
||||||
node: PyTree,
|
node: PyTree,
|
||||||
subtrees: list[PyTree],
|
subtrees: list[PyTree],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user