[autograd] disable backward/grad for complex scalar output (#92753)

Fixes https://github.com/pytorch/pytorch/issues/92750

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92753
Approved by: https://github.com/ezyang
This commit is contained in:
kshitij12345
2023-02-23 11:38:27 +00:00
committed by PyTorch MergeBot
parent b5ff41a47a
commit 3b966a6ce3
16 changed files with 130 additions and 62 deletions

View File

@ -1099,6 +1099,13 @@ TEST(TensorTest, BackwardNonScalarOutputs) {
y.backward(), "grad can be implicitly created only for scalar outputs");
}
TEST(TensorTest, BackwardComplexScalarOutput) {
auto x = torch::randn({5, 5}, torch::requires_grad());
auto y = (x * c10::Scalar(c10::complex<float>(0, 0.5))).sum();
ASSERT_THROWS_WITH(
y.backward(), "grad can be computed only for real scalar outputs");
}
TEST(TensorTest, IsLeaf) {
auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
auto y = x * x;