Extend TensorImpl with BackendMeta (#97429)

BackendMeta offers a binary interface for the backend to attach arbitrary data to TensorImpl. TensorImpl has exactly one "slot" for backend metadata, however backend is free to compose any structure that is opaque to the framework beyond iheriting standard BackendMeta base.

Change-Id: I670fcdd16dd1c2b00f7eaa1cbc5b5dfea59a6221

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97429
Approved by: https://github.com/ezyang
This commit is contained in:
Paweł Piskorski
2023-04-04 23:47:03 +00:00
committed by PyTorch MergeBot
parent dd503376bd
commit 2d9b2bcfba
2 changed files with 82 additions and 3 deletions

View File

@ -1217,3 +1217,44 @@ TEST(TensorTest, ReshapeAlias) {
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
}
TEST(TensorTest, BackendMetadata) {
// Tests ability to assign custom backend metadata to tensor.
struct CustomBackendMetadata : public c10::BackendMeta {
mutable bool cloned_{false}; // for testing this field will mutate when
// clone() is called by shallow_copy_from.
c10::intrusive_ptr<c10::BackendMeta> clone(
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};
at::Tensor y;
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
CustomBackendMetadata* custom_tmeta{nullptr};
{
auto x = torch::ones({3, 3});
auto impl{x.unsafeGetTensorImpl()};
ASSERT_TRUE(impl != nullptr);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == nullptr);
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
impl->set_backend_meta(new_tmeta);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == new_tmeta);
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
ASSERT_TRUE(custom_tmeta != nullptr);
ASSERT_TRUE(custom_tmeta->cloned_ == false);
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
}
ASSERT_TRUE(
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
ASSERT_TRUE(custom_tmeta->cloned_ == true);
}