mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR makes the following improvements: - moves the custom skip list for test_normalize_operator_exhaustive in test_fx_experimental to use the typical OpInfo skip architecture. The skips were updated to xfails, and that identified some operators which were no longer failing the test - redundant tests with OpInfo-based testing in test_jit.py were removed - test_dtypes was improved so its error messages are clear and it makes test_nondifferentiable redundant; the latter test has been removed - OpInfo.supports_complex_autograd() is removed in favor of a more accurate and general test for whether the particular dtype is in the backward dtypes of the operator - gradchecks have been improved to verify that an operator doesn't support grad if it claims not to - gradchecks have been improved to test the gradient of all input tensors that require gradient - the concept of "default test dtypes" has been removed - excessive and mostly redundant out testing for elementwise unary operators has been removed - metadata for whether an op supports nuanced "safe casting" to out behavior has been removed from OpInfos - numerous skips have been converted to xfails - numerous OpInfos have had their metadata fixed based on the new checks - jit-specific utilities in common_methods_invocations.py have been moved to jit_programming_utils.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/75951 Approved by: https://github.com/ngimel
265 lines
13 KiB
Python
265 lines
13 KiB
Python
# Owner(s): ["module: unknown"]
|
|
|
|
from functools import partial, wraps
|
|
from itertools import chain
|
|
import torch
|
|
|
|
from torch.testing._internal.common_utils import \
|
|
(TestCase, is_iterable_of_tensors, run_tests, gradcheck, gradgradcheck)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, ops, OpDTypes)
|
|
|
|
# TODO: fixme https://github.com/pytorch/pytorch/issues/68972
|
|
torch.set_default_dtype(torch.float32)
|
|
|
|
# gradcheck requires double precision
|
|
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
|
allowed_dtypes=[torch.double, torch.cdouble])
|
|
|
|
class TestGradients(TestCase):
|
|
exact_dtype = True
|
|
|
|
# Copies inputs to inplace operations to avoid inplace modifications
|
|
# to leaves requiring gradient
|
|
def _get_safe_inplace(self, inplace_variant):
|
|
@wraps(inplace_variant)
|
|
def _fn(t, *args, **kwargs):
|
|
return inplace_variant(t.clone(), *args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
def _check_helper(self, device, dtype, op, variant, check, *, check_forward_ad=False, check_backward_ad=True,
|
|
check_batched_grad=None, check_batched_forward_grad=False):
|
|
assert check in ('gradcheck', 'bwgrad_bwgrad', 'fwgrad_bwgrad')
|
|
# NB: check_backward_ad does not affect gradgradcheck (always True)
|
|
if variant is None:
|
|
self.skipTest("Skipped! Variant not implemented.")
|
|
if not op.supports_dtype(dtype, torch.device(device).type):
|
|
self.skipTest(f"Skipped! {op.name} does not support dtype {str(dtype)}")
|
|
|
|
def is_inplace(variant):
|
|
if hasattr(variant, "__wrapped__"):
|
|
return variant.__wrapped__ is op.get_inplace()
|
|
return variant is op.get_inplace()
|
|
|
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True, include_conjugated_inputs=include_conjugated_inputs)
|
|
|
|
for sample in samples:
|
|
if sample.broadcasts_input and is_inplace(variant):
|
|
continue
|
|
|
|
# Gradcheck expects tensors as its input, but autograd actually supports tensorlists
|
|
# and tensors passed as kwargs. The following creates a function that accepts just
|
|
# the tensors that require grad as varargs, and then recomposes them back into the
|
|
# original input.
|
|
|
|
# Creates gradcheck inputs by identifying tensors requiring grad
|
|
all_args = None
|
|
if is_iterable_of_tensors(sample.input):
|
|
all_args = chain(sample.input, sample.args, sample.kwargs.values())
|
|
else:
|
|
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
|
|
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
|
|
|
|
def _input_recomposition_helper(inputs, inp, input_idx):
|
|
if is_iterable_of_tensors(inp):
|
|
tensor_list = []
|
|
for x in inp:
|
|
if isinstance(x, torch.Tensor) and x.requires_grad:
|
|
tensor_list.append(inputs[input_idx])
|
|
input_idx = input_idx + 1
|
|
else:
|
|
tensor_list.append(x)
|
|
return tensor_list, input_idx
|
|
elif isinstance(inp, torch.Tensor) and inp.requires_grad:
|
|
return inputs[input_idx], input_idx + 1
|
|
else:
|
|
return inp, input_idx
|
|
|
|
def fn(*inputs):
|
|
# Puts inputs back into sample properly
|
|
positional_args = []
|
|
input_idx = 0
|
|
inp, input_idx = _input_recomposition_helper(inputs, sample.input, input_idx)
|
|
positional_args.append(inp)
|
|
|
|
for x in sample.args:
|
|
inp, input_idx = _input_recomposition_helper(inputs, x, input_idx)
|
|
positional_args.append(inp)
|
|
|
|
# Recreates kwargs
|
|
kwargs = {}
|
|
for k, v in sample.kwargs.items():
|
|
inp, input_idx = _input_recomposition_helper(inputs, v, input_idx)
|
|
kwargs[k] = inp
|
|
|
|
output = op.gradcheck_wrapper(variant, *positional_args, **kwargs)
|
|
if sample.output_process_fn_grad is not None:
|
|
return sample.output_process_fn_grad(output)
|
|
return output
|
|
|
|
if check == 'gradcheck':
|
|
if check_batched_grad is None:
|
|
check_batched_grad = op.check_batched_grad
|
|
self.assertTrue(gradcheck(fn, gradcheck_args,
|
|
check_batched_grad=check_batched_grad,
|
|
check_grad_dtypes=True,
|
|
nondet_tol=op.gradcheck_nondet_tol,
|
|
fast_mode=op.gradcheck_fast_mode,
|
|
check_forward_ad=check_forward_ad,
|
|
check_backward_ad=check_backward_ad,
|
|
check_undefined_grad=True,
|
|
check_batched_forward_grad=check_batched_forward_grad))
|
|
elif check in ('bwgrad_bwgrad', 'fwgrad_bwgrad'): # gradgrad check
|
|
self.assertFalse(check_forward_ad, msg="Cannot run forward AD check for gradgradcheck")
|
|
for gen_non_contig_grad_outputs in (False, True):
|
|
kwargs = {
|
|
"gen_non_contig_grad_outputs": gen_non_contig_grad_outputs,
|
|
"check_batched_grad": op.check_batched_gradgrad,
|
|
"check_grad_dtypes": True,
|
|
"nondet_tol": op.gradcheck_nondet_tol,
|
|
"fast_mode": op.gradcheck_fast_mode
|
|
}
|
|
if check == "fwgrad_bwgrad":
|
|
kwargs["check_fwd_over_rev"] = True
|
|
kwargs["check_rev_over_rev"] = False
|
|
kwargs["check_batched_grad"] = False
|
|
kwargs["check_undefined_grad"] = False
|
|
|
|
self.assertTrue(gradgradcheck(fn, gradcheck_args, **kwargs))
|
|
else:
|
|
self.assertTrue(False, msg="Unknown check requested!")
|
|
|
|
def _grad_test_helper(self, device, dtype, op, variant, *, check_forward_ad=False, check_backward_ad=True,
|
|
check_batched_grad=None, check_batched_forward_grad=False):
|
|
return self._check_helper(device, dtype, op, variant, 'gradcheck', check_forward_ad=check_forward_ad,
|
|
check_backward_ad=check_backward_ad, check_batched_grad=check_batched_grad,
|
|
check_batched_forward_grad=check_batched_forward_grad)
|
|
|
|
def _skip_helper(self, op, device, dtype):
|
|
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
|
self.skipTest("Skipped! Op doesn't support autograd for this dtype.")
|
|
if not op.supports_autograd and not op.supports_forward_ad:
|
|
self.skipTest("Skipped! autograd not supported.")
|
|
|
|
# Tests that gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_grad(self, device, dtype, op):
|
|
# This is verified by test_dtypes in test_ops.py
|
|
if dtype not in op.supported_backward_dtypes(torch.device(device).type):
|
|
self.skipTest("Skipped! Dtype is not in supported backward dtypes!")
|
|
else:
|
|
self._grad_test_helper(device, dtype, op, op.get_op())
|
|
|
|
# Method grad (and gradgrad, see below) tests are disabled since they're
|
|
# costly and redundant with function grad (and gradgad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_grad(self, device, dtype, op):
|
|
# self._skip_helper(op, device, dtype)
|
|
# self._grad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_grad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.inplace_variant:
|
|
self.skipTest("Op has no inplace variant!")
|
|
|
|
# Verifies an operation doesn't support inplace autograd if it claims not to
|
|
if not op.supports_inplace_autograd:
|
|
inplace = self._get_safe_inplace(op.get_inplace())
|
|
for sample in op.sample_inputs(device, dtype, requires_grad=True):
|
|
if sample.broadcasts_input:
|
|
continue
|
|
with self.assertRaises(Exception):
|
|
result = inplace(sample)
|
|
result.sum().backward()
|
|
else:
|
|
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
|
|
|
# Test that gradients of gradients are computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.supports_gradgrad:
|
|
self.skipTest("Op claims it doesn't support gradgrad. This is not verified.")
|
|
else:
|
|
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
|
|
|
# Test that forward-over-reverse gradgrad is computed correctly
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_fwgrad_bwgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
if op.supports_fwgrad_bwgrad:
|
|
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
|
else:
|
|
err_msg = r"Trying to use forward AD with .* that does not support it"
|
|
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
|
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
|
|
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
|
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
|
|
|
# Test that gradients of gradients are properly raising
|
|
@_gradcheck_ops(op_db)
|
|
def test_fn_fail_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if op.supports_gradgrad:
|
|
self.skipTest("Skipped! Operation does support gradgrad")
|
|
|
|
err_msg = r"derivative for .* is not implemented"
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
|
|
|
# Method gradgrad (and grad, see above) tests are disabled since they're
|
|
# costly and redundant with function gradgrad (and grad) tests
|
|
# @_gradcheck_ops(op_db)
|
|
# def test_method_gradgrad(self, device, dtype, op):
|
|
# self._skip_helper(op, device, dtype)
|
|
# self._gradgrad_test_helper(device, dtype, op, op.get_method())
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_gradgrad(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
if not op.inplace_variant or not op.supports_inplace_autograd:
|
|
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
|
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
|
|
|
|
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
|
|
# TODO: clean up how attributes are passed to gradcheck from OpInfos
|
|
def call_grad_test_helper():
|
|
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
|
|
(op.check_inplace_batched_forward_grad and is_inplace))
|
|
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
|
|
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
|
|
if op.supports_forward_ad:
|
|
call_grad_test_helper()
|
|
else:
|
|
err_msg = r"Trying to use forward AD with .* that does not support it"
|
|
hint_msg = ("Running forward AD for an OP that has does not support it did not "
|
|
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
|
|
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
|
call_grad_test_helper()
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_forward_mode_AD(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
self._forward_grad_helper(device, dtype, op, op.get_op(), is_inplace=False)
|
|
|
|
@_gradcheck_ops(op_db)
|
|
def test_inplace_forward_mode_AD(self, device, dtype, op):
|
|
self._skip_helper(op, device, dtype)
|
|
|
|
if not op.inplace_variant or not op.supports_inplace_autograd:
|
|
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
|
|
|
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
|
|
|
|
|
|
instantiate_device_type_tests(TestGradients, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|