[BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)

Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group`

```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
...
----------------------------------------------------------------------
Ran 3 tests in 0.023s

OK
```

Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised.

The new test_errors does not run slower than before, overhead is still king:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129de)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
..........................
----------------------------------------------------------------------
Ran 26 tests in 10.337s

OK
```

Compared to test_errors BEFORE my commit :p
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
.............sssssssssssss
----------------------------------------------------------------------
Ran 26 tests in 11.980s

OK (skipped=13)
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa696)]$
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116315
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2023-12-26 18:36:12 +00:00
committed by PyTorch MergeBot
parent 8abeacda6f
commit 44b98c09ca
4 changed files with 286 additions and 303 deletions

View File

@ -26,7 +26,6 @@ class TestOptimRenewed(TestCase):
self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs))
@onlyCPU
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
def test_errors(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
@ -36,12 +35,20 @@ class TestOptimRenewed(TestCase):
optim_input = error_input.optimizer_error_input
params, kwargs = optim_input.params, optim_input.kwargs
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
else:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
optim = optim_cls(params, **kwargs)
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
if issubclass(error_input.error_type, Warning):
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")