mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[FX][export][dynamo] use tuple
instead of list
in normalized args_spec
(#138212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138212 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce631939f0
commit
86d4b7d60b
@ -582,18 +582,25 @@ class _TargetArgsExpr(_TargetExpr):
|
||||
def pytree_flatten(
|
||||
args: Sequence[Any], kwargs: Mapping[Any, Any]
|
||||
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
|
||||
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
|
||||
if s.type is None:
|
||||
return s
|
||||
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
|
||||
return pytree.TreeSpec(
|
||||
mapping.get(s.type, s.type),
|
||||
s.context,
|
||||
list(map(norm_spec, s.children_specs)),
|
||||
)
|
||||
type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict}
|
||||
|
||||
flat, spec = pytree.tree_flatten([args, kwargs])
|
||||
spec = norm_spec(spec)
|
||||
def convert_type(x: Any) -> Any:
|
||||
cls = type(x)
|
||||
convert_fn = type_mapping.get(cls)
|
||||
if convert_fn is not None:
|
||||
return pytree.tree_map(
|
||||
convert_type,
|
||||
convert_fn(x),
|
||||
is_leaf=lambda x: type(x) in type_mapping,
|
||||
)
|
||||
return x
|
||||
|
||||
normalized_args_tree = pytree.tree_map(
|
||||
convert_type,
|
||||
(args, kwargs),
|
||||
is_leaf=lambda x: type(x) in type_mapping,
|
||||
)
|
||||
flat, spec = pytree.tree_flatten(normalized_args_tree)
|
||||
return flat, spec
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
Reference in New Issue
Block a user