Migrate fused optim load_state_dict to OptimizerInfo (#117890)

The new tests look like:

```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (29f899ef)]$ python test/test_optim.py -v -k test_cpu_load_state_dict
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
test_cpu_load_state_dict_impl_capturable_AdamW_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_capturable_Adam_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_capturable_SGD_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_fused_AdamW_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_fused_Adam_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_fused_SGD_cpu_float32 (__main__.TestOptimRenewedCPU) ... skipped 'Only runs on cuda'
test_cpu_load_state_dict_impl_capturable_AdamW_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok
test_cpu_load_state_dict_impl_capturable_Adam_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok
test_cpu_load_state_dict_impl_capturable_SGD_cuda_float32 (__main__.TestOptimRenewedCUDA) ... skipped 'SGD does not currently support capturable'
test_cpu_load_state_dict_impl_fused_AdamW_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok
test_cpu_load_state_dict_impl_fused_Adam_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok
test_cpu_load_state_dict_impl_fused_SGD_cuda_float32 (__main__.TestOptimRenewedCUDA) ... ok

----------------------------------------------------------------------
Ran 12 tests in 12.865s

OK (skipped=6)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117890
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu
2024-01-22 09:51:35 -08:00
committed by PyTorch MergeBot
parent 9a2c8f644b
commit 95a6866220
2 changed files with 33 additions and 26 deletions

View File

@ -326,6 +326,39 @@ class TestOptimRenewed(TestCase):
optimizer.step()
@onlyCUDA
@parametrize("impl", ["fused", "capturable"])
@optims([optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float32])
def test_cpu_load_state_dict(self, device, dtype, impl, optim_info):
# NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
# How do we get there? Users typically create CUDA models on fused optimizers and then
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
# Since this is a unit test, it is more expedient to simulate what the state_dict
# would look like, which is basically CPU tensors with fused/capturable flag = True.
optim_cls = optim_info.optim_cls
if optim_cls.__name__ == "SGD" and impl == "capturable":
# Capturable SGD does not exist
self.skipTest("SGD does not currently support capturable")
cpu_optim_inputs = optim_info.optim_inputs_func(device="cpu")
for optim_input in cpu_optim_inputs:
param = torch.tensor([0.1, 0.2], dtype=dtype, device="cpu")
optimizer = optim_cls([param], **optim_input.kwargs)
param.grad = torch.rand_like(param)
optimizer.step()
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
optim_state_dict_cpu["param_groups"][0][impl] = True
# load
optim_input.kwargs[impl] = True
param_cuda = param.clone().detach().to(device="cuda")
optimizer_cuda = optim_cls([param_cuda], **optim_input.kwargs)
optimizer_cuda.load_state_dict(optim_state_dict_cpu)
optimizer_cuda.zero_grad()
param_cuda.grad = torch.rand_like(param_cuda)
optimizer_cuda.step()
@optims(optim_db, dtypes=[torch.float32])
def test_step_is_noop_when_params_have_no_grad(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls