[pytree] align function signature between C++ and Python pytree (#112482)

Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
This commit is contained in:
Xuehai Pan
2023-11-05 00:44:30 +08:00
committed by PyTorch MergeBot
parent 7715b47f44
commit 4893a2814f
5 changed files with 321 additions and 190 deletions

View File

@ -21,21 +21,21 @@ SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
def register_dataclass_as_pytree_node(
typ: Any,
cls: Any,
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
serialized_type_name: Optional[str] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
typ
), f"Only dataclasses can be registered with this function: {typ}"
cls
), f"Only dataclasses can be registered with this function: {cls}"
serialized_type = f"{typ.__module__}.{typ.__name__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = typ
serialized_type = f"{cls.__module__}.{cls.__name__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
@ -48,7 +48,7 @@ def register_dataclass_as_pytree_node(
flat_names.append(name)
else:
none_names.append(name)
return flattened, (typ, flat_names, none_names)
return flattened, (cls, flat_names, none_names)
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
typ, flat_names, none_names = context
@ -69,7 +69,7 @@ def register_dataclass_as_pytree_node(
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
f"Both to_dumpable_context and from_dumpable_context for {typ} must "
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
"be None or registered."
)
@ -85,7 +85,7 @@ def register_dataclass_as_pytree_node(
)
_register_pytree_node(
typ,
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,