mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Add cpu_offload_with_hook
(#1045)
* Add cpu offload with hook * Style * add to init * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add documentation * Add tests --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@ -7,6 +7,7 @@ __version__ = "0.17.0.dev0"
|
|||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .big_modeling import (
|
from .big_modeling import (
|
||||||
cpu_offload,
|
cpu_offload,
|
||||||
|
cpu_offload_with_hook,
|
||||||
disk_offload,
|
disk_offload,
|
||||||
dispatch_model,
|
dispatch_model,
|
||||||
init_empty_weights,
|
init_empty_weights,
|
||||||
|
@ -19,7 +19,14 @@ from typing import Dict, List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .hooks import AlignDevicesHook, add_hook_to_module, attach_align_device_hook, attach_align_device_hook_on_blocks
|
from .hooks import (
|
||||||
|
AlignDevicesHook,
|
||||||
|
CpuOffload,
|
||||||
|
UserCpuOffloadHook,
|
||||||
|
add_hook_to_module,
|
||||||
|
attach_align_device_hook,
|
||||||
|
attach_align_device_hook_on_blocks,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
OffloadedWeightsLoader,
|
OffloadedWeightsLoader,
|
||||||
check_device_map,
|
check_device_map,
|
||||||
@ -184,6 +191,50 @@ def cpu_offload(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def cpu_offload_with_hook(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
execution_device: Optional[Union[int, str, torch.device]] = None,
|
||||||
|
prev_module_hook: Optional[UserCpuOffloadHook] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
|
||||||
|
[`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
|
||||||
|
the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Module`):
|
||||||
|
The model to offload.
|
||||||
|
execution_device(`str`, `int` or `torch.device`, *optional*):
|
||||||
|
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||||
|
GPU 0 if there is a GPU, and finally to the CPU.
|
||||||
|
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
||||||
|
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
|
||||||
|
offload method will be called just before the forward of the model to which this hook is attached.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
hook_1 = cpu_offload_with_hook(model_1, cuda_device)
|
||||||
|
hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
|
||||||
|
hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
|
||||||
|
|
||||||
|
hid_1 = model_1(input)
|
||||||
|
for i in range(50):
|
||||||
|
# model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
|
||||||
|
hid_2 = model_2(hid_1)
|
||||||
|
# model2 is offloaded to the CPU just before this forward.
|
||||||
|
hid_3 = model_3(hid_3)
|
||||||
|
|
||||||
|
# For model3, you need to manually call the hook offload method.
|
||||||
|
hook_3.offload()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
|
||||||
|
add_hook_to_module(model, hook, append=True)
|
||||||
|
user_hook = UserCpuOffloadHook(model, hook)
|
||||||
|
return model, user_hook
|
||||||
|
|
||||||
|
|
||||||
def disk_offload(
|
def disk_offload(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
offload_dir: Union[str, os.PathLike],
|
offload_dir: Union[str, os.PathLike],
|
||||||
|
@ -18,7 +18,14 @@ from typing import Dict, List, Mapping, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .utils import PrefixedDataset, find_device, named_module_tensors, send_to_device, set_module_tensor_to_device
|
from .utils import (
|
||||||
|
PrefixedDataset,
|
||||||
|
find_device,
|
||||||
|
is_mps_available,
|
||||||
|
named_module_tensors,
|
||||||
|
send_to_device,
|
||||||
|
set_module_tensor_to_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelHook:
|
class ModelHook:
|
||||||
@ -494,3 +501,61 @@ def attach_align_device_hook_on_blocks(
|
|||||||
module_name=child_name,
|
module_name=child_name,
|
||||||
preload_module_classes=preload_module_classes,
|
preload_module_classes=preload_module_classes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CpuOffload(ModelHook):
|
||||||
|
"""
|
||||||
|
Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
|
||||||
|
the forward, the user needs to call the `init_hook` method again for this.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_device(`str`, `int` or `torch.device`, *optional*):
|
||||||
|
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||||
|
GPU 0 if there is a GPU, and finally to the CPU.
|
||||||
|
prev_module_hook (`UserCpuOffloadHook`, *optional*):
|
||||||
|
The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
|
||||||
|
passed, its offload method will be called just before the forward of the model to which this hook is
|
||||||
|
attached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
execution_device: Optional[Union[str, int, torch.device]] = None,
|
||||||
|
prev_module_hook: Optional["UserCpuOffloadHook"] = None,
|
||||||
|
):
|
||||||
|
self.prev_module_hook = prev_module_hook
|
||||||
|
|
||||||
|
if execution_device is not None:
|
||||||
|
self.execution_device = execution_device
|
||||||
|
elif is_mps_available():
|
||||||
|
self.execution_device = torch.device("mps")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
self.execution_device = torch.device(0)
|
||||||
|
else:
|
||||||
|
self.execution_device = torch.device("cpu")
|
||||||
|
|
||||||
|
def init_hook(self, module):
|
||||||
|
return module.to("cpu")
|
||||||
|
|
||||||
|
def pre_forward(self, module, *args, **kwargs):
|
||||||
|
module.to(self.execution_device)
|
||||||
|
if self.prev_module_hook is not None:
|
||||||
|
self.prev_module_hook.offload()
|
||||||
|
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||||
|
|
||||||
|
|
||||||
|
class UserCpuOffloadHook:
|
||||||
|
"""
|
||||||
|
A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
|
||||||
|
or remove it entirely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, hook):
|
||||||
|
self.model = model
|
||||||
|
self.hook = hook
|
||||||
|
|
||||||
|
def offload(self):
|
||||||
|
self.hook.init_hook(self.model)
|
||||||
|
|
||||||
|
def remove(self):
|
||||||
|
remove_hook_from_module(self.model)
|
||||||
|
@ -21,6 +21,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from accelerate.big_modeling import (
|
from accelerate.big_modeling import (
|
||||||
cpu_offload,
|
cpu_offload,
|
||||||
|
cpu_offload_with_hook,
|
||||||
disk_offload,
|
disk_offload,
|
||||||
dispatch_model,
|
dispatch_model,
|
||||||
init_empty_weights,
|
init_empty_weights,
|
||||||
@ -484,3 +485,33 @@ class BigModelingTester(unittest.TestCase):
|
|||||||
|
|
||||||
output = new_model(x)
|
output = new_model(x)
|
||||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||||
|
|
||||||
|
@require_cuda
|
||||||
|
def test_cpu_offload_with_hook(self):
|
||||||
|
model1 = torch.nn.Linear(4, 5)
|
||||||
|
model1, hook1 = cpu_offload_with_hook(model1)
|
||||||
|
self.assertEqual(model1.weight.device, torch.device("cpu"))
|
||||||
|
|
||||||
|
inputs = torch.randn(3, 4)
|
||||||
|
outputs = model1(inputs)
|
||||||
|
self.assertEqual(outputs.device, torch.device(0))
|
||||||
|
self.assertEqual(model1.weight.device, torch.device(0))
|
||||||
|
|
||||||
|
hook1.offload()
|
||||||
|
self.assertEqual(model1.weight.device, torch.device("cpu"))
|
||||||
|
|
||||||
|
model2 = torch.nn.Linear(5, 5)
|
||||||
|
model2, hook2 = cpu_offload_with_hook(model2, prev_module_hook=hook1)
|
||||||
|
self.assertEqual(model2.weight.device, torch.device("cpu"))
|
||||||
|
|
||||||
|
outputs = model1(inputs)
|
||||||
|
self.assertEqual(outputs.device, torch.device(0))
|
||||||
|
self.assertEqual(model1.weight.device, torch.device(0))
|
||||||
|
|
||||||
|
outputs = model2(outputs)
|
||||||
|
self.assertEqual(outputs.device, torch.device(0))
|
||||||
|
self.assertEqual(model1.weight.device, torch.device("cpu"))
|
||||||
|
self.assertEqual(model2.weight.device, torch.device(0))
|
||||||
|
|
||||||
|
hook2.offload()
|
||||||
|
self.assertEqual(model2.weight.device, torch.device("cpu"))
|
||||||
|
Reference in New Issue
Block a user