!228 修复torch2.1下device_map报错的问题

Merge pull request !228 from 幽若/master-0603
This commit is contained in:
2025-06-12 02:08:24 +00:00
committed by i-robot
parent 105abd9d8a
commit 89512cad3e
3 changed files with 184 additions and 2 deletions

View File

@ -297,7 +297,7 @@ def get_model():
else:
# zero3 does not support load model with device map
if not is_deepspeed_zero3_enabled():
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK"))}
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK", 0))}
if args.load_in_4bit:
patch_bnb()

View File

@ -11,7 +11,7 @@
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
from openmind.utils import is_transformers_available, is_torch_available
from openmind.utils import is_transformers_available, is_torch_available, version
if is_transformers_available() and is_torch_available():
from openmind.integrations.transformers.autoclasses import (
@ -32,5 +32,9 @@ if is_transformers_available() and is_torch_available():
from openmind.integrations.transformers.logging import patch_transformers_logging
from openmind.integrations.transformers.bitsandbytes import patch_bnb
from openmind.integrations.transformers.modeling_utils import patch_modeling_utils
if version.check_package_version("torch>=2.1.0, <2.1.1"):
patch_modeling_utils()
patch_transformers_logging()

View File

@ -0,0 +1,178 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, List, Optional, Tuple
import torch
import transformers
from accelerate.utils import offload_weight, is_npu_available
from safetensors import safe_open
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.integrations.tensor_parallel import shard_and_distribute_module
from transformers.quantizers import HfQuantizer, get_module_from_name
from transformers.modeling_utils import (
_infer_parameter_dtype,
is_fsdp_enabled,
is_local_dist_rank_0,
_load_parameter_into_model,
)
from transformers.utils.quantization_config import QuantizationMethod
@torch.no_grad()
def _load_state_dict_into_meta_model_patch(
model,
state_dict: Dict,
shard_file: str,
expected_keys: List[str],
reverse_renaming_mapping: Dict[str, str],
device_map: Optional[Dict] = None,
disk_offload_folder: Optional[str] = None,
disk_offload_index: Optional[Dict] = None,
cpu_offload_folder: Optional[str] = None,
cpu_offload_index: Optional[Dict] = None,
hf_quantizer: Optional[HfQuantizer] = None,
is_safetensors: bool = False,
keep_in_fp32_regex: Optional[re.Pattern] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
from `shard_file`, which is the actual state dict file on disk.
This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism.
"""
# in npu environment, set the device_map like {"": "npu:0"}, in other case, keep the original {"": 0}
if is_npu_available():
for k, v in device_map.items():
if "npu" not in str(device_map.get(k)):
device_map[k] = f"npu:{v}"
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
if device_map[""] not in ("cpu", torch.device("cpu")):
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
}
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
file_pointer = None
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
for param_name, empty_param in state_dict.items():
if param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched
if is_meta_state_dict:
# This is the name of the parameter as it appears on disk file
serialized_param_name = reverse_renaming_mapping[param_name]
param = file_pointer.get_slice(serialized_param_name)
else:
param = empty_param.to(tensor_device) # It is actually not empty!
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
empty_param,
keep_in_fp32_regex,
hf_quantizer,
)
if device_mesh is not None: # In this case, the param is already on the correct device!
shard_and_distribute_module(
model,
param,
empty_param,
param_name,
casting_dtype,
to_contiguous,
device_mesh.get_local_rank(),
device_mesh,
)
else:
param = param[...]
if casting_dtype is not None:
param = param.to(casting_dtype)
if to_contiguous:
param = param.contiguous()
if device_map is None:
param_device = "cpu"
else:
module_layer = re.search(device_map_regex, param_name)
if not module_layer:
raise ValueError(f"{param_name} doesn't have any device set.")
else:
param_device = device_map[module_layer.group()]
if param_device == "disk":
if not is_safetensors:
disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
elif param_device == "cpu" and cpu_offload_index is not None:
cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (
not hf_quantizer.check_quantized_param(
model,
param,
param_name,
state_dict,
param_device=param_device,
device_map=device_map,
)
)
):
if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta"
_load_parameter_into_model(model, param_name, param.to(param_device))
else:
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, param_type = get_module_from_name(model, param_name)
value = getattr(module, param_type)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, param_type, value)
if file_pointer is not None:
file_pointer.__exit__(None, None, None)
return disk_offload_index, cpu_offload_index
def patch_modeling_utils():
transformers.modeling_utils._load_state_dict_into_meta_model = _load_state_dict_into_meta_model_patch