mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce OptimizerInfos + use them to refactor out the error testing. Why OptimizerInfos? - cleaner, easier way to test all configs of optimizers - would plug in well with devicetype to auto-enable tests for devices like MPS, meta - would allow for more granular testing. currently, lots of functionality is tested in `_test_basic_cases` and some of that should be broken down more. What did I do for error testing? - I moved out some error cases from `_test_basic_cases` into a new test_errors parametrized test. - The new test has to live in TestOptimRenewed (bikeshedding welcome) because the parametrized tests need to take in device and dtype and hook correctly, and not all tests in TestOptim do that. - TestOptimRenewed also is migrating to the toplevel test/test_optim.py now because importing TestOptimRenewed does not work (because of test instantiation, TestOptimRenewed gets replaced with TestOptimRenewedDevice for CPU, CUDA, and whatever other device). Is there any change in test coverage? - INCREASE: The error case where a single Parameter (vs a container of them) are passed in has now expanded to all optims instead of only LBFGS - DECREASE: Not much. The only thing is we no longer test two error cases for foreach=True AND foreach=False, which I think is redundant. (Highlighted in comments) Possible but not urgent next step: test ALL possible error cases by going through all the constructors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114178 Approved by: https://github.com/albanD
37 lines
1.6 KiB
Python
37 lines
1.6 KiB
Python
# Owner(s): ["module: optimizer"]
|
|
|
|
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
|
|
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
|
from optim.test_swa_utils import TestSWAUtils # noqa: F401
|
|
from torch.testing._internal.common_optimizers import optim_db, optims, OptimizerErrorEnum
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
class TestOptimRenewed(TestCase):
|
|
|
|
@onlyCPU
|
|
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
|
|
def test_errors(self, device, dtype, optim_info):
|
|
optim_cls = optim_info.optim_cls
|
|
error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
|
|
|
|
for error_input in error_inputs:
|
|
optim_input = error_input.optimizer_error_input
|
|
params, kwargs = optim_input.params, optim_input.kwargs
|
|
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
|
|
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
|
optim_cls(params, **kwargs)
|
|
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
|
|
optim = optim_cls(params, **kwargs)
|
|
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
|
|
optim.step()
|
|
else:
|
|
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
|
|
|
|
|
|
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|