mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move step pre/post hook tests to OptimizerInfo (#119288)
Note that this increases coverage from 1 config (vanilla SGD) to all the configs (13 optimizers at around 6-7 each). The test time seems fine though! With the torch cuda synchronization: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b6093c03)]$ python test/test_optim.py -k test_step_pre_hook -k test_step_post_hook /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .................................................... ---------------------------------------------------------------------- Ran 52 tests in 13.680s OK ``` Excluding the torch cuda synchronization: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (916f6fe3)]$ python test/test_optim.py -k test_step_pre_hook -k test_step_post_hook /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .................................................... ---------------------------------------------------------------------- Ran 52 tests in 1.038s OK ``` The old tests: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (916f6fe3)]$ python test/test_optim.py -k test_pre_hook -k test_post_hook /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .. ---------------------------------------------------------------------- Ran 2 tests in 0.518s OK ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/119288 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #119283
This commit is contained in:
committed by
PyTorch MergeBot
parent
99ddfaf572
commit
7b3762e6bc
@ -1,10 +1,12 @@
|
||||
# Owner(s): ["module: optimizer"]
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Dict, Tuple
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
|
||||
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
|
||||
from optim.test_swa_utils import TestSWAUtils # noqa: F401
|
||||
@ -920,6 +922,76 @@ class TestOptimRenewed(TestCase):
|
||||
self.assertEqual(optimizer.state_dict(), optimizer_cuda.state_dict())
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_post_hook(self, device, dtype, optim_info):
|
||||
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
nonlocal data
|
||||
data += 2
|
||||
|
||||
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
|
||||
|
||||
def dummy_closure():
|
||||
return 1
|
||||
|
||||
closure = dummy_closure if optim_info.step_requires_closure else None
|
||||
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optim = optim_info.optim_cls(params, **optim_input.kwargs)
|
||||
data = 2
|
||||
hook_handle = optim.register_step_post_hook(post_hook)
|
||||
|
||||
optim.step(closure)
|
||||
optim.step(closure)
|
||||
# check if post hooks were registered
|
||||
self.assertEqual(data, 6)
|
||||
|
||||
# remove handles, take step and verify that hook is no longer registered
|
||||
hook_handle.remove()
|
||||
|
||||
optim.step(closure)
|
||||
self.assertEqual(data, 6)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_pre_hook(self, device, dtype, optim_info):
|
||||
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
|
||||
nonlocal data
|
||||
data += 2
|
||||
|
||||
params = [torch.tensor([1, 1], device=device, dtype=dtype)]
|
||||
|
||||
def dummy_closure():
|
||||
return 1
|
||||
|
||||
closure = dummy_closure if optim_info.step_requires_closure else None
|
||||
|
||||
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
|
||||
for optim_input in all_optim_inputs:
|
||||
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
|
||||
and not optim_input.kwargs.get("foreach", False)):
|
||||
continue
|
||||
|
||||
optim = optim_info.optim_cls(params, **optim_input.kwargs)
|
||||
data = 5
|
||||
hook_handle = optim.register_step_pre_hook(pre_hook)
|
||||
|
||||
optim.step(closure)
|
||||
optim.step(closure)
|
||||
# check if pre hooks were registered
|
||||
self.assertEqual(data, 9)
|
||||
|
||||
# remove handles, take step and verify that hook is no longer registered
|
||||
hook_handle.remove()
|
||||
|
||||
optim.step(closure)
|
||||
self.assertEqual(data, 9)
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_deepcopy_copies_all_public_attrs(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
Reference in New Issue
Block a user