Compare commits

...

1 Commits

Author SHA1 Message Date
cb51740c77 Use partial state for everything 2023-08-10 13:36:43 +00:00

View File

@ -83,7 +83,7 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights, PartialState
from accelerate.utils import (
check_tied_parameters_on_same_device,
find_tied_parameters,
@ -2327,12 +2327,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
# Rely on the `PartialState` to ensure no accidental re-forks
device_map = {"": PartialState().device}
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
logger.info(
"The device_map was not initialized."
"Setting device_map to {'':torch.cuda.current_device()}."
f"Setting device_map to {{'':{PartialState().device}}}."
"If you want to use the model for inference, please set device_map ='auto' "
)
if low_cpu_mem_usage is None:
@ -2402,12 +2403,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype = torch.float16
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
# Rely on the `PartialState` to ensure no accidental re-forks
device_map = {"": PartialState().device}
else:
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
logger.info(
"The device_map was not initialized."
"Setting device_map to {'':torch.cuda.current_device()}."
f"Setting device_map to {{'':{PartialState().device}}}."
"If you want to use the model for inference, please set device_map ='auto' "
)
if low_cpu_mem_usage is None: