Implement shallow copy functions for FunctionalTensorWrapper. (#118783)

Fix: #115792

This PR implements 2 virtual functions of `TensorImpl` that are called when setting the
`tensor.data`:

- `shallow_copy_from`: which calls `copy_tensor_metadata`; and

- `copy_tensor_metadata`: which copies all `FunctionalTensorWrapper` metadata and ~calls
`dest->value_.set_data(src->value_)`~ assigns `dest->value_ = src->value_`, so as to copy also the inner tensor using the same
method

Before this PR, the inner tensor of a `FunctionalTensorWrapper` was being ignored.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118783
Approved by: https://github.com/bdhirsh
This commit is contained in:
Yukio Siraichi
2024-02-08 10:46:46 -03:00
committed by PyTorch MergeBot
parent 6d8f192fd0
commit 9436710afd
3 changed files with 99 additions and 6 deletions

View File

@ -352,6 +352,41 @@ const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
return "FunctionalTensorWrapper";
}
void FunctionalTensorWrapper::copy_tensor_metadata(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
TensorImpl::copy_tensor_metadata(
src_impl,
dest_impl,
version_counter,
allow_tensor_metadata_change);
// FunctionalTensorWrapper-specific fields.
dest_impl->value_ = src_impl->value_;
dest_impl->level_ = src_impl->level_;
dest_impl->mutation_counter_ = src_impl->mutation_counter_;
dest_impl->mutation_hidden_from_autograd_counter_ = src_impl->mutation_hidden_from_autograd_counter_;
dest_impl->mutation_during_no_grad_or_inference_mode_ = src_impl->mutation_during_no_grad_or_inference_mode_;
dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
dest_impl->generation_ = src_impl->generation_;
dest_impl->view_metas_ = src_impl->view_metas_;
}
void FunctionalTensorWrapper::copy_tensor_metadata_and_refresh(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
copy_tensor_metadata(src_impl, dest_impl, version_counter, allow_tensor_metadata_change);
dest_impl->refresh_numel();
dest_impl->refresh_contiguous();
}
template <typename VariableVersion>
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
VariableVersion&& version_counter,
@ -367,16 +402,11 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_
}
auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
copy_tensor_metadata(
copy_tensor_metadata_and_refresh(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/std::forward<VariableVersion>(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->level_ = level_;
impl->generation_ = generation_;
impl->view_metas_ = view_metas_;
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
@ -394,6 +424,18 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
std::move(version_counter), allow_tensor_metadata_change);
}
void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
auto functional_impl =
static_cast<FunctionalTensorWrapper*>(impl.get());
copy_tensor_metadata_and_refresh(
/*src_impl=*/functional_impl,
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
}
c10::Device FunctionalTensorWrapper::device_custom() const {
return value_.unsafeGetTensorImpl()->device();
}

View File

@ -211,6 +211,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
void copy_tensor_metadata_and_refresh(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const;
// Note that value is not taken by reference: internally, the wrapper will
// change the value tensor that it points to over time.
Tensor value_;
@ -230,6 +237,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
size_t generation_ = 0;
std::vector<at::functionalization::ViewMeta> view_metas_;
protected:
static void copy_tensor_metadata(
const FunctionalTensorWrapper* src_impl,
FunctionalTensorWrapper* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change);
};
// Utility functions for the functionalization pass.

View File

@ -1,5 +1,7 @@
# Owner(s): ["oncall: jit"]
import re
import torch
import torch._lazy.metrics as metrics
import torch._lazy.ts_backend
@ -7,6 +9,8 @@ from torch.testing._internal.common_utils import run_tests, TestCase
torch._lazy.ts_backend.init()
NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+")
class LazyFuncionalizationTest(TestCase):
def test_lazy_init_with_view(self):
@ -56,6 +60,39 @@ class LazyFuncionalizationTest(TestCase):
self.assertEqual(cpu_out, lazy_out_1.to("cpu"))
self.assertEqual(cpu_out, lazy_out_2.to("cpu"))
def test_data_assign(self):
def text(lazyt):
raw = torch._C._lazy._get_tensors_text([lazyt])
return NODE_TYPE_PATTERN.sub("", raw)
origin = torch.rand(3, dtype=torch.float32)
tensor = origin.to("lazy")
self.assertExpectedInline(
text(tensor),
"""\
IR {
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0
}
""",
)
# Modify the data-type of tensor, and assign it to 'data'.
# This should update the inner tensor of FunctionalTensorWrapper,
# changing the corresponding IR node.
modified_tensor = tensor.to(torch.bfloat16)
tensor.data = modified_tensor
self.assertExpectedInline(
text(tensor),
"""\
IR {
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0
%1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0
}
""", # noqa: B950
)
if __name__ == "__main__":
run_tests()