Add torch._lazy_clone to create COW tensors (#113397)

Part of #109833

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #113397
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113397
Approved by: https://github.com/ezyang
This commit is contained in:
Edward Z. Yang
2024-01-10 15:21:34 -05:00
committed by PyTorch MergeBot
parent 71343507cd
commit edec54b9de
8 changed files with 177 additions and 0 deletions

View File

@ -5072,6 +5072,130 @@ else:
t = torch.tensor((), device=device)
self.assertEqual(t.dtype, t.storage().dtype)
# Note [lazy_clone_ tests with inductor enabled]
# These `lazy_clone_` tests are written in a way that makes them pass in
# both eager mode and compiled mode (`PYTORCH_TEST_WITH_INDUCTOR=1`). There
# are cases where COW tensors can materialize at different times and in
# different ways in compiled mode versus eager mode, and those cases need to
# be avoided. There are two main wrinkles the be aware of.
#
# The first wrinkle is that these tests have to check the internal
# properties of tensors to make sure they materialize in the expected way,
# and those checks cause dynamo graph breaks. Depending on the situation, a
# graph break in-between two compiled graphs that operate on the same COW
# tensor can make the tensor materialize when it would not materialize in
# eager mode, causing the checks to fail. The strategy for avoiding this is
# to make all the operations on COW tensors get compiled into the same
# graph, by not doing any checks between the operations, and just do all the
# checks at the end of the test. If we really do want to perform checks
# between two operations, `op1` and `op2`, the solution is to create two
# different tests. One test performs just `op1` and then checks. The other
# test performs `op1` followed immediately by `op2` and then checks.
#
# The second wrinkle is that in eager mode, if we perform writes on two COW
# tensors where one is a lazy clone of the other, the first tensor to be
# written will be materialized with a new data pointer, and the second
# tensor will just reuse the original data pointer when it is materialized.
# But in compiled mode, if these writes happen in the same graph, the order
# in which the tensors materialize can be different than in eager mode. So
# in this case the strategy is to purposefully cause a graph break to happen
# in-between the two write operations, by adding checks between them, so
# that they have to materialize in the expected order.
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
# Lazy cloning a tensor should cause both it and its clone to become COW
# tensors. They should have different storages, but the same data
# pointer.
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._is_cow_tensor(t))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
# See Note [lazy_clone_ tests with inductor enabled]
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone_view(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
view = t.view([4])
# Viewing `t` should not cause a copy (materialize) to happen. All the
# tensors should still be COW and have the same data pointer. `view` and
# `t` should have the same storage, and `clone` should have a different
# storage.
self.assertTrue(torch._C._is_cow_tensor(t))
self.assertTrue(torch._C._is_cow_tensor(view))
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == orig_data_ptr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
self.assertTrue(torch._C._data_address(view) == orig_data_ptr)
# See Note [lazy_clone_ tests with inductor enabled]
@skipXLA
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_lazy_clone_view_materialize(self, device, dtype):
t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
t_orig_storage_addr = torch._C._storage_address(t)
orig_data_ptr = torch._C._data_address(t)
clone = t._lazy_clone()
view = t.view([4])
view += torch.ones(1, device=device, dtype=dtype)
# Writing to `t` should cause the storage under `t` and `view` to be
# copied (materialized), but should not affect `clone`.
self.assertFalse(torch._C._is_cow_tensor(t))
self.assertFalse(torch._C._is_cow_tensor(view))
self.assertTrue(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
t_new_data_addr = torch._C._data_address(t)
self.assertTrue(t_new_data_addr != orig_data_ptr)
self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
clone += torch.ones(1, device=device, dtype=dtype)
# Writing to `clone` should materialize it, so it should no longer
# be COW. However, since `clone`'s storage is the only COW storage
# left that holds a reference to the original data pointer, this
# materialization should not actually cause a copy--it should
# just reuse the original data pointer.
self.assertFalse(torch._C._is_cow_tensor(t))
self.assertFalse(torch._C._is_cow_tensor(view))
self.assertFalse(torch._C._is_cow_tensor(clone))
self.assertTrue(torch._C._storage_address(t) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(view) == t_orig_storage_addr)
self.assertTrue(torch._C._storage_address(clone) != t_orig_storage_addr)
self.assertTrue(torch._C._data_address(t) == t_new_data_addr)
self.assertTrue(torch._C._data_address(view) == t_new_data_addr)
self.assertTrue(torch._C._data_address(clone) == orig_data_ptr)
# FIXME: move to test distributions
@skipIfMps
@dtypesIfCUDA(torch.float, torch.double, torch.half)