mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FX][export][dynamo] use tuple
instead of list
in normalized args_spec
(#138212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138212 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ce631939f0
commit
86d4b7d60b
@ -582,18 +582,25 @@ class _TargetArgsExpr(_TargetExpr):
|
||||
def pytree_flatten(
|
||||
args: Sequence[Any], kwargs: Mapping[Any, Any]
|
||||
) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]:
|
||||
def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec:
|
||||
if s.type is None:
|
||||
return s
|
||||
mapping = {immutable_list: list, tuple: list, immutable_dict: dict}
|
||||
return pytree.TreeSpec(
|
||||
mapping.get(s.type, s.type),
|
||||
s.context,
|
||||
list(map(norm_spec, s.children_specs)),
|
||||
)
|
||||
type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict}
|
||||
|
||||
flat, spec = pytree.tree_flatten([args, kwargs])
|
||||
spec = norm_spec(spec)
|
||||
def convert_type(x: Any) -> Any:
|
||||
cls = type(x)
|
||||
convert_fn = type_mapping.get(cls)
|
||||
if convert_fn is not None:
|
||||
return pytree.tree_map(
|
||||
convert_type,
|
||||
convert_fn(x),
|
||||
is_leaf=lambda x: type(x) in type_mapping,
|
||||
)
|
||||
return x
|
||||
|
||||
normalized_args_tree = pytree.tree_map(
|
||||
convert_type,
|
||||
(args, kwargs),
|
||||
is_leaf=lambda x: type(x) in type_mapping,
|
||||
)
|
||||
flat, spec = pytree.tree_flatten(normalized_args_tree)
|
||||
return flat, spec
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -136,15 +136,35 @@ class OutputAdapter:
|
||||
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
|
||||
|
||||
|
||||
def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec:
|
||||
_type = list if spec.type == tuple else spec.type
|
||||
return pytree.TreeSpec(
|
||||
_type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs))
|
||||
# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame.
|
||||
class _DummyLeaf: # use a class instead.
|
||||
pass
|
||||
|
||||
|
||||
def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec:
|
||||
def replace_list_with_tuple(x: Any) -> Any:
|
||||
if type(x) is list:
|
||||
return pytree.tree_map(
|
||||
replace_list_with_tuple,
|
||||
tuple(x),
|
||||
is_leaf=lambda x: type(x) is list,
|
||||
)
|
||||
return x
|
||||
|
||||
dummy_leaf = _DummyLeaf()
|
||||
dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
|
||||
dummy_tree = pytree.tree_map(
|
||||
replace_list_with_tuple,
|
||||
dummy_tree,
|
||||
is_leaf=lambda x: type(x) is list,
|
||||
)
|
||||
return pytree.tree_structure(dummy_tree)
|
||||
|
||||
|
||||
def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec:
|
||||
if spec.type == list and spec.num_children == 1:
|
||||
def _open_top_level_sequence_if_single_element(
|
||||
spec: pytree.TreeSpec,
|
||||
) -> pytree.TreeSpec:
|
||||
if spec.type in (tuple, list) and spec.num_children == 1:
|
||||
return spec.children_specs[0]
|
||||
return spec
|
||||
|
||||
@ -167,10 +187,10 @@ def _assert_identical_pytree_spec(
|
||||
pass_if_any_checks: Sequence[Callable[[], bool]] = [
|
||||
lambda: spec1 == spec2,
|
||||
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
|
||||
lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2),
|
||||
lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2),
|
||||
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
|
||||
lambda: _open_top_level_list_if_single_element(spec1) == spec2,
|
||||
lambda: spec1 == _open_top_level_list_if_single_element(spec2),
|
||||
lambda: _open_top_level_sequence_if_single_element(spec1) == spec2,
|
||||
lambda: spec1 == _open_top_level_sequence_if_single_element(spec2),
|
||||
]
|
||||
|
||||
if not any(check() for check in pass_if_any_checks):
|
||||
|
Reference in New Issue
Block a user