mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Disable slow gradcheck for nn.Transformer ModuleInfo (#145531)
Fixes https://github.com/pytorch/pytorch/issues/117140 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145531 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #145520
This commit is contained in:
committed by
PyTorch MergeBot
parent
9e0ee152e5
commit
c7ca1df37e
@ -482,11 +482,19 @@ class TestModule(TestCase):
|
||||
output_flattened = torch.utils._pytree.tree_leaves(output)
|
||||
return output_flattened
|
||||
|
||||
def do_check(flat_input):
|
||||
self.assertTrue(
|
||||
check(
|
||||
fn_to_gradcheck,
|
||||
flat_input,
|
||||
nondet_tol=gradcheck_nondet_tol,
|
||||
fast_mode=module_info.gradcheck_fast_mode
|
||||
))
|
||||
|
||||
# check total derivative
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
|
||||
# check partial derivatives
|
||||
old_params_requires_grad = [p.requires_grad for p in params]
|
||||
@ -501,14 +509,14 @@ class TestModule(TestCase):
|
||||
p.requires_grad = old
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
p.requires_grad = False
|
||||
|
||||
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
|
||||
obj.requires_grad = old
|
||||
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
|
||||
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
|
||||
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
|
||||
do_check(flat_input)
|
||||
obj.requires_grad = False
|
||||
|
||||
@modules(module_db, allowed_dtypes=[torch.double])
|
||||
|
Reference in New Issue
Block a user