mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Migrate fused optim load_state_dict to OptimizerInfo (#117890)
The new tests look like: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (29f899ef)]$ python test/test_optim.py -v -k test_cpu_load_state_dict /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( test_cpu_load_state_dict_impl_capturable_AdamW_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_capturable_Adam_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_capturable_SGD_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_fused_AdamW_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_fused_Adam_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_fused_SGD_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda' test_cpu_load_state_dict_impl_capturable_AdamW_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok test_cpu_load_state_dict_impl_capturable_Adam_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok test_cpu_load_state_dict_impl_capturable_SGD_cuda_float32 (__main__.TestOptimRenewedCUDA) ... skipped 'SGD does not currently support capturable' test_cpu_load_state_dict_impl_fused_AdamW_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok test_cpu_load_state_dict_impl_fused_Adam_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok test_cpu_load_state_dict_impl_fused_SGD_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok ---------------------------------------------------------------------- Ran 12 tests in 12.865s OK (skipped=6) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/117890 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
9a2c8f644b
commit
95a6866220
@ -326,6 +326,39 @@ class TestOptimRenewed(TestCase):
|
||||
optimizer.step()
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@parametrize("impl", ["fused", "capturable"])
|
||||
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
|
||||
def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
|
||||
# NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
|
||||
# How do we get there? Users typically create CUDA models on fused optimizers and then
|
||||
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
|
||||
# Since this is a unit test, it is more expedient to simulate what the state_dict
|
||||
# would look like, which is basically CPU tensors with fused/capturable flag = True.
|
||||
optim_cls = optim_info.optim_cls
|
||||
if optim_cls.__name__ == "SGD" and impl == "capturable":
|
||||
# Capturable SGD does not exist
|
||||
self.skipTest("SGD does not currently support capturable")
|
||||
|
||||
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
|
||||
for optim_input in cpu_optim_inputs:
|
||||
param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
|
||||
optimizer = optim_cls([param], **optim_input.kwargs)
|
||||
param.grad = torch.rand_like(param)
|
||||
optimizer.step()
|
||||
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
|
||||
optim_state_dict_cpu["param_groups"][0][impl] = True
|
||||
|
||||
# load
|
||||
optim_input.kwargs[impl] = True
|
||||
param_cuda = param.clone().detach().to(device="cuda")
|
||||
optimizer_cuda = optim_cls([param_cuda], **optim_input.kwargs)
|
||||
optimizer_cuda.load_state_dict(optim_state_dict_cpu)
|
||||
optimizer_cuda.zero_grad()
|
||||
param_cuda.grad = torch.rand_like(param_cuda)
|
||||
optimizer_cuda.step()
|
||||
|
||||
|
||||
@optims(optim_db, dtypes=[torch.float32])
|
||||
def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
|
||||
optim_cls = optim_info.optim_cls
|
||||
|
||||
Reference in New Issue
Block a user