Files
pytorch/test/test_ops.py
anjali411 9f67176b82 Complex gradcheck logic (#43208)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43208

This PR adds gradcheck for complex. The logic used for complex gradcheck is described in Section 3.5.3 here: https://arxiv.org/pdf/1701.00392.pdf

More concretely, this PR introduces the following changes:
1. Updates get_numerical_jacobian to take as input a scalar value for vector (v). Adds gradcheck logic for C -> C, C-> R, R -> C. For R -> C functions, only the real value of gradient is propagated.
2. Adds backward definition for `torch.complex` and also adds a test to verify the definition added.
3. Updates backward for `mul`, `sin`, `cos`, `sinh`, `cosh`.
4. Adds tests for all `torch.real`, `torch.imag`, `torch.view_as_real`, `torch.view_as_complex`, `torch.conj`.

Follow up tasks:
1. Add more thorough tests for R -> C cases. Specifically, add R->C test variants for functions. for e.g., `torch.mul(complex_tensor, real_tensor)`
2. Add back commented test in `common_methods_invocation.py`.
3. Add more special case checking for complex gradcheck to make debugging easier.
4. Update complex autograd note.
5. disable complex autograd for operators not tested for complex.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D23655088

Pulled By: anjali411

fbshipit-source-id: caa75e09864b5f6ead0f988f6368dce64cf15deb
2020-09-20 22:05:04 -07:00

139 lines
5.3 KiB
Python

from functools import partial, wraps
import torch
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_methods_invocations import \
(op_db)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, dtypes, onlyOnCPUAndCUDA, skipCUDAIfRocm)
from torch.autograd.gradcheck import gradcheck, gradgradcheck
# Tests that apply to all operators
class TestOpInfo(TestCase):
exact_dtype = True
# Verifies that ops have their unsupported dtypes
# registered correctly by testing that each claimed unsupported dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db, unsupported_dtypes_only=True)
def test_unsupported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
with self.assertRaises(RuntimeError):
op(sample.input, *sample.args, **sample.kwargs)
# Verifies that ops have their supported dtypes
# registered correctly by testing that each claimed supported dtype
# does NOT throw a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(op_db)
def test_supported_dtypes(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
# NOTE: only tests on first sample
sample = samples[0]
op(sample.input, *sample.args, **sample.kwargs)
class TestGradients(TestCase):
exact_dtype = True
# Copies inputs to inplace operations to avoid inplace modifications
# to leaves requiring gradient
def _get_safe_inplace(self, inplace_variant):
@wraps(inplace_variant)
def _fn(t, *args, **kwargs):
return inplace_variant(t.clone(), *args, **kwargs)
return _fn
def _check_helper(self, device, dtype, op, variant, check):
if variant is None:
self.skipTest("Skipped! Variant not implemented.")
if not op.supports_dtype(dtype, torch.device(device).type):
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
samples = op.sample_inputs(device, dtype, requires_grad=True)
for sample in samples:
partial_fn = partial(variant, **sample.kwargs)
if check == 'gradcheck':
self.assertTrue(gradcheck(partial_fn, (sample.input,) + sample.args,
check_grad_dtypes=True))
elif check == 'gradgradcheck':
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=False,
check_grad_dtypes=True))
self.assertTrue(gradgradcheck(partial_fn, (sample.input,) + sample.args,
gen_non_contig_grad_outputs=True,
check_grad_dtypes=True))
else:
self.assertTrue(False, msg="Unknown check requested!")
def _grad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradcheck')
def _gradgrad_test_helper(self, device, dtype, op, variant):
return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
# Tests that gradients are computed correctly
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_fn_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_op())
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_method_grad(self, device, dtype, op):
self._grad_test_helper(device, dtype, op, op.get_method())
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_inplace_grad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
# TODO(@anjali411) enable this for torch.cdouble.
# Test that gradients of gradients are computed correctly
@dtypes(torch.double)
@ops(op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_op())
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_method_gradgrad(self, device, dtype, op):
self._gradgrad_test_helper(device, dtype, op, op.get_method())
# TODO(@anjali411) enable this for torch.cdouble.
@dtypes(torch.double)
@ops(op_db)
def test_inplace_gradgrad(self, device, dtype, op):
if not op.test_inplace_grad:
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
instantiate_device_type_tests(TestOpInfo, globals())
instantiate_device_type_tests(TestGradients, globals())
if __name__ == '__main__':
run_tests()