mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Track base of FunctionalTensor in inference mode. (#135141)
The idea behind the tracking is the following, whenever we see a tensor if the tensors is a root tensors (does not have any view metas ) when we consider is as the base of the all the tensors that shares its storage. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135141 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc28634172
commit
66dd4577b1
@ -707,7 +707,12 @@ bool are_all_mutations_under_no_grad_or_inference_mode(const Tensor& functional_
|
||||
}
|
||||
|
||||
bool isFunctionalTensor(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Functionalize);
|
||||
}
|
||||
|
||||
bool isBaseTensor(const at::Tensor& tensor) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isFunctionalTensor(tensor));
|
||||
return unsafeGetFunctionalWrapper(tensor)->isBaseTensor();
|
||||
}
|
||||
|
||||
bool isFunctionalTensor(const std::optional<Tensor>& t) {
|
||||
|
@ -165,6 +165,12 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
was_storage_changed_ = true;
|
||||
}
|
||||
|
||||
// A FunctionalTensor is considered a base if its not a view of another
|
||||
// tensor.
|
||||
bool isBaseTensor() const {
|
||||
return view_metas_.empty();
|
||||
}
|
||||
|
||||
c10::SymInt get_storage_size(bool before) {
|
||||
return functional_storage_impl()->get_storage_size(before);
|
||||
}
|
||||
@ -290,6 +296,8 @@ TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
|
||||
return functional_impl;
|
||||
}
|
||||
|
||||
TORCH_API bool isBaseTensor(const at::Tensor& tensor);
|
||||
|
||||
TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
|
||||
TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
|
||||
TORCH_API bool isFunctionalTensor(
|
||||
|
@ -911,85 +911,20 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3
|
||||
with torch.inference_mode():
|
||||
self.test_auto_functionalize_extra1()
|
||||
|
||||
# In inference mode we do not support inplacing views yet.
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_inference_mode2_v2(self):
|
||||
with torch.inference_mode(), torch.library._scoped_library(
|
||||
"mylib", "FRAGMENT"
|
||||
) as lib:
|
||||
torch.library.define(
|
||||
"mylib::foo",
|
||||
"(Tensor(a!) x, Tensor(b!) y) -> ()",
|
||||
tags=torch.Tag.pt2_compliant_tag,
|
||||
lib=lib,
|
||||
)
|
||||
with torch.inference_mode():
|
||||
self.test_auto_functionalize_extra2()
|
||||
|
||||
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
||||
@torch._dynamo.disable
|
||||
def foo_impl(x, y):
|
||||
x.sin_()
|
||||
y.sin_()
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_inference_mode3_v2(self):
|
||||
with torch.inference_mode():
|
||||
self.test_auto_functionalize_extra3()
|
||||
|
||||
def f(x):
|
||||
a = x[0]
|
||||
b = x[1]
|
||||
torch.ops.mylib.foo(a, b)
|
||||
return
|
||||
|
||||
orig_args = [torch.randn(2)]
|
||||
|
||||
[aot_eager_args, result1, graph_aot] = self.run_aot_eager(f, orig_args)
|
||||
[inductor_args, result2, graph_inductor] = self.run_inductor(f, orig_args)
|
||||
eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
|
||||
result3 = f(*eager_args)
|
||||
|
||||
self.assertEqual(inductor_args, eager_args)
|
||||
self.assertEqual(inductor_args, aot_eager_args)
|
||||
|
||||
self.assertEqual(result3, result1)
|
||||
self.assertEqual(result3, result2)
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_aot,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _y_base_index = 1, _all_bases = [select, select_1]); select = select_1 = None
|
||||
getitem_1: "f32[][]cpu" = auto_functionalized_v2[1]
|
||||
getitem_2: "f32[][]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None
|
||||
select_scatter: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, getitem_1, 0, 0); getitem_1 = None
|
||||
select_scatter_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter, getitem_2, 0, 1); select_scatter = getitem_2 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_1); arg0_1 = select_scatter_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
# 2. Run with inductor backend
|
||||
|
||||
if torch._dynamo.config.assume_static_by_default:
|
||||
self.assertExpectedInline(
|
||||
graph_inductor,
|
||||
"""\
|
||||
def forward(self, arg0_1: "f32[2][1]cpu"):
|
||||
select: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
||||
select_1: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1)
|
||||
as_strided_default: "f32[1][1]cpu" = torch.ops.aten.as_strided.default(select, [1], [1], 0); select = None
|
||||
clone_default: "f32[1][1]cpu" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
|
||||
as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default, [], [], 0); clone_default = None
|
||||
as_strided_default_2: "f32[2][1]cpu" = torch.ops.aten.as_strided.default(select_1, [2], [1], 0); select_1 = None
|
||||
clone_default_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_default_2); as_strided_default_2 = None
|
||||
as_strided_default_3: "f32[][]cpu" = torch.ops.aten.as_strided.default(clone_default_1, [], [], 1); clone_default_1 = None
|
||||
foo_default = torch.ops.mylib.foo.default(as_strided_default_1, as_strided_default_3); foo_default = None
|
||||
select_scatter_default: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(arg0_1, as_strided_default_1, 0, 0); as_strided_default_1 = None
|
||||
select_scatter_default_1: "f32[2][1]cpu" = torch.ops.aten.select_scatter.default(select_scatter_default, as_strided_default_3, 0, 1); select_scatter_default = as_strided_default_3 = None
|
||||
copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, select_scatter_default_1); arg0_1 = select_scatter_default_1 = copy_ = None
|
||||
return ()""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_inference_mode4_v2(self):
|
||||
with torch.inference_mode():
|
||||
self.test_auto_functionalize_extra4()
|
||||
|
||||
@torch._inductor.config.patch(enable_auto_functionalized_v2=True)
|
||||
def test_dynamic_v2(self):
|
||||
|
@ -781,6 +781,9 @@ def gen_pyi(
|
||||
"_is_functional_tensor": [
|
||||
"def _is_functional_tensor(t: Tensor) -> _bool: ..."
|
||||
],
|
||||
"_is_functional_tensor_base": [
|
||||
"def _is_functional_tensor_base(t: Tensor) -> _bool: ..."
|
||||
],
|
||||
"_from_functional_tensor": [
|
||||
"def _from_functional_tensor(t: Tensor) -> Tensor: ..."
|
||||
],
|
||||
|
@ -18,6 +18,13 @@ from torch.fx.experimental.proxy_tensor import (
|
||||
)
|
||||
|
||||
|
||||
def get_base(tensor):
|
||||
if torch.is_inference_mode_enabled():
|
||||
return tensor._inference_mode_base
|
||||
else:
|
||||
return tensor._base
|
||||
|
||||
|
||||
@dataclass
|
||||
class ViewInfo:
|
||||
base_index: int
|
||||
@ -68,7 +75,7 @@ def write_view_information_to_args(
|
||||
|
||||
if tensor is None:
|
||||
kwargs[f"{prefix}_base_index"] = None
|
||||
elif tensor._base is None:
|
||||
elif get_base(tensor) is None:
|
||||
# if the tensor is the base (not view), for simplicity we do not serialize view meta.
|
||||
kwargs[f"{prefix}_base_index"] = base_index
|
||||
else:
|
||||
@ -437,7 +444,7 @@ def do_auto_functionalize_v2(
|
||||
arg_to_base_index: Dict[str, Any] = {}
|
||||
|
||||
def update_dict(tensor, arg_name, index=None):
|
||||
base = tensor if tensor._base is None else tensor._base
|
||||
base = tensor if get_base(tensor) is None else get_base(tensor)
|
||||
|
||||
def set_result(base_index):
|
||||
if index is None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@ -111,7 +112,10 @@ class FunctionalTensor(torch.Tensor):
|
||||
torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type]
|
||||
]
|
||||
|
||||
def __new__(cls, elem):
|
||||
# Used by auto_functionalize to determine base of tensors during inference mode.
|
||||
_inference_mode_base: Optional["FunctionalTensor"] = None
|
||||
|
||||
def __new__(cls, elem, mode):
|
||||
assert torch._is_functional_tensor(elem)
|
||||
|
||||
# In general, we'd like our functional tensor subclass to only be in charge of functionalization,
|
||||
@ -142,9 +146,9 @@ class FunctionalTensor(torch.Tensor):
|
||||
cls,
|
||||
elem.shape, # sizes
|
||||
elem.stride() if not is_sparse_any(elem) else None, # strides
|
||||
elem.storage_offset()
|
||||
if not is_sparse_any(elem)
|
||||
else None, # storage_offset
|
||||
(
|
||||
elem.storage_offset() if not is_sparse_any(elem) else None
|
||||
), # storage_offset
|
||||
None, # memory_format
|
||||
elem.dtype, # dtype
|
||||
elem.layout, # layout
|
||||
@ -158,6 +162,21 @@ class FunctionalTensor(torch.Tensor):
|
||||
)
|
||||
torch._C._set_throw_on_mutable_data_ptr(out)
|
||||
out.elem = elem
|
||||
|
||||
if (
|
||||
torch.is_inference_mode_enabled()
|
||||
and torch._inductor.config.enable_auto_functionalized_v2
|
||||
):
|
||||
if out.is_base_tensor():
|
||||
out._inference_mode_base = None
|
||||
# This assumes that the FunctionalTensor.elem does not change its storage after this point.
|
||||
# Otherwise this would be invalid.
|
||||
mode._storage_to_base[out.elem.untyped_storage()] = out
|
||||
else:
|
||||
out._inference_mode_base = mode._storage_to_base[
|
||||
out.elem.untyped_storage()
|
||||
]
|
||||
assert out._inference_mode_base is not None
|
||||
return out
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
@ -209,6 +228,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def to_functional(x):
|
||||
# We will do the wrapping for the user.
|
||||
|
||||
assert not torch._is_functional_tensor(x)
|
||||
# The only autograd metadata we care about on the FunctionalTensor is:
|
||||
# - requires_grad (so autograd runs)
|
||||
@ -226,7 +246,7 @@ class FunctionalTensor(torch.Tensor):
|
||||
|
||||
with functional_mode:
|
||||
torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined]
|
||||
out = FunctionalTensor(x_functional)
|
||||
out = FunctionalTensor(x_functional, functional_mode)
|
||||
torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined]
|
||||
return out
|
||||
|
||||
@ -234,6 +254,9 @@ class FunctionalTensor(torch.Tensor):
|
||||
torch._sync(self)
|
||||
return torch._from_functional_tensor(self.elem)
|
||||
|
||||
def is_base_tensor(self) -> bool:
|
||||
return torch._is_functional_tensor_base(self.elem)
|
||||
|
||||
def replace_(self, output) -> None:
|
||||
torch._functionalize_replace(self.elem, output)
|
||||
|
||||
@ -316,6 +339,10 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
# discovery. This flag distinguishes between the two stages.
|
||||
self._allow_token_discovery = _allow_token_discovery
|
||||
|
||||
self._storage_to_base: weakref.WeakKeyDictionary[
|
||||
torch.storage.UntypedStorage, Optional[FunctionalTensor]
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
# No-op if FunctionalTensorMode is already in use
|
||||
def __enter__(self):
|
||||
def _get_prev_mode():
|
||||
@ -366,6 +393,7 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
if not issubclass(t, torch._subclasses.FakeTensor)
|
||||
and t not in [torch.Tensor, FunctionalTensor]
|
||||
]
|
||||
|
||||
if unrecognized_types:
|
||||
not_implemented_log.debug(
|
||||
"FunctionalTensor unrecognized subclass(es): %s", unrecognized_types
|
||||
@ -417,16 +445,13 @@ class FunctionalTensorMode(TorchDispatchMode):
|
||||
if r is not NotImplemented:
|
||||
return r
|
||||
|
||||
def assert_is_functional(x):
|
||||
assert torch._is_functional_tensor(x)
|
||||
|
||||
def wrap(x):
|
||||
# Only wrap our outputs in subclasses if the inner functionalization call
|
||||
# also wrapped outputs into FunctionalTensorWrappers.
|
||||
# When can this happen? e.g. `torch.div(2, 2)`
|
||||
assert not isinstance(x, FunctionalTensor)
|
||||
if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x):
|
||||
return FunctionalTensor(x)
|
||||
return FunctionalTensor(x, self)
|
||||
return x
|
||||
|
||||
def unwrap(x):
|
||||
|
@ -664,6 +664,10 @@ void initTorchFunctions(PyObject* module) {
|
||||
!at::functionalization::impl::isFunctionalTensor(o));
|
||||
at::functionalization::impl::replace_(t, o);
|
||||
});
|
||||
py_module.def("_is_functional_tensor_base", [](const at::Tensor& t) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||
return at::functionalization::impl::isBaseTensor(t);
|
||||
});
|
||||
py_module.def("_functionalize_is_multi_output_view", [](const at::Tensor& t) {
|
||||
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(t));
|
||||
auto t_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(t);
|
||||
|
Reference in New Issue
Block a user