From 10bb0ffe698b51ce8f11fa0ab3870bac10d96072 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 4 Apr 2022 19:56:55 +0000 Subject: [PATCH] Fix casting bug in state_step for optimizers when loading state dict Pull Request resolved: https://github.com/pytorch/pytorch/pull/75214 Approved by: https://github.com/albanD --- test/test_optim.py | 6 ++++++ torch/optim/optimizer.py | 12 +++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index c59d6a49bb49..8113899f621e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -228,6 +228,12 @@ class TestOptim(TestCase): # Make sure state dict wasn't modified self.assertEqual(state_dict, state_dict_c) + # Make sure that device of state['step'] is still CPU + new_state_dict = optimizer_cuda.state_dict() + if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']): + for state in new_state_dict['state'].values(): + self.assertEqual(state['step'].device.type, 'cpu') + for _i in range(20): optimizer.step(fn) optimizer_cuda.step(fn_cuda) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 8acf2eaebc5c..9da9277dee32 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -151,17 +151,19 @@ class Optimizer(object): zip(chain.from_iterable((g['params'] for g in saved_groups)), chain.from_iterable((g['params'] for g in groups)))} - def cast(param, value): + def cast(param, value, key=None): r"""Make a deep copy of value, casting all tensors to device of param.""" if isinstance(value, torch.Tensor): # Floating-point types are a bit special here. They are the only ones # that are assumed to always match the type of params. - if param.is_floating_point(): - value = value.to(param.dtype) - value = value.to(param.device) + # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 + if (key != "step"): + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) return value elif isinstance(value, dict): - return {k: cast(param, v) for k, v in value.items()} + return {k: cast(param, v, key=k) for k, v in value.items()} elif isinstance(value, container_abcs.Iterable): return type(value)(cast(param, v) for v in value) else: