mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
extract TestAutogradComplex into its own test file (#63400)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63400 This is the first step to break up test_autograd.py for #63205. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D30541499 Pulled By: dagitses fbshipit-source-id: 8d9d32007938b9eade0e88f95a6a3190e7e2ef01
This commit is contained in:
committed by
Facebook GitHub Bot
parent
be5b05c1dc
commit
cdb46f4c6e
103
test/autograd/test_complex.py
Normal file
103
test/autograd/test_complex.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, gradcheck
|
||||
|
||||
|
||||
class TestAutogradComplex(TestCase):
|
||||
def test_view_func_for_complex_views(self):
|
||||
# case 1: both parent and child have view_func
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
x0 = x.clone()
|
||||
x1 = torch.view_as_complex(x0)
|
||||
x2 = torch.view_as_real(x1)
|
||||
x2.mul_(2)
|
||||
x2.sum().backward()
|
||||
|
||||
y0 = y.clone()
|
||||
y0.mul_(2)
|
||||
y0.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
# case 2: parent has view_func but child does not
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
def fn(a):
|
||||
b = a.clone()
|
||||
b1 = torch.view_as_complex(b)
|
||||
b2 = b1.reshape(b1.numel())
|
||||
return b2
|
||||
|
||||
x0 = fn(x)
|
||||
x0.mul_(2)
|
||||
x0.sum().backward()
|
||||
|
||||
y0 = fn(y)
|
||||
y1 = y0.mul(2)
|
||||
y1.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
# case 3: parent does not have a view_func but child does
|
||||
x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
def fn(a, dim0_size=5):
|
||||
b = a.clone()
|
||||
b1 = b.reshape(dim0_size, 2)
|
||||
b2 = torch.view_as_real(b1)
|
||||
return b2
|
||||
|
||||
x0 = fn(x)
|
||||
x0.mul_(2)
|
||||
x0.sum().backward()
|
||||
|
||||
y0 = fn(y)
|
||||
y1 = y0.mul(2)
|
||||
y1.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
def test_view_with_multi_output(self):
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double)
|
||||
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
x.requires_grad_(True)
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
def as_identity(self):
|
||||
# view_as_real and view_as_complex behavior should be like an identity
|
||||
def func(z):
|
||||
z_ = torch.view_as_complex(z)
|
||||
z_select = torch.select(z_, z_.dim() - 1, 0)
|
||||
z_select_real = torch.view_as_real(z_select)
|
||||
return z_select_real.sum()
|
||||
|
||||
z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(func, [z])
|
||||
func(z).backward()
|
||||
|
||||
z1 = z.clone().detach().requires_grad_(True)
|
||||
torch.select(z1, z1.dim() - 2, 0).sum().backward()
|
||||
|
||||
self.assertEqual(z.grad, z1.grad)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -28,7 +28,6 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
|
||||
suppress_warnings, slowTest,
|
||||
load_tests,
|
||||
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
|
||||
TEST_WITH_ROCM, disable_gc,
|
||||
gradcheck, gradgradcheck)
|
||||
@ -44,11 +43,6 @@ from torch.testing._internal.common_device_type import (instantiate_device_type_
|
||||
deviceCountAtLeast, skipCUDAIfCudnnVersionLessThan,
|
||||
skipCUDAIf, skipMeta)
|
||||
|
||||
|
||||
# load_tests from common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
load_tests = load_tests
|
||||
|
||||
import pickle
|
||||
|
||||
PRECISION = 1e-4
|
||||
@ -6173,101 +6167,6 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
|
||||
test_case.assertEqual(self_variable.size(), self_variable.grad.size())
|
||||
|
||||
|
||||
class TestAutogradComplex(TestCase):
|
||||
def test_view_func_for_complex_views(self):
|
||||
# case 1: both parent and child have view_func
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
x0 = x.clone()
|
||||
x1 = torch.view_as_complex(x0)
|
||||
x2 = torch.view_as_real(x1)
|
||||
x2.mul_(2)
|
||||
x2.sum().backward()
|
||||
|
||||
y0 = y.clone()
|
||||
y0.mul_(2)
|
||||
y0.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
# case 2: parent has view_func but child does not
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
def fn(a):
|
||||
b = a.clone()
|
||||
b1 = torch.view_as_complex(b)
|
||||
b2 = b1.reshape(b1.numel())
|
||||
return b2
|
||||
|
||||
x0 = fn(x)
|
||||
x0.mul_(2)
|
||||
x0.sum().backward()
|
||||
|
||||
y0 = fn(y)
|
||||
y1 = y0.mul(2)
|
||||
y1.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
# case 3: parent does not have a view_func but child does
|
||||
x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
|
||||
y = x.detach().requires_grad_(True)
|
||||
|
||||
def fn(a, dim0_size=5):
|
||||
b = a.clone()
|
||||
b1 = b.reshape(dim0_size, 2)
|
||||
b2 = torch.view_as_real(b1)
|
||||
return b2
|
||||
|
||||
x0 = fn(x)
|
||||
x0.mul_(2)
|
||||
x0.sum().backward()
|
||||
|
||||
y0 = fn(y)
|
||||
y1 = y0.mul(2)
|
||||
y1.sum().backward()
|
||||
|
||||
self.assertEqual(x.grad, y.grad)
|
||||
|
||||
def test_view_with_multi_output(self):
|
||||
x = torch.randn(2, 2, 2, dtype=torch.double)
|
||||
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
x.requires_grad_(True)
|
||||
x1 = torch.view_as_complex(x)
|
||||
# Taking an invalid view should always be allowed as long as it is not
|
||||
# modified inplace
|
||||
res = x1.unbind(0)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "output of a function that returns multiple views"):
|
||||
res[0] += torch.rand(2, requires_grad=True)
|
||||
|
||||
def as_identity(self):
|
||||
# view_as_real and view_as_complex behavior should be like an identity
|
||||
def func(z):
|
||||
z_ = torch.view_as_complex(z)
|
||||
z_select = torch.select(z_, z_.dim() - 1, 0)
|
||||
z_select_real = torch.view_as_real(z_select)
|
||||
return z_select_real.sum()
|
||||
|
||||
z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
|
||||
gradcheck(func, [z])
|
||||
func(z).backward()
|
||||
|
||||
z1 = z.clone().detach().requires_grad_(True)
|
||||
torch.select(z1, z1.dim() - 2, 0).sum().backward()
|
||||
|
||||
self.assertEqual(z.grad, z1.grad)
|
||||
|
||||
class TestAutogradFunctional(TestCase):
|
||||
def _assert_same_struct(self, res, base):
|
||||
# base and res should be Tensors or tuple of Tensors with the same size
|
||||
@ -9640,6 +9539,11 @@ class TestMultithreadAutograd(TestCase):
|
||||
torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
|
||||
torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
|
||||
|
||||
# Import test cases from below autograd/ here. These are found
|
||||
# implicitly by the loader, so Flake8 thinks they are unused, hence
|
||||
# the suppressions.
|
||||
|
||||
from autograd.test_complex import TestAutogradComplex # noqa: F401
|
||||
|
||||
# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
|
||||
instantiate_device_type_tests(
|
||||
|
@ -48,7 +48,10 @@ TARGET_DET_LIST = [
|
||||
"distributed/test_pg_wrapper",
|
||||
"distributed/test_store",
|
||||
"distributions/test_distributions",
|
||||
"test_autograd",
|
||||
# test_autograd.py is not slow, so it does not belong here. But
|
||||
# note that if you try to add it back it will run into
|
||||
# https://bugs.python.org/issue40350 because it imports files
|
||||
# under test/autograd/.
|
||||
"test_binary_ufuncs",
|
||||
"test_cpp_extensions_aot_ninja",
|
||||
"test_cpp_extensions_aot_no_ninja",
|
||||
|
Reference in New Issue
Block a user