mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add C++-only int dim
overloads to std
-related operations (#40451)
Summary: Fixes gh-40287 The `int -> bool` conversion takes higher precedence than `int -> IntArrayRef`. So, calling `std(0)` in C++ would select the `std(unbiased=False)` overload instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40451 Differential Revision: D22217926 Pulled By: ezyang fbshipit-source-id: 7520792fab5ab6665bddd03b6f57444c6c729af4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a208a272cb
commit
16f276cef9
@ -1046,3 +1046,25 @@ TEST(TensorTest, RequiresGradInplace) {
|
||||
ASSERT_THROWS_WITH(int_tensor.requires_grad_(true),
|
||||
"Only Tensors of floating point and complex dtype can require gradients");
|
||||
}
|
||||
|
||||
TEST(TensorTest, StdDimension) {
|
||||
// Test that std(0) doesn't select the std(unbiased=False) overload (gh-40287)
|
||||
auto x = torch::randn({4, 3});
|
||||
auto std = x.std(0);
|
||||
|
||||
ASSERT_EQ(x.var(0).numel(), 3);
|
||||
ASSERT_EQ(x.std(0).numel(), 3);
|
||||
|
||||
ASSERT_EQ(x.var(0, /*unbiased=*/true).numel(), 3);
|
||||
ASSERT_EQ(x.std(0, /*unbiased=*/true).numel(), 3);
|
||||
|
||||
ASSERT_EQ(torch::var(x, 0).numel(), 3);
|
||||
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0)).numel(), 3);
|
||||
ASSERT_EQ(torch::std(x, 0).numel(), 3);
|
||||
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0)).numel(), 3);
|
||||
|
||||
ASSERT_EQ(torch::var(x, 0, /*unbiased=*/true).numel(), 3);
|
||||
ASSERT_EQ(std::get<0>(torch::var_mean(x, 0, /*unbiased=*/true)).numel(), 3);
|
||||
ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
|
||||
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
|
||||
}
|
||||
|
Reference in New Issue
Block a user