Compare commits

...

1 Commits

Author SHA1 Message Date
170780d92f Rm clean_state_dict 2024-01-26 16:59:39 +01:00
3 changed files with 0 additions and 44 deletions

View File

@ -64,7 +64,6 @@ from .utils import (
RNGType,
TorchDynamoPlugin,
check_os_kernel,
clean_state_dict_for_safetensors,
compare_versions,
convert_model,
convert_outputs_to_fp32,
@ -2567,8 +2566,6 @@ class Accelerator:
raise RuntimeError("You can't save the model since some parameters are on the meta device.")
state_dict = self.get_state_dict(model)
if safe_serialization:
state_dict = clean_state_dict_for_safetensors(state_dict)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
# Shard the model if it is too big.

View File

@ -182,7 +182,6 @@ from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
from .memory import find_executable_batch_size, release_memory
from .other import (
check_os_kernel,
clean_state_dict_for_safetensors,
clear_environment,
convert_bytes,
extract_model_from_parallel,

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import platform
import re
@ -20,7 +19,6 @@ import socket
from contextlib import contextmanager
from functools import partial, reduce
from types import MethodType
from typing import OrderedDict
import torch
from packaging.version import Version
@ -32,7 +30,6 @@ from ..state import PartialState
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_available
from .modeling import id_tensor_storage
from .transformer_engine import convert_model
from .versions import is_torch_version
@ -118,41 +115,6 @@ def wait_for_everyone():
PartialState().wait_for_everyone()
def clean_state_dict_for_safetensors(state_dict: dict):
"""
Cleans the state dictionary from a model and removes tensor aliasing if present.
Args:
state_dict (`dict`):
The state dictionary from a model
"""
ptrs = collections.defaultdict(list)
# When bnb serialization is used, weights in state dict can be strings
for name, tensor in state_dict.items():
if not isinstance(tensor, str):
ptrs[id_tensor_storage(tensor)].append(name)
# These are all pointers of tensors with shared memory
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
for names in shared_ptrs.values():
# When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found_names = [name for name in names if name in state_dict]
warn_names.update(found_names[1:])
for name in found_names[1:]:
del state_dict[name]
if len(warn_names) > 0:
logger.warning(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
)
state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
return state_dict
def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
"""
Save the data to disk. Use in place of `torch.save()`.
@ -170,8 +132,6 @@ def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = Fal
# Check if it's a model and remove duplicates
if safe_serialization:
save_func = partial(safe_save_file, metadata={"format": "pt"})
if isinstance(obj, OrderedDict):
obj = clean_state_dict_for_safetensors(obj)
else:
save_func = torch.save