Fix casting bug in state_step for optimizers when loading state dict

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75214

Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2022-04-04 19:56:55 +00:00
committed by PyTorch MergeBot
parent 22d227fd29
commit 10bb0ffe69
2 changed files with 13 additions and 5 deletions

View File

@ -228,6 +228,12 @@ class TestOptim(TestCase):
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
# Make sure that device of state['step'] is still CPU
new_state_dict = optimizer_cuda.state_dict()
if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']):
for state in new_state_dict['state'].values():
self.assertEqual(state['step'].device.type, 'cpu')
for _i in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)

View File

@ -151,17 +151,19 @@ class Optimizer(object):
zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
def cast(param, value):
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
if (key != "step"):
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
return {k: cast(param, v, key=k) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else: