|
|
|
|
@ -12,7 +12,6 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import collections
|
|
|
|
|
import os
|
|
|
|
|
import platform
|
|
|
|
|
import re
|
|
|
|
|
@ -20,7 +19,6 @@ import socket
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from functools import partial, reduce
|
|
|
|
|
from types import MethodType
|
|
|
|
|
from typing import OrderedDict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from packaging.version import Version
|
|
|
|
|
@ -32,7 +30,6 @@ from ..state import PartialState
|
|
|
|
|
from .constants import FSDP_PYTORCH_VERSION
|
|
|
|
|
from .dataclasses import DistributedType
|
|
|
|
|
from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_available
|
|
|
|
|
from .modeling import id_tensor_storage
|
|
|
|
|
from .transformer_engine import convert_model
|
|
|
|
|
from .versions import is_torch_version
|
|
|
|
|
|
|
|
|
|
@ -118,41 +115,6 @@ def wait_for_everyone():
|
|
|
|
|
PartialState().wait_for_everyone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_state_dict_for_safetensors(state_dict: dict):
|
|
|
|
|
"""
|
|
|
|
|
Cleans the state dictionary from a model and removes tensor aliasing if present.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
state_dict (`dict`):
|
|
|
|
|
The state dictionary from a model
|
|
|
|
|
"""
|
|
|
|
|
ptrs = collections.defaultdict(list)
|
|
|
|
|
# When bnb serialization is used, weights in state dict can be strings
|
|
|
|
|
for name, tensor in state_dict.items():
|
|
|
|
|
if not isinstance(tensor, str):
|
|
|
|
|
ptrs[id_tensor_storage(tensor)].append(name)
|
|
|
|
|
|
|
|
|
|
# These are all pointers of tensors with shared memory
|
|
|
|
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
|
|
|
|
warn_names = set()
|
|
|
|
|
for names in shared_ptrs.values():
|
|
|
|
|
# When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
|
|
|
|
|
# If the link between tensors was done at runtime then `from_pretrained` will not get
|
|
|
|
|
# the key back leading to random tensor. A proper warning will be shown
|
|
|
|
|
# during reload (if applicable), but since the file is not necessarily compatible with
|
|
|
|
|
# the config, better show a proper warning.
|
|
|
|
|
found_names = [name for name in names if name in state_dict]
|
|
|
|
|
warn_names.update(found_names[1:])
|
|
|
|
|
for name in found_names[1:]:
|
|
|
|
|
del state_dict[name]
|
|
|
|
|
if len(warn_names) > 0:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
|
|
|
|
|
)
|
|
|
|
|
state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
|
|
|
|
|
"""
|
|
|
|
|
Save the data to disk. Use in place of `torch.save()`.
|
|
|
|
|
@ -170,8 +132,6 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal
|
|
|
|
|
# Check if it's a model and remove duplicates
|
|
|
|
|
if safe_serialization:
|
|
|
|
|
save_func = partial(safe_save_file, metadata={"format": "pt"})
|
|
|
|
|
if isinstance(obj, OrderedDict):
|
|
|
|
|
obj = clean_state_dict_for_safetensors(obj)
|
|
|
|
|
else:
|
|
|
|
|
save_func = torch.save
|
|
|
|
|
|
|
|
|
|
|