mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement shallow copy functions for FunctionalTensorWrapper
. (#118783)
Fix: #115792 This PR implements 2 virtual functions of `TensorImpl` that are called when setting the `tensor.data`: - `shallow_copy_from`: which calls `copy_tensor_metadata`; and - `copy_tensor_metadata`: which copies all `FunctionalTensorWrapper` metadata and ~calls `dest->value_.set_data(src->value_)`~ assigns `dest->value_ = src->value_`, so as to copy also the inner tensor using the same method Before this PR, the inner tensor of a `FunctionalTensorWrapper` was being ignored. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118783 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
6d8f192fd0
commit
9436710afd
@ -352,6 +352,41 @@ const char* FunctionalTensorWrapper::tensorimpl_type_name() const {
|
||||
return "FunctionalTensorWrapper";
|
||||
}
|
||||
|
||||
void FunctionalTensorWrapper::copy_tensor_metadata(
|
||||
const FunctionalTensorWrapper* src_impl,
|
||||
FunctionalTensorWrapper* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) {
|
||||
TensorImpl::copy_tensor_metadata(
|
||||
src_impl,
|
||||
dest_impl,
|
||||
version_counter,
|
||||
allow_tensor_metadata_change);
|
||||
|
||||
// FunctionalTensorWrapper-specific fields.
|
||||
dest_impl->value_ = src_impl->value_;
|
||||
dest_impl->level_ = src_impl->level_;
|
||||
dest_impl->mutation_counter_ = src_impl->mutation_counter_;
|
||||
dest_impl->mutation_hidden_from_autograd_counter_ = src_impl->mutation_hidden_from_autograd_counter_;
|
||||
dest_impl->mutation_during_no_grad_or_inference_mode_ = src_impl->mutation_during_no_grad_or_inference_mode_;
|
||||
dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_;
|
||||
dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_;
|
||||
dest_impl->was_storage_changed_ = src_impl->was_storage_changed_;
|
||||
dest_impl->generation_ = src_impl->generation_;
|
||||
dest_impl->view_metas_ = src_impl->view_metas_;
|
||||
}
|
||||
|
||||
|
||||
void FunctionalTensorWrapper::copy_tensor_metadata_and_refresh(
|
||||
const FunctionalTensorWrapper* src_impl,
|
||||
FunctionalTensorWrapper* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const {
|
||||
copy_tensor_metadata(src_impl, dest_impl, version_counter, allow_tensor_metadata_change);
|
||||
dest_impl->refresh_numel();
|
||||
dest_impl->refresh_contiguous();
|
||||
}
|
||||
|
||||
template <typename VariableVersion>
|
||||
c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_core(
|
||||
VariableVersion&& version_counter,
|
||||
@ -367,16 +402,11 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach_
|
||||
}
|
||||
|
||||
auto impl = c10::make_intrusive<FunctionalTensorWrapper>(value_);
|
||||
copy_tensor_metadata(
|
||||
copy_tensor_metadata_and_refresh(
|
||||
/*src_impl=*/this,
|
||||
/*dest_impl=*/impl.get(),
|
||||
/*version_counter=*/std::forward<VariableVersion>(version_counter),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
||||
impl->level_ = level_;
|
||||
impl->generation_ = generation_;
|
||||
impl->view_metas_ = view_metas_;
|
||||
impl->refresh_numel();
|
||||
impl->refresh_contiguous();
|
||||
return impl;
|
||||
}
|
||||
|
||||
@ -394,6 +424,18 @@ c10::intrusive_ptr<TensorImpl> FunctionalTensorWrapper::shallow_copy_and_detach(
|
||||
std::move(version_counter), allow_tensor_metadata_change);
|
||||
}
|
||||
|
||||
void FunctionalTensorWrapper::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
|
||||
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
|
||||
auto functional_impl =
|
||||
static_cast<FunctionalTensorWrapper*>(impl.get());
|
||||
copy_tensor_metadata_and_refresh(
|
||||
/*src_impl=*/functional_impl,
|
||||
/*dest_impl=*/this,
|
||||
/*version_counter=*/version_counter(),
|
||||
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
|
||||
}
|
||||
|
||||
|
||||
c10::Device FunctionalTensorWrapper::device_custom() const {
|
||||
return value_.unsafeGetTensorImpl()->device();
|
||||
}
|
||||
|
@ -211,6 +211,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
VariableVersion&& version_counter,
|
||||
bool allow_tensor_metadata_change) const;
|
||||
|
||||
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
||||
void copy_tensor_metadata_and_refresh(
|
||||
const FunctionalTensorWrapper* src_impl,
|
||||
FunctionalTensorWrapper* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change) const;
|
||||
|
||||
// Note that value is not taken by reference: internally, the wrapper will
|
||||
// change the value tensor that it points to over time.
|
||||
Tensor value_;
|
||||
@ -230,6 +237,13 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
|
||||
|
||||
size_t generation_ = 0;
|
||||
std::vector<at::functionalization::ViewMeta> view_metas_;
|
||||
|
||||
protected:
|
||||
static void copy_tensor_metadata(
|
||||
const FunctionalTensorWrapper* src_impl,
|
||||
FunctionalTensorWrapper* dest_impl,
|
||||
const c10::VariableVersion& version_counter,
|
||||
bool allow_tensor_metadata_change);
|
||||
};
|
||||
|
||||
// Utility functions for the functionalization pass.
|
||||
|
@ -1,5 +1,7 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch._lazy.metrics as metrics
|
||||
import torch._lazy.ts_backend
|
||||
@ -7,6 +9,8 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
torch._lazy.ts_backend.init()
|
||||
|
||||
NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+")
|
||||
|
||||
|
||||
class LazyFuncionalizationTest(TestCase):
|
||||
def test_lazy_init_with_view(self):
|
||||
@ -56,6 +60,39 @@ class LazyFuncionalizationTest(TestCase):
|
||||
self.assertEqual(cpu_out, lazy_out_1.to("cpu"))
|
||||
self.assertEqual(cpu_out, lazy_out_2.to("cpu"))
|
||||
|
||||
def test_data_assign(self):
|
||||
def text(lazyt):
|
||||
raw = torch._C._lazy._get_tensors_text([lazyt])
|
||||
return NODE_TYPE_PATTERN.sub("", raw)
|
||||
|
||||
origin = torch.rand(3, dtype=torch.float32)
|
||||
tensor = origin.to("lazy")
|
||||
|
||||
self.assertExpectedInline(
|
||||
text(tensor),
|
||||
"""\
|
||||
IR {
|
||||
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0
|
||||
}
|
||||
""",
|
||||
)
|
||||
|
||||
# Modify the data-type of tensor, and assign it to 'data'.
|
||||
# This should update the inner tensor of FunctionalTensorWrapper,
|
||||
# changing the corresponding IR node.
|
||||
modified_tensor = tensor.to(torch.bfloat16)
|
||||
tensor.data = modified_tensor
|
||||
|
||||
self.assertExpectedInline(
|
||||
text(tensor),
|
||||
"""\
|
||||
IR {
|
||||
%0 = [Float[3]] lazy_tensors::device_data(), device=CPU0
|
||||
%1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0
|
||||
}
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user