mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)
Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group` ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" ... ---------------------------------------------------------------------- Ran 3 tests in 0.023s OK ``` Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised. The new test_errors does not run slower than before, overhead is still king: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_errors /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .......................... ---------------------------------------------------------------------- Ran 26 tests in 10.337s OK ``` Compared to test_errors BEFORE my commit :p ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$ python test/test_optim.py -k test_errors /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .............sssssssssssss ---------------------------------------------------------------------- Ran 26 tests in 11.980s OK (skipped=13) (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/116315 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
8abeacda6f
commit
44b98c09ca
@ -26,7 +26,6 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs))
|
||||
|
||||
|
||||
@onlyCPU
|
||||
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
|
||||
def test_errors(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
@ -36,12 +35,20 @@ class TestOptimRenewed(TestCase):
|
||||
optim_input = error_input.optimizer_error_input
|
||||
params, kwargs = optim_input.params, optim_input.kwargs
|
||||
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
|
||||
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
||||
optim_cls(params, **kwargs)
|
||||
if issubclass(error_input.error_type, Warning):
|
||||
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
|
||||
optim_cls(params, **kwargs)
|
||||
else:
|
||||
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
||||
optim_cls(params, **kwargs)
|
||||
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
|
||||
optim = optim_cls(params, **kwargs)
|
||||
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
||||
optim.step()
|
||||
if issubclass(error_input.error_type, Warning):
|
||||
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
|
||||
optim.step()
|
||||
else:
|
||||
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
||||
optim.step()
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user