For CriterionTests, have check_gradgrad actually only affect gradgrad checks. (#44060)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44060

Right now it skips grad checks as well.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D23484018

Pulled By: gchanan

fbshipit-source-id: 24a8f1af41f9918aaa62bc3cd78b139b2f8de1e1
This commit is contained in:
Gregory Chanan
2020-09-03 12:27:40 -07:00
committed by Facebook GitHub Bot
parent 42f9897983
commit 49215d7f26

View File

@ -5070,9 +5070,6 @@ class CriterionTest(InputVariableMixin, TestBase):
self._do_extra_tests(test_case, module, input, target)
def _do_extra_tests(self, test_case, module, input, target):
if not self.check_gradgrad:
return
test_case.assertFalse(target.requires_grad)
params = tuple(x for x in module.parameters())
@ -5090,6 +5087,10 @@ class CriterionTest(InputVariableMixin, TestBase):
# TODO: we don't pass `target` as part of inputs because we don't
# currently compute the gradient w.r.t. target for loss functions.
gradcheck(apply_fn, inputs)
if not self.check_gradgrad:
return
gradgradcheck(apply_fn, inputs)
def test_cuda(self, test_case, dtype=None, extra_args=None):