mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
C++ API parity: at::Tensor::output_nr
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26216 Test Plan: Imported from OSS Differential Revision: D17427576 Pulled By: pbelevich fbshipit-source-id: 351c834c6c44a2a2f915e48a1e8aa8ad7f4274b3
This commit is contained in:
committed by
Facebook Github Bot
parent
97c8c18a21
commit
fc3e1a22da
@ -352,3 +352,16 @@ TEST(TensorTest, IsLeaf) {
|
||||
ASSERT_THROWS_WITH(y.is_leaf(), message);
|
||||
ASSERT_THROWS_WITH(x.is_leaf(), message);
|
||||
}
|
||||
|
||||
TEST(TensorTest, OutputNr) {
|
||||
auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true));
|
||||
auto y = x * x;
|
||||
ASSERT_EQ(x.output_nr(), 0);
|
||||
ASSERT_EQ(y.output_nr(), 0);
|
||||
|
||||
x = at::tensor({5});
|
||||
y = x * x;
|
||||
const auto message = "output_nr is not implemented for Tensor";
|
||||
ASSERT_THROWS_WITH(y.output_nr(), message);
|
||||
ASSERT_THROWS_WITH(x.output_nr(), message);
|
||||
}
|
||||
|
Reference in New Issue
Block a user