Add torch._lazy_clone to create COW tensors (#113397)

Part of #109833

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #113397
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113397
Approved by: https://github.com/ezyang
This commit is contained in:
Edward Z. Yang
2024-01-10 15:21:34 -05:00
committed by PyTorch MergeBot
parent 71343507cd
commit edec54b9de
8 changed files with 177 additions and 0 deletions

View File

@ -1,6 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <c10/util/SmallBuffer.h>
#include <c10/core/impl/COW.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -10,6 +11,7 @@
#include <ATen/ops/_make_dual_native.h>
#include <ATen/ops/_new_zeros_with_same_feature_meta_native.h>
#include <ATen/ops/_unpack_dual_native.h>
#include <ATen/ops/_lazy_clone_native.h>
#include <ATen/ops/alias.h>
#include <ATen/ops/zeros.h>
#endif
@ -89,4 +91,19 @@ bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
return base.storage().nbytes() / base.itemsize() == other.storage().nbytes() / other.itemsize();
}
Tensor _lazy_clone(Tensor const& self) {
c10::StorageImpl* self_storage = self.storage().unsafeGetStorageImpl();
c10::intrusive_ptr<c10::StorageImpl> storage =
c10::impl::cow::lazy_clone_storage(*self_storage);
TORCH_CHECK(storage != nullptr);
auto tensor = c10::make_intrusive<c10::TensorImpl>(
c10::Storage(std::move(storage)),
self.key_set(),
self.dtype());
tensor->set_sizes_and_strides(self.sym_sizes(),
self.sym_strides(),
self.sym_storage_offset());
return Tensor(std::move(tensor));
}
} // namespace at::native

View File

@ -1228,6 +1228,13 @@
CompositeExplicitAutograd: copysign_out
tags: pointwise
- func: _lazy_clone(Tensor self) -> Tensor
# Like clone, but the copy takes place lazily, only if either the
# input or the output are written.
variants: function, method
dispatch:
CompositeExplicitAutograd: _lazy_clone
- func: logical_not(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method

View File

@ -375,6 +375,7 @@ aten::_int_mm
aten::_int_mm.out
aten::_is_all_true
aten::_is_any_true
aten::_lazy_clone
aten::_linalg_check_errors
aten::_linalg_det
aten::_linalg_det.result

View File

@ -5072,6 +5072,130 @@ else:
t = torch.tensor((), device=device)
self.assertEqual(t.dtype, t.storage().dtype)
# Note [lazy_clone_ tests with inductor enabled]
# These `lazy_clone_` tests are written in a way that makes them pass in
# both eager mode and compiled mode (`PYTORCH_TEST_WITH_INDUCTOR=1`). There
# are cases where COW tensors can materialize at different times and in
# different ways in compiled mode versus eager mode, and those cases need to
# be avoided. There are two main wrinkles the be aware of.
#
# The first wrinkle is that these tests have to check the internal
# properties of tensors to make sure they materialize in the expected way,
# and those checks cause dynamo graph breaks. Depending on the situation, a
# graph break in-between two compiled graphs that operate on the same COW
# tensor can make the tensor materialize when it would not materialize in
# eager mode, causing the checks to fail. The strategy for avoiding this is
# to make all the operations on COW tensors get compiled into the same
# graph, by not doing any checks between the operations, and just do all the
# checks at the end of the test. If we really do want to perform checks
# between two operations, `op1` and `op2`, the solution is to create two
# different tests. One test performs just `op1` and then checks. The other
# test performs `op1` followed immediately by `op2` and then checks.
#
# The second wrinkle is that in eager mode, if we perform writes on two COW
# tensors where one is a lazy clone of the other, the first tensor to be
# written will be materialized with a new data pointer, and the second
# tensor will just reuse the original data pointer when it is materialized.
# But in compiled mode, if these writes happen in the same graph, the order
# in which the tensors materialize can be different than in eager mode. So
# in this case the strategy is to purposefully cause a graph break to happen
# in-between the two write operations, by adding checks between them, so
# that they have to materialize in the expected order.
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
# Lazy cloning a tensor should cause both it and its clone to become COW
# tensors. They should have different storages, but the same data
# pointer.
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._is_cow_tensor(t))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
# See Note [lazy_clone_ tests with inductor enabled]
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone_view(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
view = t.view([4])
# Viewing `t` should not cause a copy (materialize) to happen. All the
# tensors should still be COW and have the same data pointer. `view` and
# `t` should have the same storage, and `clone` should have a different
# storage.
self.assertTrue(torch._C._is_cow_tensor(t))
self.assertTrue(torch._C._is_cow_tensor(view))
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
self.assertTrue(torch._C._data_address(view) == orig_data_ptr)
# See Note [lazy_clone_ tests with inductor enabled]
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone_view_materialize(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
view = t.view([4])
view += torch.ones(1, device=device, dtype=dtype)
# Writing to `t` should cause the storage under `t` and `view` to be
# copied (materialized), but should not affect `clone`.
self.assertFalse(torch._C._is_cow_tensor(t))
self.assertFalse(torch._C._is_cow_tensor(view))
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
t_new_data_addr = torch._C._data_address(t)
self.assertTrue(t_new_data_addr != orig_data_ptr)
self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
clone += torch.ones(1, device=device, dtype=dtype)
# Writing to `clone` should materialize it, so it should no longer
# be COW. However, since `clone`'s storage is the only COW storage
# left that holds a reference to the original data pointer, this
# materialization should not actually cause a copy--it should
# just reuse the original data pointer.
self.assertFalse(torch._C._is_cow_tensor(t))
self.assertFalse(torch._C._is_cow_tensor(view))
self.assertFalse(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == t_new_data_addr)
self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
# FIXME: move to test distributions
@skipIfMps
@dtypesIfCUDA(torch.float, torch.double, torch.half)

View File

@ -442,6 +442,10 @@
self: grad
result: auto_linear
- name: _lazy_clone(Tensor self) -> Tensor
self: grad
result: auto_linear
- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
self: _to_copy_backward(grad, self.options())
result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format)

View File

@ -326,6 +326,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.isneginf,
aten.isposinf,
aten.l1_loss,
aten._lazy_clone,
aten.leaky_relu_,
aten.leaky_relu_backward,
aten.lerp,

View File

@ -1976,6 +1976,28 @@ Call this whenever a new thread is created in order to propagate values from
return map;
});
py_module.def(
"_storage_address",
[](const at::Tensor& tensor) {
return reinterpret_cast<std::intptr_t>(
tensor.storage().unsafeGetStorageImpl());
},
"Gets the memory address of the Tensor's StorageImpl.");
py_module.def(
"_data_address",
[](const at::Tensor& tensor) {
return reinterpret_cast<std::intptr_t>(tensor.storage().data());
},
"Gets the memory address of the Tensor's data pointer.");
py_module.def(
"_is_cow_tensor",
[](const at::Tensor& tensor) {
return c10::impl::cow::is_cow_data_ptr(tensor.storage().data_ptr());
},
"Checks if a tensor's data pointer is COW");
const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
THPDefaultCPUGenerator =
(THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);

View File

@ -347,6 +347,7 @@ def get_ignored_functions() -> Set[Callable]:
Tensor._has_symbolic_sizes_strides.__get__,
Tensor._conj,
Tensor._conj_physical,
Tensor._lazy_clone,
Tensor._neg_view,
Tensor._is_zerotensor,
Tensor._is_all_true,