mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Extend TensorImpl with BackendMeta (#97429)
BackendMeta offers a binary interface for the backend to attach arbitrary data to TensorImpl. TensorImpl has exactly one "slot" for backend metadata, however backend is free to compose any structure that is opaque to the framework beyond iheriting standard BackendMeta base. Change-Id: I670fcdd16dd1c2b00f7eaa1cbc5b5dfea59a6221 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/97429 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd503376bd
commit
2d9b2bcfba
@ -1217,3 +1217,44 @@ TEST(TensorTest, ReshapeAlias) {
|
||||
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
|
||||
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
|
||||
}
|
||||
|
||||
TEST(TensorTest, BackendMetadata) {
|
||||
// Tests ability to assign custom backend metadata to tensor.
|
||||
|
||||
struct CustomBackendMetadata : public c10::BackendMeta {
|
||||
mutable bool cloned_{false}; // for testing this field will mutate when
|
||||
// clone() is called by shallow_copy_from.
|
||||
c10::intrusive_ptr<c10::BackendMeta> clone(
|
||||
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
|
||||
cloned_ = true;
|
||||
return c10::BackendMeta::clone(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor y;
|
||||
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
|
||||
CustomBackendMetadata* custom_tmeta{nullptr};
|
||||
|
||||
{
|
||||
auto x = torch::ones({3, 3});
|
||||
auto impl{x.unsafeGetTensorImpl()};
|
||||
ASSERT_TRUE(impl != nullptr);
|
||||
|
||||
tmeta = impl->get_backend_meta_intrusive_ptr();
|
||||
ASSERT_TRUE(tmeta == nullptr);
|
||||
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
|
||||
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
|
||||
impl->set_backend_meta(new_tmeta);
|
||||
tmeta = impl->get_backend_meta_intrusive_ptr();
|
||||
ASSERT_TRUE(tmeta == new_tmeta);
|
||||
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
|
||||
ASSERT_TRUE(custom_tmeta != nullptr);
|
||||
ASSERT_TRUE(custom_tmeta->cloned_ == false);
|
||||
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
|
||||
}
|
||||
|
||||
ASSERT_TRUE(
|
||||
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
|
||||
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
|
||||
ASSERT_TRUE(custom_tmeta->cloned_ == true);
|
||||
}
|
||||
|
Reference in New Issue
Block a user