Track base of FunctionalTensor in inference mode. (#135141)

The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141
Approved by: https://github.com/zou3519
This commit is contained in:
Laith Sakka
2024-09-05 13:34:28 -07:00
committed by PyTorch MergeBot
parent cc28634172
commit 66dd4577b1
7 changed files with 74 additions and 87 deletions

View File

@ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_
}
bool isFunctionalTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
}
bool isBaseTensor(const at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
}
bool isFunctionalTensor(const std::optional<Tensor>& t) {

View File

@ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
was_storage_changed_ = true;
}
// A FunctionalTensor is considered a base if its not a view of another
// tensor.
bool isBaseTensor() const {
return view_metas_.empty();
}
c10::SymInt get_storage_size(bool before) {
return functional_storage_impl()->get_storage_size(before);
}
@ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
return functional_impl;
}
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
TORCH_API bool isFunctionalTensor(

View File

@ -911,85 +911,20 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
with torch.inference_mode():
self.test_auto_functionalize_extra1()
# In inference mode we do not support inplacing views yet.
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_inference_mode2_v2(self):
with torch.inference_mode(), torch.library._scoped_library(
"mylib", "FRAGMENT"
) as lib:
torch.library.define(
"mylib::foo",
"(Tensor(a!) x, Tensor(b!) y) -> ()",
tags=torch.Tag.pt2_compliant_tag,
lib=lib,
)
with torch.inference_mode():
self.test_auto_functionalize_extra2()
@torch.library.impl("mylib::foo", "cpu", lib=lib)
@torch._dynamo.disable
def foo_impl(x, y):
x.sin_()
y.sin_()
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_inference_mode3_v2(self):
with torch.inference_mode():
self.test_auto_functionalize_extra3()
def f(x):
a = x[0]
b = x[1]
torch.ops.mylib.foo(a, b)
return
orig_args = [torch.randn(2)]
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
result3 = f(*eager_args)
self.assertEqual(inductor_args, eager_args)
self.assertEqual(inductor_args, aot_eager_args)
self.assertEqual(result3, result1)
self.assertEqual(result3, result2)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_aot,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None
getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None
select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
# 2. Run with inductor backend
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
graph_inductor,
"""\
def forward(self, arg0_1: "f32[2][1]cpu"):
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0); select = None
clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0); clone_default = None
as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0); select_1 = None
clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None
as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1); clone_default_1 = None
foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3); foo_default = None
select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0); as_strided_default_1 = None
select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1); select_scatter_default = as_strided_default_3 = None
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None
return ()""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_inference_mode4_v2(self):
with torch.inference_mode():
self.test_auto_functionalize_extra4()
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
def test_dynamic_v2(self):

View File

@ -781,6 +781,9 @@ def gen_pyi(
"_is_functional_tensor": [
"def _is_functional_tensor(t: Tensor) -> _bool: ..."
],
"_is_functional_tensor_base": [
"def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
],
"_from_functional_tensor": [
"def _from_functional_tensor(t: Tensor) -> Tensor: ..."
],

View File

@ -18,6 +18,13 @@ from torch.fx.experimental.proxy_tensor import (
)
def get_base(tensor):
if torch.is_inference_mode_enabled():
return tensor._inference_mode_base
else:
return tensor._base
@dataclass
class ViewInfo:
base_index: int
@ -68,7 +75,7 @@ def write_view_information_to_args(
if tensor is None:
kwargs[f"{prefix}_base_index"] = None
elif tensor._base is None:
elif get_base(tensor) is None:
# if the tensor is the base (not view), for simplicity we do not serialize view meta.
kwargs[f"{prefix}_base_index"] = base_index
else:
@ -437,7 +444,7 @@ def do_auto_functionalize_v2(
arg_to_base_index: Dict[str, Any] = {}
def update_dict(tensor, arg_name, index=None):
base = tensor if tensor._base is None else tensor._base
base = tensor if get_base(tensor) is None else get_base(tensor)
def set_result(base_index):
if index is None:

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
import warnings
import weakref
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
@ -111,7 +112,10 @@ class FunctionalTensor(torch.Tensor):
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
]
def __new__(cls, elem):
# Used by auto_functionalize to determine base of tensors during inference mode.
_inference_mode_base: Optional["FunctionalTensor"] = None
def __new__(cls, elem, mode):
assert torch._is_functional_tensor(elem)
# In general, we'd like our functional tensor subclass to only be in charge of functionalization,
@ -142,9 +146,9 @@ class FunctionalTensor(torch.Tensor):
cls,
elem.shape, # sizes
elem.stride() if not is_sparse_any(elem) else None, # strides
elem.storage_offset()
if not is_sparse_any(elem)
else None, # storage_offset
(
elem.storage_offset() if not is_sparse_any(elem) else None
), # storage_offset
None, # memory_format
elem.dtype, # dtype
elem.layout, # layout
@ -158,6 +162,21 @@ class FunctionalTensor(torch.Tensor):
)
torch._C._set_throw_on_mutable_data_ptr(out)
out.elem = elem
if (
torch.is_inference_mode_enabled()
and torch._inductor.config.enable_auto_functionalized_v2
):
if out.is_base_tensor():
out._inference_mode_base = None
# This assumes that the FunctionalTensor.elem does not change its storage after this point.
# Otherwise this would be invalid.
mode._storage_to_base[out.elem.untyped_storage()] = out
else:
out._inference_mode_base = mode._storage_to_base[
out.elem.untyped_storage()
]
assert out._inference_mode_base is not None
return out
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
@ -209,6 +228,7 @@ class FunctionalTensor(torch.Tensor):
@staticmethod
def to_functional(x):
# We will do the wrapping for the user.
assert not torch._is_functional_tensor(x)
# The only autograd metadata we care about on the FunctionalTensor is:
# - requires_grad (so autograd runs)
@ -226,7 +246,7 @@ class FunctionalTensor(torch.Tensor):
with functional_mode:
torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined]
out = FunctionalTensor(x_functional)
out = FunctionalTensor(x_functional, functional_mode)
torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined]
return out
@ -234,6 +254,9 @@ class FunctionalTensor(torch.Tensor):
torch._sync(self)
return torch._from_functional_tensor(self.elem)
def is_base_tensor(self) -> bool:
return torch._is_functional_tensor_base(self.elem)
def replace_(self, output) -> None:
torch._functionalize_replace(self.elem, output)
@ -316,6 +339,10 @@ class FunctionalTensorMode(TorchDispatchMode):
# discovery. This flag distinguishes between the two stages.
self._allow_token_discovery = _allow_token_discovery
self._storage_to_base: weakref.WeakKeyDictionary[
torch.storage.UntypedStorage, Optional[FunctionalTensor]
] = weakref.WeakKeyDictionary()
# No-op if FunctionalTensorMode is already in use
def __enter__(self):
def _get_prev_mode():
@ -366,6 +393,7 @@ class FunctionalTensorMode(TorchDispatchMode):
if not issubclass(t, torch._subclasses.FakeTensor)
and t not in [torch.Tensor, FunctionalTensor]
]
if unrecognized_types:
not_implemented_log.debug(
"FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
@ -417,16 +445,13 @@ class FunctionalTensorMode(TorchDispatchMode):
if r is not NotImplemented:
return r
def assert_is_functional(x):
assert torch._is_functional_tensor(x)
def wrap(x):
# Only wrap our outputs in subclasses if the inner functionalization call
# also wrapped outputs into FunctionalTensorWrappers.
# When can this happen? e.g. `torch.div(2, 2)`
assert not isinstance(x, FunctionalTensor)
if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
return FunctionalTensor(x)
return FunctionalTensor(x, self)
return x
def unwrap(x):

View File

@ -664,6 +664,10 @@ void initTorchFunctions(PyObject* module) {
!at::functionalization::impl::isFunctionalTensor(o));
at::functionalization::impl::replace_(t, o);
});
py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
return at::functionalization::impl::isBaseTensor(t);
});
py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);