mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Cast tensors when loading optimizer state dicts (#3658)
This commit is contained in:
committed by
Edward Z. Yang
parent
51ca3a1a48
commit
af9fd35d82
@ -1,4 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, Iterable
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
@ -96,8 +96,33 @@ class Optimizer(object):
|
||||
id_map = {old_id: p for old_id, p in
|
||||
zip(chain(*(g['params'] for g in saved_groups)),
|
||||
chain(*(g['params'] for g in groups)))}
|
||||
state = defaultdict(
|
||||
dict, {id_map.get(k, k): v for k, v in state_dict['state'].items()})
|
||||
|
||||
def cast(param, value):
|
||||
"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if torch.is_tensor(value):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if any(tp in type(param.data).__name__ for tp in {'Half', 'Float', 'Double'}):
|
||||
value = value.type_as(param.data)
|
||||
value = value.cuda(param.get_device()) if param.is_cuda else value.cpu()
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = id_map[k]
|
||||
state[param] = cast(param, v)
|
||||
else:
|
||||
state[k] = v
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(group, new_group):
|
||||
|
Reference in New Issue
Block a user