diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 45047634f66d..2e2cf980f69a 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4536,7 +4536,7 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): immutable_inputs = torch.fx.immutable_collections.immutable_list(inputs) try: immutable_inputs.append(x) - except NotImplementedError: + except TypeError: pass return torch.fx.node.map_aggregate(immutable_inputs, f) diff --git a/test/test_fx.py b/test/test_fx.py index 01760ce5b26f..07401118c426 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3039,7 +3039,7 @@ class TestFX(JitTestCase): from torch.fx.immutable_collections import immutable_list x = immutable_list([3, 4]) - with self.assertRaisesRegex(NotImplementedError, "new_args"): + with self.assertRaisesRegex(TypeError, "new_args"): x[0] = 4 def test_partial_trace(self): diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index 0fb0e0ea4a50..286575726dc2 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -70,7 +70,7 @@ def call_delegate_cpu( input_args: list[torch.Tensor], ) -> list[torch.Tensor]: # FX creates this immutable_dict/list concept. Get rid of this. - map_types = { + map_types: dict[type, type] = { torch.fx.immutable_collections.immutable_dict: dict, torch.fx.immutable_collections.immutable_list: list, } diff --git a/torch/_higher_order_ops/executorch_call_delegate.py b/torch/_higher_order_ops/executorch_call_delegate.py index a6ee5205ff4e..2782ddce230b 100644 --- a/torch/_higher_order_ops/executorch_call_delegate.py +++ b/torch/_higher_order_ops/executorch_call_delegate.py @@ -74,7 +74,7 @@ def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): # pyre-ignore def call_delegate_cpu(lowered_module, *args): # FX creates this immutable_dict/list concept. Get rid of this. - map_types = { + map_types: dict[type, type] = { torch.fx.immutable_collections.immutable_dict: dict, torch.fx.immutable_collections.immutable_list: list, } diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 97cf5ca4348a..ab32d6af3e69 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -589,7 +589,11 @@ class _TargetArgsExpr(_TargetExpr): def pytree_flatten( args: Sequence[Any], kwargs: Mapping[Any, Any] ) -> tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: - type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict} + type_mapping: dict[type, type] = { + immutable_list: tuple, + list: tuple, + immutable_dict: dict, + } def convert_type(x: Any) -> Any: cls = type(x) diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index 74691bbe72ac..a4322a884d60 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -84,6 +84,7 @@ Several example transformations can be found at the repository. ''' +from torch.fx import immutable_collections from torch.fx._symbolic_trace import ( # noqa: F401 PH, ProxyableClassMeta, diff --git a/torch/fx/immutable_collections.py b/torch/fx/immutable_collections.py index 484f9c18f628..6c6204d520bc 100644 --- a/torch/fx/immutable_collections.py +++ b/torch/fx/immutable_collections.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs from collections.abc import Iterable -from typing import Any +from typing import Any, NoReturn, TypeVar +from typing_extensions import Self from torch.utils._pytree import ( _dict_flatten, @@ -18,97 +18,94 @@ from ._compatibility import compatibility __all__ = ["immutable_list", "immutable_dict"] -_help_mutation = """\ + +_help_mutation = """ If you are attempting to modify the kwargs or args of a torch.fx.Node object, instead create a new copy of it and assign the copy to the node: - new_args = ... # copy and mutate args + + new_args = ... # copy and mutate args node.args = new_args -""" +""".strip() -def _no_mutation(self, *args, **kwargs): - raise NotImplementedError( - f"'{type(self).__name__}' object does not support mutation. {_help_mutation}", +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +def _no_mutation(self: Any, *args: Any, **kwargs: Any) -> NoReturn: + raise TypeError( + f"{type(self).__name__!r} object does not support mutation. {_help_mutation}", ) -def _create_immutable_container(base, mutable_functions): - container = type("immutable_" + base.__name__, (base,), {}) - for attr in mutable_functions: - setattr(container, attr, _no_mutation) - return container +@compatibility(is_backward_compatible=True) +class immutable_list(list[_T]): + """An immutable version of :class:`list`.""" + + __delitem__ = _no_mutation + __iadd__ = _no_mutation + __imul__ = _no_mutation + __setitem__ = _no_mutation + append = _no_mutation + clear = _no_mutation + extend = _no_mutation + insert = _no_mutation + pop = _no_mutation + remove = _no_mutation + reverse = _no_mutation + sort = _no_mutation + + def __hash__(self) -> int: # type: ignore[override] + return hash(tuple(self)) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[_T, ...]]]: + return (type(self), (tuple(self),)) -immutable_list = _create_immutable_container( - list, - ( - "__delitem__", - "__iadd__", - "__imul__", - "__setitem__", - "append", - "clear", - "extend", - "insert", - "pop", - "remove", - "reverse", - "sort", - ), -) -immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),)) -immutable_list.__hash__ = lambda self: hash(tuple(self)) +@compatibility(is_backward_compatible=True) +class immutable_dict(dict[_KT, _VT]): + """An immutable version of :class:`dict`.""" -compatibility(is_backward_compatible=True)(immutable_list) + __delitem__ = _no_mutation + __ior__ = _no_mutation + __setitem__ = _no_mutation + clear = _no_mutation + pop = _no_mutation + popitem = _no_mutation + setdefault = _no_mutation + update = _no_mutation # type: ignore[assignment] -immutable_dict = _create_immutable_container( - dict, - ( - "__delitem__", - "__ior__", - "__setitem__", - "clear", - "pop", - "popitem", - "setdefault", - "update", - ), -) -immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),)) -immutable_dict.__hash__ = lambda self: hash(tuple(self.items())) -compatibility(is_backward_compatible=True)(immutable_dict) + def __hash__(self) -> int: # type: ignore[override] + return hash(frozenset(self.items())) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[tuple[_KT, _VT], ...]]]: + return (type(self), (tuple(self.items()),)) # Register immutable collections for PyTree operations -def _immutable_dict_flatten(d: dict[Any, Any]) -> tuple[list[Any], Context]: - return _dict_flatten(d) - - -def _immutable_dict_unflatten( - values: Iterable[Any], - context: Context, -) -> dict[Any, Any]: - return immutable_dict(_dict_unflatten(values, context)) - - -def _immutable_list_flatten(d: list[Any]) -> tuple[list[Any], Context]: +def _immutable_list_flatten(d: immutable_list[_T]) -> tuple[list[_T], Context]: return _list_flatten(d) def _immutable_list_unflatten( - values: Iterable[Any], + values: Iterable[_T], context: Context, -) -> list[Any]: +) -> immutable_list[_T]: return immutable_list(_list_unflatten(values, context)) -register_pytree_node( - immutable_dict, - _immutable_dict_flatten, - _immutable_dict_unflatten, - serialized_type_name="torch.fx.immutable_collections.immutable_dict", - flatten_with_keys_fn=_dict_flatten_with_keys, -) +def _immutable_dict_flatten(d: immutable_dict[Any, _VT]) -> tuple[list[_VT], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[_VT], + context: Context, +) -> immutable_dict[Any, _VT]: + return immutable_dict(_dict_unflatten(values, context)) + + register_pytree_node( immutable_list, _immutable_list_flatten, @@ -116,3 +113,10 @@ register_pytree_node( serialized_type_name="torch.fx.immutable_collections.immutable_list", flatten_with_keys_fn=_list_flatten_with_keys, ) +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +)