Cast tensors when loading optimizer state dicts (#3658)

This commit is contained in:
Adam Paszke
2017-11-28 15:56:39 +01:00
committed by Edward Z. Yang
parent 51ca3a1a48
commit af9fd35d82
3 changed files with 56 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from collections import defaultdict
from collections import defaultdict, Iterable
import torch
from copy import deepcopy
@ -96,8 +96,33 @@ class Optimizer(object):
id_map = {old_id: p for old_id, p in
zip(chain(*(g['params'] for g in saved_groups)),
chain(*(g['params'] for g in groups)))}
state = defaultdict(
dict, {id_map.get(k, k): v for k, v in state_dict['state'].items()})
def cast(param, value):
"""Make a deep copy of value, casting all tensors to device of param."""
if torch.is_tensor(value):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}):
value = value.type_as(param.data)
value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):