mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43208 This PR adds gradcheck for complex. The logic used for complex gradcheck is described in Section 3.5.3 here: https://arxiv.org/pdf/1701.00392.pdf More concretely, this PR introduces the following changes: 1. Updates get_numerical_jacobian to take as input a scalar value for vector (v). Adds gradcheck logic for C -> C, C-> R, R -> C. For R -> C functions, only the real value of gradient is propagated. 2. Adds backward definition for `torch.complex` and also adds a test to verify the definition added. 3. Updates backward for `mul`, `sin`, `cos`, `sinh`, `cosh`. 4. Adds tests for all `torch.real`, `torch.imag`, `torch.view_as_real`, `torch.view_as_complex`, `torch.conj`. Follow up tasks: 1. Add more thorough tests for R -> C cases. Specifically, add R->C test variants for functions. for e.g., `torch.mul(complex_tensor, real_tensor)` 2. Add back commented test in `common_methods_invocation.py`. 3. Add more special case checking for complex gradcheck to make debugging easier. 4. Update complex autograd note. 5. disable complex autograd for operators not tested for complex. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D23655088 Pulled By: anjali411 fbshipit-source-id: caa75e09864b5f6ead0f988f6368dce64cf15deb
139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
from functools import partial, wraps
|
|
|
|
import torch
|
|
|
|
from torch.testing._internal.common_utils import \
|
|
(TestCase, run_tests)
|
|
from torch.testing._internal.common_methods_invocations import \
|
|
(op_db)
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm)
|
|
from torch.autograd.gradcheck import gradcheck, gradgradcheck
|
|
|
|
|
|
# Tests that apply to all operators
|
|
|
|
class TestOpInfo(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Verifies that ops have their unsupported dtypes
|
|
# registered correctly by testing that each claimed unsupported dtype
|
|
# throws a runtime error
|
|
@skipCUDAIfRocm
|
|
@onlyOnCPUAndCUDA
|
|
@ops(op_db, unsupported_dtypes_only=True)
|
|
def test_unsupported_dtypes(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
# NOTE: only tests on first sample
|
|
sample = samples[0]
|
|
with self.assertRaises(RuntimeError):
|
|
op(sample.input, *sample.args, **sample.kwargs)
|
|
|
|
# Verifies that ops have their supported dtypes
|
|
# registered correctly by testing that each claimed supported dtype
|
|
# does NOT throw a runtime error
|
|
@skipCUDAIfRocm
|
|
@onlyOnCPUAndCUDA
|
|
@ops(op_db)
|
|
def test_supported_dtypes(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype)
|
|
if len(samples) == 0:
|
|
self.skipTest("Skipped! No sample inputs!")
|
|
|
|
# NOTE: only tests on first sample
|
|
sample = samples[0]
|
|
op(sample.input, *sample.args, **sample.kwargs)
|
|
|
|
|
|
class TestGradients(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Copies inputs to inplace operations to avoid inplace modifications
|
|
# to leaves requiring gradient
|
|
def _get_safe_inplace(self, inplace_variant):
|
|
@wraps(inplace_variant)
|
|
def _fn(t, *args, **kwargs):
|
|
return inplace_variant(t.clone(), *args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
def _check_helper(self, device, dtype, op, variant, check):
|
|
if variant is None:
|
|
self.skipTest("Skipped! Variant not implemented.")
|
|
if not op.supports_dtype(dtype, torch.device(device).type):
|
|
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
for sample in samples:
|
|
partial_fn = partial(variant, **sample.kwargs)
|
|
if check == 'gradcheck':
|
|
self.assertTrue(gradcheck(partial_fn, (sample.input,) + sample.args,
|
|
check_grad_dtypes=True))
|
|
elif check == 'gradgradcheck':
|
|
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
|
|
gen_non_contig_grad_outputs=False,
|
|
check_grad_dtypes=True))
|
|
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
|
|
gen_non_contig_grad_outputs=True,
|
|
check_grad_dtypes=True))
|
|
else:
|
|
self.assertTrue(False, msg="Unknown check requested!")
|
|
|
|
def _grad_test_helper(self, device, dtype, op, variant):
|
|
return self._check_helper(device, dtype, op, variant, 'gradcheck')
|
|
|
|
def _gradgrad_test_helper(self, device, dtype, op, variant):
|
|
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
|
|
|
|
# Tests that gradients are computed correctly
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_fn_grad(self, device, dtype, op):
|
|
self._grad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_method_grad(self, device, dtype, op):
|
|
self._grad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_inplace_grad(self, device, dtype, op):
|
|
if not op.test_inplace_grad:
|
|
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
|
|
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
# Test that gradients of gradients are computed correctly
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_fn_gradgrad(self, device, dtype, op):
|
|
self._gradgrad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_method_gradgrad(self, device, dtype, op):
|
|
self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
# TODO(@anjali411) enable this for torch.cdouble.
|
|
@dtypes(torch.double)
|
|
@ops(op_db)
|
|
def test_inplace_gradgrad(self, device, dtype, op):
|
|
if not op.test_inplace_grad:
|
|
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
|
|
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
|
|
instantiate_device_type_tests(TestOpInfo, globals())
|
|
instantiate_device_type_tests(TestGradients, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|