Compare commits

...

12 Commits

Author SHA1 Message Date
c1ea6f2474 remove outdated comment 2024-01-16 18:26:55 +01:00
30631a65be disk offloading do not reload tied parameters in memory 2024-01-16 12:51:19 +01:00
f9ecf75d24 Update tests/test_big_modeling.py
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2024-01-16 10:32:00 +01:00
2ea8986df6 Update tests/test_big_modeling.py
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2024-01-16 10:31:43 +01:00
4b95a1d12e cleanup 2024-01-15 18:36:13 +01:00
f199f6baed fix offload, submodules 2024-01-15 18:32:24 +01:00
829d33af01 style & tests pass 2024-01-12 18:40:23 +01:00
ae573266b3 style 2024-01-12 18:28:14 +01:00
ec2d94f02d cleanup 2024-01-12 18:23:25 +01:00
641a22f87e add test 2024-01-12 18:17:47 +01:00
373a6bea9a fix 2024-01-12 18:01:37 +01:00
982e60560b wip 2024-01-12 17:45:03 +01:00
6 changed files with 402 additions and 17 deletions

View File

@ -43,6 +43,7 @@ from .utils import (
parse_flag_from_env,
retie_parameters,
)
from .utils.other import recursive_getattr
logger = logging.getLogger(__name__)
@ -395,7 +396,22 @@ def dispatch_model(
else:
weights_map = None
# When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
# tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
# original pointer) on each devices.
tied_params = find_tied_parameters(model)
tied_params_map = {}
for group in tied_params:
for param_name in group:
# data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
# to care about views of tensors through storage_offset.
data_ptr = recursive_getattr(model, param_name).data_ptr()
tied_params_map[data_ptr] = {}
# Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
attach_align_device_hook_on_blocks(
model,
execution_device=execution_device,
@ -404,6 +420,7 @@ def dispatch_model(
weights_map=weights_map,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
tied_params_map=tied_params_map,
)
# warn if there is any params on the meta device

View File

@ -27,6 +27,7 @@ from .utils import (
set_module_tensor_to_device,
)
from .utils.modeling import get_non_persistent_buffers
from .utils.other import recursive_getattr
class ModelHook:
@ -116,7 +117,9 @@ class SequentialHook(ModelHook):
return module
def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
def add_hook_to_module(
module: nn.Module, hook: ModelHook, append: bool = False, init_hook_kwargs: Optional[Dict] = None
):
"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
@ -135,6 +138,8 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
init_hook_kwargs (Optional[Dict], *optional*, defaults to `None`):
Optional arguments to pass to the hook initialization.
Returns:
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
@ -153,7 +158,10 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
old_forward = module.forward
module._old_forward = old_forward
module = hook.init_hook(module)
if init_hook_kwargs is None:
init_hook_kwargs = {}
module = hook.init_hook(module, **init_hook_kwargs)
module._hf_hook = hook
def new_forward(module, *args, **kwargs):
@ -240,6 +248,7 @@ class AlignDevicesHook(ModelHook):
self.input_device = None
self.param_original_devices = {}
self.buffer_original_devices = {}
self.tied_params_names = set()
def __repr__(self):
return (
@ -248,10 +257,14 @@ class AlignDevicesHook(ModelHook):
f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
)
def init_hook(self, module):
def init_hook(self, module, tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None):
# In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
tied_params_map = None
if not self.offload and self.execution_device is not None:
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
elif self.offload:
self.original_devices = {
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
@ -266,13 +279,25 @@ class AlignDevicesHook(ModelHook):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
):
# When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
# as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
# As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
# to add on the fly pointers to `tied_params_map` in the pre_forward call.
if tied_params_map is not None and recursive_getattr(module, name).data_ptr() in tied_params_map:
self.tied_params_names.add(name)
set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=tied_params_map)
# The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
# for tied weights already loaded on the target execution device.
self.tied_params_map = tied_params_map
return module
@ -280,6 +305,8 @@ class AlignDevicesHook(ModelHook):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
if self.offload:
self.tied_pointers_to_remove = set()
for name, _ in named_module_tensors(
module,
include_buffers=self.offload_buffers,
@ -287,11 +314,32 @@ class AlignDevicesHook(ModelHook):
remove_non_persistent=True,
):
fp16_statistics = None
value = self.weights_map[name]
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
if self.weights_map[name].dtype == torch.int8:
if value.dtype == torch.int8:
fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
# In case we are using offloading with tied weights, we need to keep track of the offloaded weights
# that are loaded on device at this point, as we will need to remove them as well from the dictionary
# self.tied_params_map in order to allow to free memory.
if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
self.tied_params_map[value.data_ptr()] = {}
if (
value is not None
and self.tied_params_map is not None
and value.data_ptr() in self.tied_params_map
and self.execution_device not in self.tied_params_map[value.data_ptr()]
):
self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
set_module_tensor_to_device(
module, name, self.execution_device, value=self.weights_map[name], fp16_statistics=fp16_statistics
module,
name,
self.execution_device,
value=value,
fp16_statistics=fp16_statistics,
tied_params_map=self.tied_params_map,
)
return send_to_device(args, self.execution_device), send_to_device(
@ -311,6 +359,12 @@ class AlignDevicesHook(ModelHook):
module.state.SCB = None
module.state.CxB = None
# We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
# this dictionary to allow the garbage collector to do its job.
for value_pointer, device in self.tied_pointers_to_remove:
del self.tied_params_map[value_pointer][device]
self.tied_pointers_to_remove = None
if self.io_same_device and self.input_device is not None:
output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
@ -329,6 +383,7 @@ def attach_execution_device_hook(
execution_device: Union[int, str, torch.device],
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
@ -346,16 +401,24 @@ def attach_execution_device_hook(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
add_hook_to_module(module, AlignDevicesHook(execution_device, skip_keys=skip_keys))
add_hook_to_module(
module,
AlignDevicesHook(execution_device, skip_keys=skip_keys),
init_hook_kwargs={"tied_params_map": tied_params_map},
)
# Break the recursion if we get to a preload module.
if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
return
for child in module.children():
attach_execution_device_hook(child, execution_device)
attach_execution_device_hook(child, execution_device, tied_params_map=tied_params_map)
def attach_align_device_hook(
@ -367,6 +430,7 @@ def attach_align_device_hook(
module_name: str = "",
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
@ -392,6 +456,10 @@ def attach_align_device_hook(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
# Attach the hook on this module if it has any direct tensor.
directs = named_module_tensors(module)
@ -413,7 +481,7 @@ def attach_align_device_hook(
place_submodules=full_offload,
skip_keys=skip_keys,
)
add_hook_to_module(module, hook, append=True)
add_hook_to_module(module, hook, append=True, init_hook_kwargs={"tied_params_map": tied_params_map})
# We stop the recursion in case we hit the full offload.
if full_offload:
@ -431,6 +499,7 @@ def attach_align_device_hook(
module_name=child_name,
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
@ -455,6 +524,7 @@ def attach_align_device_hook_on_blocks(
module_name: str = "",
skip_keys: Optional[Union[str, List[str]]] = None,
preload_module_classes: Optional[List[str]] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
@ -481,6 +551,10 @@ def attach_align_device_hook_on_blocks(
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
`dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
instead of duplicating memory.
"""
# If one device and one offload, we've got one hook.
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
@ -488,7 +562,7 @@ def attach_align_device_hook_on_blocks(
hook = AlignDevicesHook(
execution_device=execution_device, io_same_device=True, skip_keys=skip_keys, place_submodules=True
)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
else:
attach_align_device_hook(
module,
@ -498,6 +572,7 @@ def attach_align_device_hook_on_blocks(
offload_buffers=offload_buffers,
module_name=module_name,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
return
@ -514,8 +589,8 @@ def attach_align_device_hook_on_blocks(
place_submodules=True,
skip_keys=skip_keys,
)
add_hook_to_module(module, hook)
attach_execution_device_hook(module, execution_device[module_name])
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map)
elif module_name in execution_device and module_name in offload:
attach_align_device_hook(
module,
@ -526,21 +601,23 @@ def attach_align_device_hook_on_blocks(
module_name=module_name,
skip_keys=skip_keys,
preload_module_classes=preload_module_classes,
tied_params_map=tied_params_map,
)
if not hasattr(module, "_hf_hook"):
hook = AlignDevicesHook(
execution_device=execution_device[module_name], io_same_device=(module_name == ""), skip_keys=skip_keys
)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
attach_execution_device_hook(
module,
execution_device[module_name],
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)
elif module_name == "":
hook = AlignDevicesHook(execution_device=execution_device.get(""), io_same_device=True, skip_keys=skip_keys)
add_hook_to_module(module, hook)
add_hook_to_module(module, hook, init_hook_kwargs={"tied_params_map": tied_params_map})
for child_name, child in module.named_children():
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
@ -553,6 +630,7 @@ def attach_align_device_hook_on_blocks(
module_name=child_name,
preload_module_classes=preload_module_classes,
skip_keys=skip_keys,
tied_params_map=tied_params_map,
)

View File

@ -188,6 +188,7 @@ from .other import (
is_port_in_use,
merge_dicts,
patch_environment,
recursive_getattr,
save,
wait_for_everyone,
write_basic_config,

View File

@ -267,6 +267,7 @@ def set_module_tensor_to_device(
value: Optional[torch.Tensor] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
@ -286,6 +287,10 @@ def set_module_tensor_to_device(
the dtype of the existing parameter in the model.
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for 8 bit model serialization.
tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
device for all others, instead of duplicating memory.
"""
# Recurse if needed
if "." in tensor_name:
@ -302,6 +307,24 @@ def set_module_tensor_to_device(
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
# Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
# in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
if (
value is not None
and tied_params_map is not None
and value.data_ptr() in tied_params_map
and device in tied_params_map[value.data_ptr()]
):
module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
return
elif (
tied_params_map is not None
and old_value.data_ptr() in tied_params_map
and device in tied_params_map[old_value.data_ptr()]
):
module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
return
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
@ -367,6 +390,7 @@ def set_module_tensor_to_device(
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device))
@ -397,6 +421,22 @@ def set_module_tensor_to_device(
else:
torch.cuda.empty_cache()
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
# order to avoid duplicating memory, see above.
if (
tied_params_map is not None
and old_value.data_ptr() in tied_params_map
and device not in tied_params_map[old_value.data_ptr()]
):
tied_params_map[old_value.data_ptr()][device] = new_value
elif (
value is not None
and tied_params_map is not None
and value.data_ptr() in tied_params_map
and device not in tied_params_map[value.data_ptr()]
):
tied_params_map[value.data_ptr()][device] = new_value
def named_module_tensors(
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
@ -832,6 +872,7 @@ def get_balanced_memory(
The model to analyze.
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
Example: `max_memory={0: "1GB"}`.
no_split_module_classes (`List[str]`, *optional*):
A list of layer class names that should never be split across device (for instance any layer that has a
residual connection).
@ -989,6 +1030,7 @@ def infer_auto_device_map(
The model to analyze.
max_memory (`Dict`, *optional*):
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
Example: `max_memory={0: "1GB"}`.
no_split_module_classes (`List[str]`, *optional*):
A list of layer class names that should never be split across device (for instance any layer that has a
residual connection).

View File

@ -18,7 +18,7 @@ import platform
import re
import socket
from contextlib import contextmanager
from functools import partial
from functools import partial, reduce
from types import MethodType
from typing import OrderedDict
@ -320,3 +320,20 @@ def check_os_kernel():
"cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
)
logger.warning(msg, main_process_only=True)
def recursive_getattr(obj, attr: str):
"""
Recursive `getattr`.
Args:
obj:
A class instance holding the attribute.
attr (`str`):
The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
"""
def _getattr(obj, attr):
return getattr(obj, attr)
return reduce(_getattr, [obj] + attr.split("."))

View File

@ -14,6 +14,7 @@
import copy
import os
import unittest
from collections import OrderedDict
from tempfile import TemporaryDirectory
import torch
@ -363,6 +364,235 @@ class BigModelingTester(unittest.TestCase):
dispatch_model(model, device_map)
self.assertIs(model.linear2.weight, model.linear1.weight)
@require_multi_gpu
def test_dispatch_model_tied_weights_memory(self):
# Test that we do not duplicate tied weights at any point during dispatch_model call.
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
model = nn.Sequential(
OrderedDict(
[
("linear0", nn.Linear(5000, 5000, bias=False)),
("linear1", nn.Linear(5000, 5000, bias=False)),
("linear2", nn.Linear(5000, 5000, bias=False)),
("linear3", nn.Linear(5000, 5000, bias=False)),
("linear4", nn.Linear(5000, 5000, bias=False)),
]
)
)
model.linear2.weight = model.linear0.weight
model.linear3.weight = model.linear0.weight
model.linear4.weight = model.linear0.weight
x = torch.randn(5, 5000)
with torch.no_grad():
expected = model(x)
# We should need only 5000 * 5000 * 32 // 8 * 1e-6 = 100 MB on the device 0 for the four linear weights.
device_map = {"linear0": 0, "linear1": 1, "linear2": 0, "linear3": 0, "linear4": 0}
# Just to intialize CUDA context.
a = torch.rand(5).to("cuda:0") # noqa: F841
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
required_memory_bytes = 5000 * 5000 * (32 // 8)
# Leaving 50 MB of free memory for possible buffers, etc.
n_vals = (free_memory_bytes - required_memory_bytes - int(50e6)) // (32 // 8)
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
# If this does OOM: there is an issue in somewhere in dispatch_model, memory of tied weights is duplicated.
try:
dispatch_model(model, device_map)
except torch.cuda.OutOfMemoryError as e:
raise torch.cuda.OutOfMemoryError(
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory. {e}"
)
except Exception as e:
raise e
with torch.no_grad():
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_cuda
def test_dispatch_model_tied_weights_memory_with_nested_offload_cpu(self):
# Test that we do not duplicate tied weights at any point during dispatch_model call.
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
class SubModule(torch.nn.Module):
def __init__(self, ref_to_parameter):
super().__init__()
self.parameter = ref_to_parameter
def forward(self, x):
return x + torch.max(self.parameter)
class LinearModuleAndSubModule(torch.nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=False)
self.weight_submodule = SubModule(self.weight)
self.weight_submodule2 = SubModule(self.weight)
self.weight_submodule3 = SubModule(self.weight)
self.weight_submodule4 = SubModule(self.weight)
def forward(self, x):
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
return a + b + c + d
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.compute = LinearModuleAndSubModule(5000, 5000)
self.compute1 = LinearModuleAndSubModule(5000, 5000)
def forward(self, x):
a = self.compute(x)
b = self.compute1(x)
return a + b
# We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.
device_map = {"compute": 0, "compute1": "cpu"}
model = Model()
x = torch.randn(1, 5000)
with torch.no_grad():
expected = model(x)
# Just to intialize CUDA context.
a = torch.rand(5).to("cuda:0") # noqa: F841
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
# Leaving 150 MB of free memory for possible buffers, etc.
n_vals = (free_memory_bytes - required_memory_bytes - int(150e6)) // (32 // 8)
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
dispatch_model(model, device_map)
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130)
original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr()
with torch.no_grad():
try:
output = model(x)
except torch.cuda.OutOfMemoryError as e:
raise torch.cuda.OutOfMemoryError(
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_cpu. {e}"
)
except Exception as e:
raise e
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
torch.cuda.empty_cache()
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0]
# Check that we have no more references on GPU for the offloaded tied weight.
self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0)
self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0)
self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130)
@require_cuda
def test_dispatch_model_tied_weights_memory_with_nested_offload_disk(self):
# Test that we do not duplicate tied weights at any point during dispatch_model call.
torch.cuda.empty_cache() # Needed in case we run several tests in a row.
class SubModule(torch.nn.Module):
def __init__(self, ref_to_parameter):
super().__init__()
self.parameter = ref_to_parameter
def forward(self, x):
return x + torch.max(self.parameter)
class LinearModuleAndSubModule(torch.nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=False)
self.weight_submodule = SubModule(self.weight)
self.weight_submodule2 = SubModule(self.weight)
self.weight_submodule3 = SubModule(self.weight)
self.weight_submodule4 = SubModule(self.weight)
def forward(self, x):
a = torch.nn.functional.linear(self.weight_submodule(x), self.weight)
b = torch.nn.functional.linear(self.weight_submodule2(x), self.weight)
c = torch.nn.functional.linear(self.weight_submodule3(x), self.weight)
d = torch.nn.functional.linear(self.weight_submodule4(x), self.weight)
return a + b + c + d
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.compute = LinearModuleAndSubModule(5000, 5000)
self.compute1 = LinearModuleAndSubModule(5000, 5000)
def forward(self, x):
a = self.compute(x)
b = self.compute1(x)
return a + b
# We should need only 2 * 5000 * 5000 * 32 // 8 * 1e-6 = 200 MB on the device 0 for the whole model forward, and not 600 MB.
device_map = {"compute": 0, "compute1": "disk"}
model = Model()
x = torch.randn(1, 5000)
with torch.no_grad():
expected = model(x)
# Just to intialize CUDA context.
a = torch.rand(5).to("cuda:0") # noqa: F841
free_memory_bytes = torch.cuda.mem_get_info("cuda:0")[0]
required_memory_bytes = 2 * 5000 * 5000 * (32 // 8) # 200 MB
# Leaving 150 MB of free memory for possible buffers, etc.
n_vals = (free_memory_bytes - required_memory_bytes - int(200e6)) // (32 // 8)
foo = torch.rand(n_vals, device="cuda:0") # noqa: F841
free_memory_bytes_before_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir)
free_memory_bytes_after_dispatch = torch.cuda.mem_get_info("cuda:0")[0]
self.assertTrue((free_memory_bytes_after_dispatch - free_memory_bytes_before_dispatch) * 1e-6 < 130)
original_pointer = model.compute1._hf_hook.weights_map["weight"].data_ptr()
with torch.no_grad():
try:
output = model(x)
except torch.cuda.OutOfMemoryError as e:
raise torch.cuda.OutOfMemoryError(
f"OOM error in dispatch_model. This is a bug and should not happen, see test_dispatch_model_tied_weights_memory_with_nested_offload_disk. {e}"
)
except Exception as e:
raise e
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
torch.cuda.empty_cache()
free_memory_bytes_after_infer = torch.cuda.mem_get_info("cuda:0")[0]
# Check that we have no more references on GPU for the offloaded tied weight.
self.assertTrue(len(model.compute1.weight_submodule._hf_hook.tied_params_map[original_pointer]) == 0)
self.assertTrue(len(model.compute1._hf_hook.tied_params_map[original_pointer]) == 0)
self.assertTrue((free_memory_bytes_after_infer - free_memory_bytes_after_dispatch) * 1e-6 < 130)
@require_multi_gpu
def test_dispatch_model_multi_gpu(self):
model = BiggerModelForTest()