Files
pytorch/test/test_optim.py
Jane Xu d78fe039eb Introduce OptimizerInfos + add a test_errors (#114178)
Introduce OptimizerInfos + use them to refactor out the error testing.

Why OptimizerInfos?
- cleaner, easier way to test all configs of optimizers
- would plug in well with devicetype to auto-enable tests for devices like MPS, meta
- would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more.

What did I do for error testing?
- I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test.
- The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that.
- TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device).

Is there any change in test coverage?
- INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS
- DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments)

Possible but not urgent next step: test ALL possible error cases by going through all the constructors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114178
Approved by: https://github.com/albanD
2023-12-05 22:58:36 +00:00

37 lines
1.6 KiB
Python

# Owner(s): ["module: optimizer"]
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
from torch.testing._internal.common_optimizers import optim_db, optims, OptimizerErrorEnum
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU
from torch.testing._internal.common_utils import run_tests, TestCase
class TestOptimRenewed(TestCase):
@onlyCPU
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
def test_errors(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
for error_input in error_inputs:
optim_input = error_input.optimizer_error_input
params, kwargs = optim_input.params, optim_input.kwargs
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
optim = optim_cls(params, **kwargs)
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
if __name__ == '__main__':
run_tests()