mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Track base of FunctionalTensor in inference mode. (#135141)
The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141 Approved by: https://github.com/zou3519
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							cc28634172
						
					
				
				
					commit
					66dd4577b1
				
			| @ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_ | ||||
| } | ||||
|  | ||||
| bool isFunctionalTensor(const at::Tensor& tensor) { | ||||
|   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); | ||||
|    return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize); | ||||
| } | ||||
|  | ||||
| bool isBaseTensor(const at::Tensor& tensor) { | ||||
|   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor)); | ||||
|   return unsafeGetFunctionalWrapper(tensor)->isBaseTensor(); | ||||
| } | ||||
|  | ||||
| bool isFunctionalTensor(const std::optional<Tensor>& t) { | ||||
|  | ||||
| @ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|     was_storage_changed_ = true; | ||||
|   } | ||||
|  | ||||
|   // A FunctionalTensor is considered a base if its not a view of another | ||||
|   // tensor. | ||||
|   bool isBaseTensor() const { | ||||
|     return view_metas_.empty(); | ||||
|   } | ||||
|  | ||||
|   c10::SymInt get_storage_size(bool before) { | ||||
|     return functional_storage_impl()->get_storage_size(before); | ||||
|   } | ||||
| @ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( | ||||
|   return functional_impl; | ||||
| } | ||||
|  | ||||
| TORCH_API bool isBaseTensor(const at::Tensor& tensor); | ||||
|  | ||||
| TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); | ||||
| TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t); | ||||
| TORCH_API bool isFunctionalTensor( | ||||
|  | ||||
| @ -911,85 +911,20 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 | ||||
|         with torch.inference_mode(): | ||||
|             self.test_auto_functionalize_extra1() | ||||
|  | ||||
|     # In inference mode we do not support inplacing views yet. | ||||
|     @torch._inductor.config.patch(enable_auto_functionalized_v2=True) | ||||
|     def test_inference_mode2_v2(self): | ||||
|         with torch.inference_mode(), torch.library._scoped_library( | ||||
|             "mylib", "FRAGMENT" | ||||
|         ) as lib: | ||||
|             torch.library.define( | ||||
|                 "mylib::foo", | ||||
|                 "(Tensor(a!) x, Tensor(b!) y) -> ()", | ||||
|                 tags=torch.Tag.pt2_compliant_tag, | ||||
|                 lib=lib, | ||||
|             ) | ||||
|         with torch.inference_mode(): | ||||
|             self.test_auto_functionalize_extra2() | ||||
|  | ||||
|             @torch.library.impl("mylib::foo", "cpu", lib=lib) | ||||
|             @torch._dynamo.disable | ||||
|             def foo_impl(x, y): | ||||
|                 x.sin_() | ||||
|                 y.sin_() | ||||
|     @torch._inductor.config.patch(enable_auto_functionalized_v2=True) | ||||
|     def test_inference_mode3_v2(self): | ||||
|         with torch.inference_mode(): | ||||
|             self.test_auto_functionalize_extra3() | ||||
|  | ||||
|             def f(x): | ||||
|                 a = x[0] | ||||
|                 b = x[1] | ||||
|                 torch.ops.mylib.foo(a, b) | ||||
|                 return | ||||
|  | ||||
|             orig_args = [torch.randn(2)] | ||||
|  | ||||
|             [aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args) | ||||
|             [inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args) | ||||
|             eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) | ||||
|             result3 = f(*eager_args) | ||||
|  | ||||
|             self.assertEqual(inductor_args, eager_args) | ||||
|             self.assertEqual(inductor_args, aot_eager_args) | ||||
|  | ||||
|             self.assertEqual(result3, result1) | ||||
|             self.assertEqual(result3, result2) | ||||
|  | ||||
|             if torch._dynamo.config.assume_static_by_default: | ||||
|                 self.assertExpectedInline( | ||||
|                     graph_aot, | ||||
|                     """\ | ||||
| def forward(self, arg0_1: "f32[2][1]cpu"): | ||||
|         select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) | ||||
|         select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) | ||||
|         auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]);  select = select_1 = None | ||||
|         getitem_1: "f32[][]cpu" = auto_functionalized_v2[1] | ||||
|         getitem_2: "f32[][]cpu" = auto_functionalized_v2[2];  auto_functionalized_v2 = None | ||||
|         select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0);  getitem_1 = None | ||||
|         select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1);  select_scatter = getitem_2 = None | ||||
|         copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1);  arg0_1 = select_scatter_1 = copy_ = None | ||||
|         return ()""",  # noqa: B950 | ||||
|                     ignore_comments=True, | ||||
|                     ignore_empty_lines=True, | ||||
|                 ) | ||||
|  | ||||
|             # 2. Run with inductor backend | ||||
|  | ||||
|             if torch._dynamo.config.assume_static_by_default: | ||||
|                 self.assertExpectedInline( | ||||
|                     graph_inductor, | ||||
|                     """\ | ||||
| def forward(self, arg0_1: "f32[2][1]cpu"): | ||||
|         select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) | ||||
|         select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1) | ||||
|         as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0);  select = None | ||||
|         clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default);  as_strided_default = None | ||||
|         as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0);  clone_default = None | ||||
|         as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0);  select_1 = None | ||||
|         clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2);  as_strided_default_2 = None | ||||
|         as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1);  clone_default_1 = None | ||||
|         foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3);  foo_default = None | ||||
|         select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0);  as_strided_default_1 = None | ||||
|         select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1);  select_scatter_default = as_strided_default_3 = None | ||||
|         copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1);  arg0_1 = select_scatter_default_1 = copy_ = None | ||||
|         return ()""",  # noqa: B950 | ||||
|                     ignore_comments=True, | ||||
|                     ignore_empty_lines=True, | ||||
|                 ) | ||||
|     @torch._inductor.config.patch(enable_auto_functionalized_v2=True) | ||||
|     def test_inference_mode4_v2(self): | ||||
|         with torch.inference_mode(): | ||||
|             self.test_auto_functionalize_extra4() | ||||
|  | ||||
|     @torch._inductor.config.patch(enable_auto_functionalized_v2=True) | ||||
|     def test_dynamic_v2(self): | ||||
|  | ||||
| @ -781,6 +781,9 @@ def gen_pyi( | ||||
|             "_is_functional_tensor": [ | ||||
|                 "def _is_functional_tensor(t: Tensor) -> _bool: ..." | ||||
|             ], | ||||
|             "_is_functional_tensor_base": [ | ||||
|                 "def _is_functional_tensor_base(t: Tensor) -> _bool: ..." | ||||
|             ], | ||||
|             "_from_functional_tensor": [ | ||||
|                 "def _from_functional_tensor(t: Tensor) -> Tensor: ..." | ||||
|             ], | ||||
|  | ||||
| @ -18,6 +18,13 @@ from torch.fx.experimental.proxy_tensor import ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def get_base(tensor): | ||||
|     if torch.is_inference_mode_enabled(): | ||||
|         return tensor._inference_mode_base | ||||
|     else: | ||||
|         return tensor._base | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ViewInfo: | ||||
|     base_index: int | ||||
| @ -68,7 +75,7 @@ def write_view_information_to_args( | ||||
|  | ||||
|         if tensor is None: | ||||
|             kwargs[f"{prefix}_base_index"] = None | ||||
|         elif tensor._base is None: | ||||
|         elif get_base(tensor) is None: | ||||
|             # if the tensor is the base (not view), for simplicity we do not serialize view meta. | ||||
|             kwargs[f"{prefix}_base_index"] = base_index | ||||
|         else: | ||||
| @ -437,7 +444,7 @@ def do_auto_functionalize_v2( | ||||
|     arg_to_base_index: Dict[str, Any] = {} | ||||
|  | ||||
|     def update_dict(tensor, arg_name, index=None): | ||||
|         base = tensor if tensor._base is None else tensor._base | ||||
|         base = tensor if get_base(tensor) is None else get_base(tensor) | ||||
|  | ||||
|         def set_result(base_index): | ||||
|             if index is None: | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| import contextlib | ||||
| import warnings | ||||
| import weakref | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union | ||||
|  | ||||
| @ -111,7 +112,10 @@ class FunctionalTensor(torch.Tensor): | ||||
|         torch.ops.aten.unsafe_chunk.default,  # type: ignore[has-type] | ||||
|     ] | ||||
|  | ||||
|     def __new__(cls, elem): | ||||
|     # Used by auto_functionalize to determine base of tensors during inference mode. | ||||
|     _inference_mode_base: Optional["FunctionalTensor"] = None | ||||
|  | ||||
|     def __new__(cls, elem, mode): | ||||
|         assert torch._is_functional_tensor(elem) | ||||
|  | ||||
|         # In general, we'd like our functional tensor subclass to only be in charge of functionalization, | ||||
| @ -142,9 +146,9 @@ class FunctionalTensor(torch.Tensor): | ||||
|             cls, | ||||
|             elem.shape,  # sizes | ||||
|             elem.stride() if not is_sparse_any(elem) else None,  # strides | ||||
|             elem.storage_offset() | ||||
|             if not is_sparse_any(elem) | ||||
|             else None,  # storage_offset | ||||
|             ( | ||||
|                 elem.storage_offset() if not is_sparse_any(elem) else None | ||||
|             ),  # storage_offset | ||||
|             None,  # memory_format | ||||
|             elem.dtype,  # dtype | ||||
|             elem.layout,  # layout | ||||
| @ -158,6 +162,21 @@ class FunctionalTensor(torch.Tensor): | ||||
|         ) | ||||
|         torch._C._set_throw_on_mutable_data_ptr(out) | ||||
|         out.elem = elem | ||||
|  | ||||
|         if ( | ||||
|             torch.is_inference_mode_enabled() | ||||
|             and torch._inductor.config.enable_auto_functionalized_v2 | ||||
|         ): | ||||
|             if out.is_base_tensor(): | ||||
|                 out._inference_mode_base = None | ||||
|                 # This assumes that the FunctionalTensor.elem does not change its storage after this point. | ||||
|                 # Otherwise this would be invalid. | ||||
|                 mode._storage_to_base[out.elem.untyped_storage()] = out | ||||
|             else: | ||||
|                 out._inference_mode_base = mode._storage_to_base[ | ||||
|                     out.elem.untyped_storage() | ||||
|                 ] | ||||
|                 assert out._inference_mode_base is not None | ||||
|         return out | ||||
|  | ||||
|     def __torch_dispatch__(self, func, types, args=(), kwargs=None): | ||||
| @ -209,6 +228,7 @@ class FunctionalTensor(torch.Tensor): | ||||
|     @staticmethod | ||||
|     def to_functional(x): | ||||
|         # We will do the wrapping for the user. | ||||
|  | ||||
|         assert not torch._is_functional_tensor(x) | ||||
|         # The only autograd metadata we care about on the FunctionalTensor is: | ||||
|         # - requires_grad (so autograd runs) | ||||
| @ -226,7 +246,7 @@ class FunctionalTensor(torch.Tensor): | ||||
|  | ||||
|         with functional_mode: | ||||
|             torch._mirror_autograd_meta_to(x, x_functional)  # type: ignore[attr-defined] | ||||
|             out = FunctionalTensor(x_functional) | ||||
|             out = FunctionalTensor(x_functional, functional_mode) | ||||
|             torch._mirror_autograd_meta_to(x_functional, out)  # type: ignore[attr-defined] | ||||
|         return out | ||||
|  | ||||
| @ -234,6 +254,9 @@ class FunctionalTensor(torch.Tensor): | ||||
|         torch._sync(self) | ||||
|         return torch._from_functional_tensor(self.elem) | ||||
|  | ||||
|     def is_base_tensor(self) -> bool: | ||||
|         return torch._is_functional_tensor_base(self.elem) | ||||
|  | ||||
|     def replace_(self, output) -> None: | ||||
|         torch._functionalize_replace(self.elem, output) | ||||
|  | ||||
| @ -316,6 +339,10 @@ class FunctionalTensorMode(TorchDispatchMode): | ||||
|         # discovery. This flag distinguishes between the two stages. | ||||
|         self._allow_token_discovery = _allow_token_discovery | ||||
|  | ||||
|         self._storage_to_base: weakref.WeakKeyDictionary[ | ||||
|             torch.storage.UntypedStorage, Optional[FunctionalTensor] | ||||
|         ] = weakref.WeakKeyDictionary() | ||||
|  | ||||
|     # No-op if FunctionalTensorMode is already in use | ||||
|     def __enter__(self): | ||||
|         def _get_prev_mode(): | ||||
| @ -366,6 +393,7 @@ class FunctionalTensorMode(TorchDispatchMode): | ||||
|             if not issubclass(t, torch._subclasses.FakeTensor) | ||||
|             and t not in [torch.Tensor, FunctionalTensor] | ||||
|         ] | ||||
|  | ||||
|         if unrecognized_types: | ||||
|             not_implemented_log.debug( | ||||
|                 "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types | ||||
| @ -417,16 +445,13 @@ class FunctionalTensorMode(TorchDispatchMode): | ||||
|                 if r is not NotImplemented: | ||||
|                     return r | ||||
|  | ||||
|         def assert_is_functional(x): | ||||
|             assert torch._is_functional_tensor(x) | ||||
|  | ||||
|         def wrap(x): | ||||
|             # Only wrap our outputs in subclasses if the inner functionalization call | ||||
|             # also wrapped outputs into FunctionalTensorWrappers. | ||||
|             # When can this happen? e.g. `torch.div(2, 2)` | ||||
|             assert not isinstance(x, FunctionalTensor) | ||||
|             if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): | ||||
|                 return FunctionalTensor(x) | ||||
|                 return FunctionalTensor(x, self) | ||||
|             return x | ||||
|  | ||||
|         def unwrap(x): | ||||
|  | ||||
| @ -664,6 +664,10 @@ void initTorchFunctions(PyObject* module) { | ||||
|             !at::functionalization::impl::isFunctionalTensor(o)); | ||||
|         at::functionalization::impl::replace_(t, o); | ||||
|       }); | ||||
|   py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) { | ||||
|     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); | ||||
|     return at::functionalization::impl::isBaseTensor(t); | ||||
|   }); | ||||
|   py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) { | ||||
|     TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t)); | ||||
|     auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t); | ||||
|  | ||||
		Reference in New Issue
	
	Block a user