Move step pre/post hook tests to OptimizerInfo (#119288)

Note that this increases coverage from 1 config (vanilla SGD) to all the configs (13 optimizers at around 6-7 each). The test time seems fine though!

With the torch cuda synchronization:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b6093c03)]$ python test/test_optim.py -k test_step_pre_hook -k test_step_post_hook
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
....................................................
----------------------------------------------------------------------
Ran 52 tests in 13.680s

OK
```

Excluding the torch cuda synchronization:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (916f6fe3)]$ python test/test_optim.py -k test_step_pre_hook -k test_step_post_hook
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
....................................................
----------------------------------------------------------------------
Ran 52 tests in 1.038s

OK
```

The old tests:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (916f6fe3)]$ python test/test_optim.py -k test_pre_hook -k test_post_hook
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
..
----------------------------------------------------------------------
Ran 2 tests in 0.518s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119288
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #119283
This commit is contained in:
Jane Xu
2024-02-06 11:19:19 -08:00
committed by PyTorch MergeBot
parent 99ddfaf572
commit 7b3762e6bc
3 changed files with 100 additions and 44 deletions

View File

@ -1,10 +1,12 @@
# Owner(s): ["module: optimizer"]
import functools
import math
from typing import Any, Dict, Tuple
import unittest
from copy import deepcopy
import torch
from torch.optim import Optimizer
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
@ -920,6 +922,76 @@ class TestOptimRenewed(TestCase):
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
@optims(optim_db, dtypes=[torch.float32])
def test_step_post_hook(self, device, dtype, optim_info):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
def dummy_closure():
return 1
closure = dummy_closure if optim_info.step_requires_closure else None
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optim = optim_info.optim_cls(params, **optim_input.kwargs)
data = 2
hook_handle = optim.register_step_post_hook(post_hook)
optim.step(closure)
optim.step(closure)
# check if post hooks were registered
self.assertEqual(data, 6)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
optim.step(closure)
self.assertEqual(data, 6)
@optims(optim_db, dtypes=[torch.float32])
def test_step_pre_hook(self, device, dtype, optim_info):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
def dummy_closure():
return 1
closure = dummy_closure if optim_info.step_requires_closure else None
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
for optim_input in all_optim_inputs:
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
and not optim_input.kwargs.get("foreach", False)):
continue
optim = optim_info.optim_cls(params, **optim_input.kwargs)
data = 5
hook_handle = optim.register_step_pre_hook(pre_hook)
optim.step(closure)
optim.step(closure)
# check if pre hooks were registered
self.assertEqual(data, 9)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
optim.step(closure)
self.assertEqual(data, 9)
@optims(optim_db, dtypes=[torch.float32])
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls