[FX][export][dynamo] use tuple instead of list in normalized args_spec (#138212)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138212
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan
2024-10-24 23:52:49 +08:00
committed by PyTorch MergeBot
parent ce631939f0
commit 86d4b7d60b
2 changed files with 47 additions and 20 deletions

View File

@ -582,18 +582,25 @@ class _TargetArgsExpr(_TargetExpr):
def pytree_flatten(
args: Sequence[Any], kwargs: Mapping[Any, Any]
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
if s.type is None:
return s
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
return pytree.TreeSpec(
mapping.get(s.type, s.type),
s.context,
list(map(norm_spec, s.children_specs)),
)
type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict}
flat, spec = pytree.tree_flatten([args, kwargs])
spec = norm_spec(spec)
def convert_type(x: Any) -> Any:
cls = type(x)
convert_fn = type_mapping.get(cls)
if convert_fn is not None:
return pytree.tree_map(
convert_type,
convert_fn(x),
is_leaf=lambda x: type(x) in type_mapping,
)
return x
normalized_args_tree = pytree.tree_map(
convert_type,
(args, kwargs),
is_leaf=lambda x: type(x) in type_mapping,
)
flat, spec = pytree.tree_flatten(normalized_args_tree)
return flat, spec
def __repr__(self) -> str:

View File

@ -136,15 +136,35 @@ class OutputAdapter:
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
_type = list if spec.type == tuple else spec.type
return pytree.TreeSpec(
_type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame.
class _DummyLeaf: # use a class instead.
pass
def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec:
def replace_list_with_tuple(x: Any) -> Any:
if type(x) is list:
return pytree.tree_map(
replace_list_with_tuple,
tuple(x),
is_leaf=lambda x: type(x) is list,
)
return x
dummy_leaf = _DummyLeaf()
dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
dummy_tree = pytree.tree_map(
replace_list_with_tuple,
dummy_tree,
is_leaf=lambda x: type(x) is list,
)
return pytree.tree_structure(dummy_tree)
def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
if spec.type == list and spec.num_children == 1:
def _open_top_level_sequence_if_single_element(
spec: pytree.TreeSpec,
) -> pytree.TreeSpec:
if spec.type in (tuple, list) and spec.num_children == 1:
return spec.children_specs[0]
return spec
@ -167,10 +187,10 @@ def _assert_identical_pytree_spec(
pass_if_any_checks: Sequence[Callable[[], bool]] = [
lambda: spec1 == spec2,
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2),
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
lambda: _open_top_level_list_if_single_element(spec1) == spec2,
lambda: spec1 == _open_top_level_list_if_single_element(spec2),
lambda: _open_top_level_sequence_if_single_element(spec1) == spec2,
lambda: spec1 == _open_top_level_sequence_if_single_element(spec2),
]
if not any(check() for check in pass_if_any_checks):