mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Reset grad state across unittests (#126345)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126345 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a31a60d85b
commit
d11e44c0d0
@ -2903,6 +2903,9 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
||||
if self._default_dtype_check_enabled:
|
||||
assert torch.get_default_dtype() == torch.float
|
||||
|
||||
# attempt to reset some global state at the end of the test
|
||||
self._prev_grad_state = torch.is_grad_enabled()
|
||||
|
||||
def tearDown(self):
|
||||
# There exists test cases that override TestCase.setUp
|
||||
# definition, so we cannot assume that _check_invariants
|
||||
@ -2917,6 +2920,10 @@ This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0"""
|
||||
if self._default_dtype_check_enabled:
|
||||
assert torch.get_default_dtype() == torch.float
|
||||
|
||||
# attribute may not be defined, per above
|
||||
if hasattr(self, '_prev_grad_state'):
|
||||
torch.set_grad_enabled(self._prev_grad_state)
|
||||
|
||||
@staticmethod
|
||||
def _make_crow_indices(n_rows, n_cols, nnz,
|
||||
*, device, dtype, random=True):
|
||||
|
Reference in New Issue
Block a user