mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 00:24:53 +08:00
Fix casting bug in state_step for optimizers when loading state dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75214 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
22d227fd29
commit
10bb0ffe69
@ -228,6 +228,12 @@ class TestOptim(TestCase):
|
||||
# Make sure state dict wasn't modified
|
||||
self.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
# Make sure that device of state['step'] is still CPU
|
||||
new_state_dict = optimizer_cuda.state_dict()
|
||||
if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']):
|
||||
for state in new_state_dict['state'].values():
|
||||
self.assertEqual(state['step'].device.type, 'cpu')
|
||||
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
|
||||
@ -151,17 +151,19 @@ class Optimizer(object):
|
||||
zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||
chain.from_iterable((g['params'] for g in groups)))}
|
||||
|
||||
def cast(param, value):
|
||||
def cast(param, value, key=None):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
|
||||
if (key != "step"):
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user