mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added autograd support for C->C functions and enabled requires_grad=True for complex (#36932)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36932 Differential Revision: D21181230 Pulled By: anjali411 fbshipit-source-id: 295f2cd1e2b9918a8b2cb88cab0536b2407dc455
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1beca4ac6a
commit
6e92579883
@ -12,13 +12,13 @@
|
||||
using namespace torch::test;
|
||||
|
||||
template <typename T>
|
||||
bool exactly_equal(at::Tensor left, T right) {
|
||||
return left.item<T>() == right;
|
||||
}
|
||||
bool exactly_equal(at::Tensor left, T right) {
|
||||
return left.item<T>() == right;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
|
||||
return std::abs(left.item<T>() - right) < tolerance;
|
||||
template <typename T>
|
||||
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
|
||||
return std::abs(left.item<T>() - right) < tolerance;
|
||||
}
|
||||
|
||||
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
|
||||
@ -609,7 +609,7 @@ void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType d
|
||||
AutoDefaultDtypeMode dtype_mode(default_dtype);
|
||||
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype);
|
||||
}
|
||||
@ -913,5 +913,5 @@ TEST(TensorTest, RequiresGradInplace) {
|
||||
|
||||
const auto int_tensor = torch::tensor({5}, at::TensorOptions().dtype(torch::kInt));
|
||||
ASSERT_THROWS_WITH(int_tensor.requires_grad_(true),
|
||||
"Only Tensors of floating point dtype can require gradients");
|
||||
"Only Tensors of floating point and complex dtype can require gradients");
|
||||
}
|
||||
|
Reference in New Issue
Block a user