Added autograd support for C->C functions and enabled requires_grad=True for complex (#36932)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36932

Differential Revision: D21181230

Pulled By: anjali411

fbshipit-source-id: 295f2cd1e2b9918a8b2cb88cab0536b2407dc455
This commit is contained in:
anjali411
2020-04-24 12:20:26 -07:00
committed by Facebook GitHub Bot
parent 1beca4ac6a
commit 6e92579883
11 changed files with 58 additions and 23 deletions

View File

@ -12,13 +12,13 @@
using namespace torch::test;
template <typename T>
bool exactly_equal(at::Tensor left, T right) {
return left.item<T>() == right;
}
bool exactly_equal(at::Tensor left, T right) {
return left.item<T>() == right;
}
template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(left.item<T>() - right) < tolerance;
template <typename T>
bool almost_equal(at::Tensor left, T right, T tolerance = 1e-4) {
return std::abs(left.item<T>() - right) < tolerance;
}
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
@ -609,7 +609,7 @@ void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(c10::ScalarType d
AutoDefaultDtypeMode dtype_mode(default_dtype);
ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(), default_dtype);
ASSERT_EQ(torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(), default_dtype);
}
@ -913,5 +913,5 @@ TEST(TensorTest, RequiresGradInplace) {
const auto int_tensor = torch::tensor({5}, at::TensorOptions().dtype(torch::kInt));
ASSERT_THROWS_WITH(int_tensor.requires_grad_(true),
"Only Tensors of floating point dtype can require gradients");
"Only Tensors of floating point and complex dtype can require gradients");
}