!228 修复torch2.1下device_map报错的问题
Merge pull request !228 from 幽若/master-0603
This commit is contained in:
@ -297,7 +297,7 @@ def get_model():
|
||||
else:
|
||||
# zero3 does not support load model with device map
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK"))}
|
||||
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK", 0))}
|
||||
|
||||
if args.load_in_4bit:
|
||||
patch_bnb()
|
||||
|
@ -11,7 +11,7 @@
|
||||
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
# See the Mulan PSL v2 for more details.
|
||||
|
||||
from openmind.utils import is_transformers_available, is_torch_available
|
||||
from openmind.utils import is_transformers_available, is_torch_available, version
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from openmind.integrations.transformers.autoclasses import (
|
||||
@ -32,5 +32,9 @@ if is_transformers_available() and is_torch_available():
|
||||
|
||||
from openmind.integrations.transformers.logging import patch_transformers_logging
|
||||
from openmind.integrations.transformers.bitsandbytes import patch_bnb
|
||||
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
|
||||
|
||||
if version.check_package_version("torch>=2.1.0, <2.1.1"):
|
||||
patch_modeling_utils()
|
||||
|
||||
patch_transformers_logging()
|
||||
|
178
src/openmind/integrations/transformers/modeling_utils.py
Normal file
178
src/openmind/integrations/transformers/modeling_utils.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from accelerate.utils import offload_weight, is_npu_available
|
||||
from safetensors import safe_open
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.integrations.tensor_parallel import shard_and_distribute_module
|
||||
from transformers.quantizers import HfQuantizer, get_module_from_name
|
||||
from transformers.modeling_utils import (
|
||||
_infer_parameter_dtype,
|
||||
is_fsdp_enabled,
|
||||
is_local_dist_rank_0,
|
||||
_load_parameter_into_model,
|
||||
)
|
||||
from transformers.utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model_patch(
|
||||
model,
|
||||
state_dict: Dict,
|
||||
shard_file: str,
|
||||
expected_keys: List[str],
|
||||
reverse_renaming_mapping: Dict[str, str],
|
||||
device_map: Optional[Dict] = None,
|
||||
disk_offload_folder: Optional[str] = None,
|
||||
disk_offload_index: Optional[Dict] = None,
|
||||
cpu_offload_folder: Optional[str] = None,
|
||||
cpu_offload_index: Optional[Dict] = None,
|
||||
hf_quantizer: Optional[HfQuantizer] = None,
|
||||
is_safetensors: bool = False,
|
||||
keep_in_fp32_regex: Optional[re.Pattern] = None,
|
||||
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
|
||||
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
|
||||
) -> Tuple[Optional[Dict], Optional[Dict]]:
|
||||
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
|
||||
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
|
||||
from `shard_file`, which is the actual state dict file on disk.
|
||||
This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism.
|
||||
"""
|
||||
|
||||
# in npu environment, set the device_map like {"": "npu:0"}, in other case, keep the original {"": 0}
|
||||
if is_npu_available():
|
||||
for k, v in device_map.items():
|
||||
if "npu" not in str(device_map.get(k)):
|
||||
device_map[k] = f"npu:{v}"
|
||||
|
||||
tensor_device = "cpu"
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
if device_map[""] not in ("cpu", torch.device("cpu")):
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
if device_map is not None:
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
}
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
||||
file_pointer = None
|
||||
if is_meta_state_dict:
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
|
||||
for param_name, empty_param in state_dict.items():
|
||||
if param_name not in expected_keys:
|
||||
continue
|
||||
|
||||
# we need to use serialized_param_name as file pointer is untouched
|
||||
if is_meta_state_dict:
|
||||
# This is the name of the parameter as it appears on disk file
|
||||
serialized_param_name = reverse_renaming_mapping[param_name]
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
else:
|
||||
param = empty_param.to(tensor_device) # It is actually not empty!
|
||||
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
param_name,
|
||||
empty_param,
|
||||
keep_in_fp32_regex,
|
||||
hf_quantizer,
|
||||
)
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
empty_param,
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
param = param[...]
|
||||
if casting_dtype is not None:
|
||||
param = param.to(casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
else:
|
||||
module_layer = re.search(device_map_regex, param_name)
|
||||
if not module_layer:
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
else:
|
||||
param_device = device_map[module_layer.group()]
|
||||
|
||||
if param_device == "disk":
|
||||
if not is_safetensors:
|
||||
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
|
||||
elif param_device == "cpu" and cpu_offload_index is not None:
|
||||
cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
|
||||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (
|
||||
not hf_quantizer.check_quantized_param(
|
||||
model,
|
||||
param,
|
||||
param_name,
|
||||
state_dict,
|
||||
param_device=param_device,
|
||||
device_map=device_map,
|
||||
)
|
||||
)
|
||||
):
|
||||
if is_fsdp_enabled():
|
||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||
|
||||
_load_parameter_into_model(model, param_name, param.to(param_device))
|
||||
|
||||
else:
|
||||
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
|
||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||
# in comparison to the sharded model across GPUs.
|
||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||
module, param_type = get_module_from_name(model, param_name)
|
||||
value = getattr(module, param_type)
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
param_to = "meta"
|
||||
val_kwargs = {}
|
||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, param_type, value)
|
||||
|
||||
if file_pointer is not None:
|
||||
file_pointer.__exit__(None, None, None)
|
||||
|
||||
return disk_offload_index, cpu_offload_index
|
||||
|
||||
|
||||
def patch_modeling_utils():
|
||||
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch
|
Reference in New Issue
Block a user