Revert "Extend TensorImpl with BackendMeta (#97429)"

This reverts commit bc38b278bf4c2890700f8fe751cfd15fcb01da60.

Reverted https://github.com/pytorch/pytorch/pull/97429 on behalf of https://github.com/huydhn due to Sorry for reverting your PR as I am trying to root cause a libtorch build failure on Windows starting from your change bc38b278bf.  AFAICT, there is no other change from the log.  I will reland this if the failure is unrelated
This commit is contained in:
PyTorch MergeBot
2023-04-04 05:13:18 +00:00
parent 8f2f1a0b32
commit 7eaaefafb3
2 changed files with 3 additions and 82 deletions

View File

@ -1217,44 +1217,3 @@ TEST(TensorTest, ReshapeAlias) {
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
}
TEST(TensorTest, BackendMetadata) {
// Tests ability to assign custom backend metadata to tensor.
struct CustomBackendMetadata : public c10::BackendMeta {
mutable bool cloned_{false}; // for testing this field will mutate when
// clone() is called by shallow_copy_from.
c10::intrusive_ptr<c10::BackendMeta> clone(
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
cloned_ = true;
return c10::BackendMeta::clone(ptr);
}
};
at::Tensor y;
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
CustomBackendMetadata* custom_tmeta{nullptr};
{
auto x = torch::ones({3, 3});
auto impl{x.unsafeGetTensorImpl()};
ASSERT_TRUE(impl != nullptr);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == nullptr);
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
impl->set_backend_meta(new_tmeta);
tmeta = impl->get_backend_meta_intrusive_ptr();
ASSERT_TRUE(tmeta == new_tmeta);
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
ASSERT_TRUE(custom_tmeta != nullptr);
ASSERT_TRUE(custom_tmeta->cloned_ == false);
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
}
ASSERT_TRUE(
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
ASSERT_TRUE(custom_tmeta->cloned_ == true);
}