[FX][export][dynamo] use tuple instead of list in normalized args_spec (#138212)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138212
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-10-24 23:52:49 +08:00
committed by PyTorch MergeBot
parent ce631939f0
commit 86d4b7d60b
2 changed files with 47 additions and 20 deletions

View File

@ -582,18 +582,25 @@ class _TargetArgsExpr(_TargetExpr):
def pytree_flatten(
args: Sequence[Any], kwargs: Mapping[Any, Any]
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
if s.type is None:
return s
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
return pytree.TreeSpec(
mapping.get(s.type, s.type),
s.context,
list(map(norm_spec, s.children_specs)),
)
type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict}
flat, spec = pytree.tree_flatten([args, kwargs])
spec = norm_spec(spec)
def convert_type(x: Any) -> Any:
cls = type(x)
convert_fn = type_mapping.get(cls)
if convert_fn is not None:
return pytree.tree_map(
convert_type,
convert_fn(x),
is_leaf=lambda x: type(x) in type_mapping,
)
return x
normalized_args_tree = pytree.tree_map(
convert_type,
(args, kwargs),
is_leaf=lambda x: type(x) in type_mapping,
)
flat, spec = pytree.tree_flatten(normalized_args_tree)
return flat, spec
def __repr__(self) -> str: