mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-16 07:19:51 +08:00
Compare commits
12 Commits
v0.30.1
...
fix-dispat
| Author | SHA1 | Date | |
|---|---|---|---|
| c1ea6f2474 | |||
| 30631a65be | |||
| f9ecf75d24 | |||
| 2ea8986df6 | |||
| 4b95a1d12e | |||
| f199f6baed | |||
| 829d33af01 | |||
| ae573266b3 | |||
| ec2d94f02d | |||
| 641a22f87e | |||
| 373a6bea9a | |||
| 982e60560b |
@ -43,6 +43,7 @@ from .utils import (
|
||||
parse_flag_from_env,
|
||||
retie_parameters,
|
||||
)
|
||||
from .utils.other import recursive_getattr
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -395,7 +396,22 @@ def dispatch_model(
|
||||
else:
|
||||
weights_map = None
|
||||
|
||||
# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
|
||||
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
|
||||
# original pointer) on each devices.
|
||||
tied_params = find_tied_parameters(model)
|
||||
|
||||
tied_params_map = {}
|
||||
for group in tied_params:
|
||||
for param_name in group:
|
||||
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
|
||||
# to care about views of tensors through storage_offset.
|
||||
data_ptr = recursive_getattr(model, param_name).data_ptr()
|
||||
tied_params_map[data_ptr] = {}
|
||||
|
||||
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
|
||||
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
||||
|
||||
attach_align_device_hook_on_blocks(
|
||||
model,
|
||||
execution_device=execution_device,
|
||||
@ -404,6 +420,7 @@ def dispatch_model(
|
||||
weights_map=weights_map,
|
||||
skip_keys=skip_keys,
|
||||
preload_module_classes=preload_module_classes,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
|
||||
# warn if there is any params on the meta device
|
||||
|
||||
@ -27,6 +27,7 @@ from .utils import (
|
||||
set_module_tensor_to_device,
|
||||
)
|
||||
from .utils.modeling import get_non_persistent_buffers
|
||||
from .utils.other import recursive_getattr
|
||||
|
||||
|
||||
class ModelHook:
|
||||
@ -116,7 +117,9 @@ class SequentialHook(ModelHook):
|
||||
return module
|
||||
|
||||
|
||||
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
|
||||
def add_hook_to_module(
|
||||
module: nn.Module, hook: ModelHook, append: bool = False, init_hook_kwargs: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
||||
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
||||
@ -135,6 +138,8 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
|
||||
The hook to attach.
|
||||
append (`bool`, *optional*, defaults to `False`):
|
||||
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
||||
init_hook_kwargs (Optional[Dict], *optional*, defaults to `None`):
|
||||
Optional arguments to pass to the hook initialization.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
|
||||
@ -153,7 +158,10 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
|
||||
old_forward = module.forward
|
||||
module._old_forward = old_forward
|
||||
|
||||
module = hook.init_hook(module)
|
||||
if init_hook_kwargs is None:
|
||||
init_hook_kwargs = {}
|
||||
|
||||
module = hook.init_hook(module, **init_hook_kwargs)
|
||||
module._hf_hook = hook
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
@ -240,6 +248,7 @@ class AlignDevicesHook(ModelHook):
|
||||
self.input_device = None
|
||||
self.param_original_devices = {}
|
||||
self.buffer_original_devices = {}
|
||||
self.tied_params_names = set()
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@ -248,10 +257,14 @@ class AlignDevicesHook(ModelHook):
|
||||
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
|
||||
)
|
||||
|
||||
def init_hook(self, module):
|
||||
def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None):
|
||||
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
|
||||
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
|
||||
tied_params_map = None
|
||||
|
||||
if not self.offload and self.execution_device is not None:
|
||||
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
|
||||
elif self.offload:
|
||||
self.original_devices = {
|
||||
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
|
||||
@ -266,13 +279,25 @@ class AlignDevicesHook(ModelHook):
|
||||
for name, _ in named_module_tensors(
|
||||
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
|
||||
):
|
||||
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
|
||||
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
|
||||
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
|
||||
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
|
||||
if tied_params_map is not None and recursive_getattr(module, name).data_ptr() in tied_params_map:
|
||||
self.tied_params_names.add(name)
|
||||
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
|
||||
elif self.offload_buffers and self.execution_device is not None:
|
||||
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
|
||||
|
||||
# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
|
||||
# for tied weights already loaded on the target execution device.
|
||||
self.tied_params_map = tied_params_map
|
||||
|
||||
return module
|
||||
|
||||
@ -280,6 +305,8 @@ class AlignDevicesHook(ModelHook):
|
||||
if self.io_same_device:
|
||||
self.input_device = find_device([args, kwargs])
|
||||
if self.offload:
|
||||
self.tied_pointers_to_remove = set()
|
||||
|
||||
for name, _ in named_module_tensors(
|
||||
module,
|
||||
include_buffers=self.offload_buffers,
|
||||
@ -287,11 +314,32 @@ class AlignDevicesHook(ModelHook):
|
||||
remove_non_persistent=True,
|
||||
):
|
||||
fp16_statistics = None
|
||||
value = self.weights_map[name]
|
||||
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
|
||||
if self.weights_map[name].dtype == torch.int8:
|
||||
if value.dtype == torch.int8:
|
||||
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
|
||||
|
||||
# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
|
||||
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
|
||||
# self.tied_params_map in order to allow to free memory.
|
||||
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
|
||||
self.tied_params_map[value.data_ptr()] = {}
|
||||
|
||||
if (
|
||||
value is not None
|
||||
and self.tied_params_map is not None
|
||||
and value.data_ptr() in self.tied_params_map
|
||||
and self.execution_device not in self.tied_params_map[value.data_ptr()]
|
||||
):
|
||||
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
|
||||
|
||||
set_module_tensor_to_device(
|
||||
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics
|
||||
module,
|
||||
name,
|
||||
self.execution_device,
|
||||
value=value,
|
||||
fp16_statistics=fp16_statistics,
|
||||
tied_params_map=self.tied_params_map,
|
||||
)
|
||||
|
||||
return send_to_device(args, self.execution_device), send_to_device(
|
||||
@ -311,6 +359,12 @@ class AlignDevicesHook(ModelHook):
|
||||
module.state.SCB = None
|
||||
module.state.CxB = None
|
||||
|
||||
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
|
||||
# this dictionary to allow the garbage collector to do its job.
|
||||
for value_pointer, device in self.tied_pointers_to_remove:
|
||||
del self.tied_params_map[value_pointer][device]
|
||||
self.tied_pointers_to_remove = None
|
||||
|
||||
if self.io_same_device and self.input_device is not None:
|
||||
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
|
||||
|
||||
@ -329,6 +383,7 @@ def attach_execution_device_hook(
|
||||
execution_device: Union[int, str, torch.device],
|
||||
skip_keys: Optional[Union[str, List[str]]] = None,
|
||||
preload_module_classes: Optional[List[str]] = None,
|
||||
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
||||
):
|
||||
"""
|
||||
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
|
||||
@ -346,16 +401,24 @@ def attach_execution_device_hook(
|
||||
of the forward. This should only be used for classes that have submodules which are registered but not
|
||||
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
||||
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
||||
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
||||
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
||||
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
||||
instead of duplicating memory.
|
||||
"""
|
||||
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
|
||||
add_hook_to_module(module, AlignDevicesHook(execution_device, skip_keys=skip_keys))
|
||||
add_hook_to_module(
|
||||
module,
|
||||
AlignDevicesHook(execution_device, skip_keys=skip_keys),
|
||||
init_hook_kwargs={"tied_params_map": tied_params_map},
|
||||
)
|
||||
|
||||
# Break the recursion if we get to a preload module.
|
||||
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
|
||||
return
|
||||
|
||||
for child in module.children():
|
||||
attach_execution_device_hook(child, execution_device)
|
||||
attach_execution_device_hook(child, execution_device, tied_params_map=tied_params_map)
|
||||
|
||||
|
||||
def attach_align_device_hook(
|
||||
@ -367,6 +430,7 @@ def attach_align_device_hook(
|
||||
module_name: str = "",
|
||||
skip_keys: Optional[Union[str, List[str]]] = None,
|
||||
preload_module_classes: Optional[List[str]] = None,
|
||||
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
||||
):
|
||||
"""
|
||||
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
|
||||
@ -392,6 +456,10 @@ def attach_align_device_hook(
|
||||
of the forward. This should only be used for classes that have submodules which are registered but not
|
||||
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
||||
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
||||
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
||||
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
||||
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
||||
instead of duplicating memory.
|
||||
"""
|
||||
# Attach the hook on this module if it has any direct tensor.
|
||||
directs = named_module_tensors(module)
|
||||
@ -413,7 +481,7 @@ def attach_align_device_hook(
|
||||
place_submodules=full_offload,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
add_hook_to_module(module, hook, append=True)
|
||||
add_hook_to_module(module, hook, append=True, init_hook_kwargs={"tied_params_map": tied_params_map})
|
||||
|
||||
# We stop the recursion in case we hit the full offload.
|
||||
if full_offload:
|
||||
@ -431,6 +499,7 @@ def attach_align_device_hook(
|
||||
module_name=child_name,
|
||||
preload_module_classes=preload_module_classes,
|
||||
skip_keys=skip_keys,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
|
||||
|
||||
@ -455,6 +524,7 @@ def attach_align_device_hook_on_blocks(
|
||||
module_name: str = "",
|
||||
skip_keys: Optional[Union[str, List[str]]] = None,
|
||||
preload_module_classes: Optional[List[str]] = None,
|
||||
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
||||
):
|
||||
"""
|
||||
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
|
||||
@ -481,6 +551,10 @@ def attach_align_device_hook_on_blocks(
|
||||
of the forward. This should only be used for classes that have submodules which are registered but not
|
||||
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
|
||||
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
|
||||
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
|
||||
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
|
||||
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
|
||||
instead of duplicating memory.
|
||||
"""
|
||||
# If one device and one offload, we've got one hook.
|
||||
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
|
||||
@ -488,7 +562,7 @@ def attach_align_device_hook_on_blocks(
|
||||
hook = AlignDevicesHook(
|
||||
execution_device=execution_device, io_same_device=True, skip_keys=skip_keys, place_submodules=True
|
||||
)
|
||||
add_hook_to_module(module, hook)
|
||||
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
|
||||
else:
|
||||
attach_align_device_hook(
|
||||
module,
|
||||
@ -498,6 +572,7 @@ def attach_align_device_hook_on_blocks(
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=module_name,
|
||||
skip_keys=skip_keys,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
return
|
||||
|
||||
@ -514,8 +589,8 @@ def attach_align_device_hook_on_blocks(
|
||||
place_submodules=True,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
add_hook_to_module(module, hook)
|
||||
attach_execution_device_hook(module, execution_device[module_name])
|
||||
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
|
||||
attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map)
|
||||
elif module_name in execution_device and module_name in offload:
|
||||
attach_align_device_hook(
|
||||
module,
|
||||
@ -526,21 +601,23 @@ def attach_align_device_hook_on_blocks(
|
||||
module_name=module_name,
|
||||
skip_keys=skip_keys,
|
||||
preload_module_classes=preload_module_classes,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
if not hasattr(module, "_hf_hook"):
|
||||
hook = AlignDevicesHook(
|
||||
execution_device=execution_device[module_name], io_same_device=(module_name == ""), skip_keys=skip_keys
|
||||
)
|
||||
add_hook_to_module(module, hook)
|
||||
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
|
||||
attach_execution_device_hook(
|
||||
module,
|
||||
execution_device[module_name],
|
||||
preload_module_classes=preload_module_classes,
|
||||
skip_keys=skip_keys,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
elif module_name == "":
|
||||
hook = AlignDevicesHook(execution_device=execution_device.get(""), io_same_device=True, skip_keys=skip_keys)
|
||||
add_hook_to_module(module, hook)
|
||||
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
|
||||
|
||||
for child_name, child in module.named_children():
|
||||
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
||||
@ -553,6 +630,7 @@ def attach_align_device_hook_on_blocks(
|
||||
module_name=child_name,
|
||||
preload_module_classes=preload_module_classes,
|
||||
skip_keys=skip_keys,
|
||||
tied_params_map=tied_params_map,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -188,6 +188,7 @@ from .other import (
|
||||
is_port_in_use,
|
||||
merge_dicts,
|
||||
patch_environment,
|
||||
recursive_getattr,
|
||||
save,
|
||||
wait_for_everyone,
|
||||
write_basic_config,
|
||||
|
||||
@ -267,6 +267,7 @@ def set_module_tensor_to_device(
|
||||
value: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
fp16_statistics: Optional[torch.HalfTensor] = None,
|
||||
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
|
||||
):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
@ -286,6 +287,10 @@ def set_module_tensor_to_device(
|
||||
the dtype of the existing parameter in the model.
|
||||
fp16_statistics (`torch.HalfTensor`, *optional*):
|
||||
The list of fp16 statistics to set on the module, used for 8 bit model serialization.
|
||||
tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
|
||||
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
|
||||
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
|
||||
device for all others, instead of duplicating memory.
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
@ -302,6 +307,24 @@ def set_module_tensor_to_device(
|
||||
is_buffer = tensor_name in module._buffers
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
# Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
|
||||
# in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
|
||||
if (
|
||||
value is not None
|
||||
and tied_params_map is not None
|
||||
and value.data_ptr() in tied_params_map
|
||||
and device in tied_params_map[value.data_ptr()]
|
||||
):
|
||||
module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
|
||||
return
|
||||
elif (
|
||||
tied_params_map is not None
|
||||
and old_value.data_ptr() in tied_params_map
|
||||
and device in tied_params_map[old_value.data_ptr()]
|
||||
):
|
||||
module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
|
||||
return
|
||||
|
||||
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||
|
||||
@ -367,6 +390,7 @@ def set_module_tensor_to_device(
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
|
||||
else:
|
||||
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
|
||||
|
||||
module._parameters[tensor_name] = new_value
|
||||
if fp16_statistics is not None:
|
||||
setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device))
|
||||
@ -397,6 +421,22 @@ def set_module_tensor_to_device(
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
|
||||
# order to avoid duplicating memory, see above.
|
||||
if (
|
||||
tied_params_map is not None
|
||||
and old_value.data_ptr() in tied_params_map
|
||||
and device not in tied_params_map[old_value.data_ptr()]
|
||||
):
|
||||
tied_params_map[old_value.data_ptr()][device] = new_value
|
||||
elif (
|
||||
value is not None
|
||||
and tied_params_map is not None
|
||||
and value.data_ptr() in tied_params_map
|
||||
and device not in tied_params_map[value.data_ptr()]
|
||||
):
|
||||
tied_params_map[value.data_ptr()][device] = new_value
|
||||
|
||||
|
||||
def named_module_tensors(
|
||||
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
|
||||
@ -832,6 +872,7 @@ def get_balanced_memory(
|
||||
The model to analyze.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
||||
Example: `max_memory={0: "1GB"}`.
|
||||
no_split_module_classes (`List[str]`, *optional*):
|
||||
A list of layer class names that should never be split across device (for instance any layer that has a
|
||||
residual connection).
|
||||
@ -989,6 +1030,7 @@ def infer_auto_device_map(
|
||||
The model to analyze.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
||||
Example: `max_memory={0: "1GB"}`.
|
||||
no_split_module_classes (`List[str]`, *optional*):
|
||||
A list of layer class names that should never be split across device (for instance any layer that has a
|
||||
residual connection).
|
||||
|
||||
@ -18,7 +18,7 @@ import platform
|
||||
import re
|
||||
import socket
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from functools import partial, reduce
|
||||
from types import MethodType
|
||||
from typing import OrderedDict
|
||||
|
||||
@ -320,3 +320,20 @@ def check_os_kernel():
|
||||
"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
|
||||
)
|
||||
logger.warning(msg, main_process_only=True)
|
||||
|
||||
|
||||
def recursive_getattr(obj, attr: str):
|
||||
"""
|
||||
Recursive `getattr`.
|
||||
|
||||
Args:
|
||||
obj:
|
||||
A class instance holding the attribute.
|
||||
attr (`str`):
|
||||
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
|
||||
"""
|
||||
|
||||
def _getattr(obj, attr):
|
||||
return getattr(obj, attr)
|
||||
|
||||
return reduce(_getattr, [obj] + attr.split("."))
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
import copy
|
||||
import os
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import torch
|
||||
@ -363,6 +364,235 @@ class BigModelingTester(unittest.TestCase):
|
||||
dispatch_model(model, device_map)
|
||||
self.assertIs(model.linear2.weight, model.linear1.weight)
|
||||
|
||||
@require_multi_gpu
|
||||
def test_dispatch_model_tied_weights_memory(self):
|
||||
# Test that we do not duplicate tied weights at any point during dispatch_model call.
|
||||
|
||||
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
|
||||
|
||||
model = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("linear0", nn.Linear(5000, 5000, bias=False)),
|
||||
("linear1", nn.Linear(5000, 5000, bias=False)),
|
||||
("linear2", nn.Linear(5000, 5000, bias=False)),
|
||||
("linear3", nn.Linear(5000, 5000, bias=False)),
|
||||
("linear4", nn.Linear(5000, 5000, bias=False)),
|
||||
]
|
||||
)
|
||||
)
|
||||
model.linear2.weight = model.linear0.weight
|
||||
model.linear3.weight = model.linear0.weight
|
||||
model.linear4.weight = model.linear0.weight
|
||||
|
||||
x = torch.randn(5, 5000)
|
||||
with torch.no_grad():
|
||||
expected = model(x)
|
||||
|
||||
# We should need only 5000 * 5000 * 32 // 8 * 1e-6 = 100 MB on the device 0 for the four linear weights.
|
||||
device_map = {"linear0": 0, "linear1": 1, "linear2": 0, "linear3": 0, "linear4": 0}
|
||||
|
||||
# Just to intialize CUDA context.
|
||||
a = torch.rand(5).to("cuda:0") # noqa: F841
|
||||
|
||||
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
required_memory_bytes = 5000 * 5000 * (32 // 8)
|
||||
|
||||
# Leaving 50 MB of free memory for possible buffers, etc.
|
||||
n_vals = (free_memory_bytes - required_memory_bytes - int(50e6)) // (32 // 8)
|
||||
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
|
||||
|
||||
# If this does OOM: there is an issue in somewhere in dispatch_model, memory of tied weights is duplicated.
|
||||
try:
|
||||
dispatch_model(model, device_map)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise torch.cuda.OutOfMemoryError(
|
||||
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
@require_cuda
|
||||
def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self):
|
||||
# Test that we do not duplicate tied weights at any point during dispatch_model call.
|
||||
|
||||
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self, ref_to_parameter):
|
||||
super().__init__()
|
||||
self.parameter = ref_to_parameter
|
||||
|
||||
def forward(self, x):
|
||||
return x + torch.max(self.parameter)
|
||||
|
||||
class LinearModuleAndSubModule(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__(in_features, out_features, bias=False)
|
||||
self.weight_submodule = SubModule(self.weight)
|
||||
self.weight_submodule2 = SubModule(self.weight)
|
||||
self.weight_submodule3 = SubModule(self.weight)
|
||||
self.weight_submodule4 = SubModule(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
|
||||
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
|
||||
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
|
||||
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
|
||||
return a + b + c + d
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.compute = LinearModuleAndSubModule(5000, 5000)
|
||||
self.compute1 = LinearModuleAndSubModule(5000, 5000)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.compute(x)
|
||||
b = self.compute1(x)
|
||||
return a + b
|
||||
|
||||
# We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.
|
||||
device_map = {"compute": 0, "compute1": "cpu"}
|
||||
|
||||
model = Model()
|
||||
|
||||
x = torch.randn(1, 5000)
|
||||
with torch.no_grad():
|
||||
expected = model(x)
|
||||
|
||||
# Just to intialize CUDA context.
|
||||
a = torch.rand(5).to("cuda:0") # noqa: F841
|
||||
|
||||
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
|
||||
|
||||
# Leaving 150 MB of free memory for possible buffers, etc.
|
||||
n_vals = (free_memory_bytes - required_memory_bytes - int(150e6)) // (32 // 8)
|
||||
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
|
||||
|
||||
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
dispatch_model(model, device_map)
|
||||
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
|
||||
self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130)
|
||||
|
||||
original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr()
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
output = model(x)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise torch.cuda.OutOfMemoryError(
|
||||
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_cpu. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
|
||||
# Check that we have no more references on GPU for the offloaded tied weight.
|
||||
self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0)
|
||||
self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0)
|
||||
self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130)
|
||||
|
||||
@require_cuda
|
||||
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):
|
||||
# Test that we do not duplicate tied weights at any point during dispatch_model call.
|
||||
|
||||
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
|
||||
|
||||
class SubModule(torch.nn.Module):
|
||||
def __init__(self, ref_to_parameter):
|
||||
super().__init__()
|
||||
self.parameter = ref_to_parameter
|
||||
|
||||
def forward(self, x):
|
||||
return x + torch.max(self.parameter)
|
||||
|
||||
class LinearModuleAndSubModule(torch.nn.Linear):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__(in_features, out_features, bias=False)
|
||||
self.weight_submodule = SubModule(self.weight)
|
||||
self.weight_submodule2 = SubModule(self.weight)
|
||||
self.weight_submodule3 = SubModule(self.weight)
|
||||
self.weight_submodule4 = SubModule(self.weight)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
|
||||
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
|
||||
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
|
||||
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
|
||||
return a + b + c + d
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.compute = LinearModuleAndSubModule(5000, 5000)
|
||||
self.compute1 = LinearModuleAndSubModule(5000, 5000)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.compute(x)
|
||||
b = self.compute1(x)
|
||||
return a + b
|
||||
|
||||
# We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.
|
||||
device_map = {"compute": 0, "compute1": "disk"}
|
||||
|
||||
model = Model()
|
||||
|
||||
x = torch.randn(1, 5000)
|
||||
with torch.no_grad():
|
||||
expected = model(x)
|
||||
|
||||
# Just to intialize CUDA context.
|
||||
a = torch.rand(5).to("cuda:0") # noqa: F841
|
||||
|
||||
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
|
||||
|
||||
# Leaving 150 MB of free memory for possible buffers, etc.
|
||||
n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8)
|
||||
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
|
||||
|
||||
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
dispatch_model(model, device_map, offload_dir=tmp_dir)
|
||||
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
|
||||
self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130)
|
||||
|
||||
original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr()
|
||||
|
||||
with torch.no_grad():
|
||||
try:
|
||||
output = model(x)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
raise torch.cuda.OutOfMemoryError(
|
||||
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0]
|
||||
|
||||
# Check that we have no more references on GPU for the offloaded tied weight.
|
||||
self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0)
|
||||
self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0)
|
||||
self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130)
|
||||
|
||||
@require_multi_gpu
|
||||
def test_dispatch_model_multi_gpu(self):
|
||||
model = BiggerModelForTest()
|
||||
|
||||
Reference in New Issue
Block a user