Compare commits

..

2 Commits

Author SHA1 Message Date
b345f4e7f7 warning was in info mode. 2024-03-26 20:15:07 +09:00
dbf00d5895 grumble 2024-03-25 22:04:26 +09:00
3 changed files with 12 additions and 22 deletions

View File

@ -107,7 +107,7 @@ from transformers import LlavaNextForConditionalGeneration, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_compute_dtype="torch.float16",
)
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", quantization_config=quantization_config, device_map="auto")

View File

@ -595,12 +595,13 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
state_dict._metadata = metadata
error_msgs = []
unexpected_keys, missing_keys = [], []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
@ -630,7 +631,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# it's safe to delete it.
del state_dict
return error_msgs
return error_msgs, unexpected_keys, missing_keys
def find_submodule_and_param_name(model, long_key, start_prefix):
@ -3901,7 +3902,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model,
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
offload_index = None
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True
@ -3974,8 +3975,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
mmsg, unexpected_keys, missing_keys = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
error_msgs += mmsg
# force memory release
del state_dict
gc.collect()
@ -4009,9 +4010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0:
archs = [] if model.config.architectures is None else model.config.architectures
warner = logger.warning if model.__class__.__name__ in archs else logger.info
warner(
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"

View File

@ -57,13 +57,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
package_version = temp_version
package_exists = True
else:
package_version = "N/A"
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_version = "N/A"
elif return_version:
# For packages other than "torch", don't attempt the fallback
# However, we only mark the package as not available if the version is explicitly requested
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
@ -174,14 +173,6 @@ _torch_version = "N/A"
_torch_available = False
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
if _torch_available and _torch_version == "N/A":
# Here we have the situation where the import package for torch exists, but we can't
# find the distribution package containing its version data. In this case, we import it and ask it directly.
import torch
_torch_version = torch.__version__
if "+" in _torch_version:
torch_version = _torch_version.split("+")[0]
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False