mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
For CriterionTests, have check_gradgrad actually only affect gradgrad checks. (#44060)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44060 Right now it skips grad checks as well. Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D23484018 Pulled By: gchanan fbshipit-source-id: 24a8f1af41f9918aaa62bc3cd78b139b2f8de1e1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
42f9897983
commit
49215d7f26
@ -5070,9 +5070,6 @@ class CriterionTest(InputVariableMixin, TestBase):
|
||||
self._do_extra_tests(test_case, module, input, target)
|
||||
|
||||
def _do_extra_tests(self, test_case, module, input, target):
|
||||
if not self.check_gradgrad:
|
||||
return
|
||||
|
||||
test_case.assertFalse(target.requires_grad)
|
||||
|
||||
params = tuple(x for x in module.parameters())
|
||||
@ -5090,6 +5087,10 @@ class CriterionTest(InputVariableMixin, TestBase):
|
||||
# TODO: we don't pass `target` as part of inputs because we don't
|
||||
# currently compute the gradient w.r.t. target for loss functions.
|
||||
gradcheck(apply_fn, inputs)
|
||||
|
||||
if not self.check_gradgrad:
|
||||
return
|
||||
|
||||
gradgradcheck(apply_fn, inputs)
|
||||
|
||||
def test_cuda(self, test_case, dtype=None, extra_args=None):
|
||||
|
Reference in New Issue
Block a user