Reset grad state across unittests (#126345)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126345
Approved by: https://github.com/ezyang
This commit is contained in:
William Wen
2024-05-23 10:35:56 -07:00
committed by PyTorch MergeBot
parent a31a60d85b
commit d11e44c0d0
2 changed files with 8 additions and 7 deletions

View File

@ -2903,6 +2903,9 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float
# attempt to reset some global state at the end of the test
self._prev_grad_state = torch.is_grad_enabled()
def tearDown(self):
# There exists test cases that override TestCase.setUp
# definition, so we cannot assume that _check_invariants
@ -2917,6 +2920,10 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
if self._default_dtype_check_enabled:
assert torch.get_default_dtype() == torch.float
# attribute may not be defined, per above
if hasattr(self, '_prev_grad_state'):
torch.set_grad_enabled(self._prev_grad_state)
@staticmethod
def _make_crow_indices(n_rows, n_cols, nnz,
*, device, dtype, random=True):