mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations. - #112485 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
7715b47f44
commit
4893a2814f
@ -21,21 +21,21 @@ SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
|
||||
|
||||
|
||||
def register_dataclass_as_pytree_node(
|
||||
typ: Any,
|
||||
cls: Any,
|
||||
flatten_fn: Optional[FlattenFunc] = None,
|
||||
unflatten_fn: Optional[UnflattenFunc] = None,
|
||||
*,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||||
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||||
serialized_type_name: Optional[str] = None,
|
||||
return_none_fields: bool = False,
|
||||
) -> None:
|
||||
assert dataclasses.is_dataclass(
|
||||
typ
|
||||
), f"Only dataclasses can be registered with this function: {typ}"
|
||||
cls
|
||||
), f"Only dataclasses can be registered with this function: {cls}"
|
||||
|
||||
serialized_type = f"{typ.__module__}.{typ.__name__}"
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = typ
|
||||
serialized_type = f"{cls.__module__}.{cls.__name__}"
|
||||
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
||||
|
||||
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||||
flattened = []
|
||||
@ -48,7 +48,7 @@ def register_dataclass_as_pytree_node(
|
||||
flat_names.append(name)
|
||||
else:
|
||||
none_names.append(name)
|
||||
return flattened, (typ, flat_names, none_names)
|
||||
return flattened, (cls, flat_names, none_names)
|
||||
|
||||
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||||
typ, flat_names, none_names = context
|
||||
@ -69,7 +69,7 @@ def register_dataclass_as_pytree_node(
|
||||
|
||||
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
||||
raise ValueError(
|
||||
f"Both to_dumpable_context and from_dumpable_context for {typ} must "
|
||||
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
||||
"be None or registered."
|
||||
)
|
||||
|
||||
@ -85,7 +85,7 @@ def register_dataclass_as_pytree_node(
|
||||
)
|
||||
|
||||
_register_pytree_node(
|
||||
typ,
|
||||
cls,
|
||||
flatten_fn,
|
||||
unflatten_fn,
|
||||
serialized_type_name=serialized_type_name,
|
||||
|
||||
Reference in New Issue
Block a user