mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Extend TensorImpl with BackendMeta (#97429)"
This reverts commit bc38b278bf4c2890700f8fe751cfd15fcb01da60.
Reverted https://github.com/pytorch/pytorch/pull/97429 on behalf of https://github.com/huydhn due to Sorry for reverting your PR as I am trying to root cause a libtorch build failure on Windows starting from your change bc38b278bf
. AFAICT, there is no other change from the log. I will reland this if the failure is unrelated
This commit is contained in:
@ -1217,44 +1217,3 @@ TEST(TensorTest, ReshapeAlias) {
|
||||
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
|
||||
ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
|
||||
}
|
||||
|
||||
TEST(TensorTest, BackendMetadata) {
|
||||
// Tests ability to assign custom backend metadata to tensor.
|
||||
|
||||
struct CustomBackendMetadata : public c10::BackendMeta {
|
||||
mutable bool cloned_{false}; // for testing this field will mutate when
|
||||
// clone() is called by shallow_copy_from.
|
||||
c10::intrusive_ptr<c10::BackendMeta> clone(
|
||||
const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
|
||||
cloned_ = true;
|
||||
return c10::BackendMeta::clone(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
at::Tensor y;
|
||||
c10::intrusive_ptr<c10::BackendMeta> tmeta{};
|
||||
CustomBackendMetadata* custom_tmeta{nullptr};
|
||||
|
||||
{
|
||||
auto x = torch::ones({3, 3});
|
||||
auto impl{x.unsafeGetTensorImpl()};
|
||||
ASSERT_TRUE(impl != nullptr);
|
||||
|
||||
tmeta = impl->get_backend_meta_intrusive_ptr();
|
||||
ASSERT_TRUE(tmeta == nullptr);
|
||||
c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
|
||||
std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
|
||||
impl->set_backend_meta(new_tmeta);
|
||||
tmeta = impl->get_backend_meta_intrusive_ptr();
|
||||
ASSERT_TRUE(tmeta == new_tmeta);
|
||||
custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
|
||||
ASSERT_TRUE(custom_tmeta != nullptr);
|
||||
ASSERT_TRUE(custom_tmeta->cloned_ == false);
|
||||
y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
|
||||
}
|
||||
|
||||
ASSERT_TRUE(
|
||||
tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
|
||||
ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
|
||||
ASSERT_TRUE(custom_tmeta->cloned_ == true);
|
||||
}
|
||||
|
Reference in New Issue
Block a user