Disable slow gradcheck for nn.Transformer ModuleInfo (#145531)

Fixes https://github.com/pytorch/pytorch/issues/117140

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145531
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #145520
This commit is contained in:
soulitzer
2025-01-24 13:36:24 -05:00
committed by PyTorch MergeBot
parent 9e0ee152e5
commit c7ca1df37e
2 changed files with 19 additions and 4 deletions

View File

@ -482,11 +482,19 @@ class TestModule(TestCase):
output_flattened = torch.utils._pytree.tree_leaves(output)
return output_flattened
def do_check(flat_input):
self.assertTrue(
check(
fn_to_gradcheck,
flat_input,
nondet_tol=gradcheck_nondet_tol,
fast_mode=module_info.gradcheck_fast_mode
))
# check total derivative
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
# check partial derivatives
old_params_requires_grad = [p.requires_grad for p in params]
@ -501,14 +509,14 @@ class TestModule(TestCase):
p.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
p.requires_grad = False
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
obj.requires_grad = old
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
do_check(flat_input)
obj.requires_grad = False
@modules(module_db, allowed_dtypes=[torch.double])