Add cpu_offload_with_hook (#1045)

* Add cpu offload with hook

* Style

* add to init

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add documentation

* Add tests

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2023-02-07 13:09:27 -05:00
committed by GitHub
parent 76c41f0df7
commit 71e81bab00
4 changed files with 150 additions and 2 deletions

View File

@ -7,6 +7,7 @@ __version__ = "0.17.0.dev0"
from .accelerator import Accelerator from .accelerator import Accelerator
from .big_modeling import ( from .big_modeling import (
cpu_offload, cpu_offload,
cpu_offload_with_hook,
disk_offload, disk_offload,
dispatch_model, dispatch_model,
init_empty_weights, init_empty_weights,

View File

@ -19,7 +19,14 @@ from typing import Dict, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from .hooks import AlignDevicesHook, add_hook_to_module, attach_align_device_hook, attach_align_device_hook_on_blocks from .hooks import (
AlignDevicesHook,
CpuOffload,
UserCpuOffloadHook,
add_hook_to_module,
attach_align_device_hook,
attach_align_device_hook_on_blocks,
)
from .utils import ( from .utils import (
OffloadedWeightsLoader, OffloadedWeightsLoader,
check_device_map, check_device_map,
@ -184,6 +191,50 @@ def cpu_offload(
return model return model
def cpu_offload_with_hook(
model: torch.nn.Module,
execution_device: Optional[Union[int, str, torch.device]] = None,
prev_module_hook: Optional[UserCpuOffloadHook] = None,
):
"""
Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
[`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
Args:
model (`torch.nn.Module`):
The model to offload.
execution_device(`str`, `int` or `torch.device`, *optional*):
The device on which the model should be executed. Will default to the MPS device if it's available, then
GPU 0 if there is a GPU, and finally to the CPU.
prev_module_hook (`UserCpuOffloadHook`, *optional*):
The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
offload method will be called just before the forward of the model to which this hook is attached.
Example:
```py
hook_1 = cpu_offload_with_hook(model_1, cuda_device)
hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
hid_1 = model_1(input)
for i in range(50):
# model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
hid_2 = model_2(hid_1)
# model2 is offloaded to the CPU just before this forward.
hid_3 = model_3(hid_3)
# For model3, you need to manually call the hook offload method.
hook_3.offload()
```
"""
hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
add_hook_to_module(model, hook, append=True)
user_hook = UserCpuOffloadHook(model, hook)
return model, user_hook
def disk_offload( def disk_offload(
model: nn.Module, model: nn.Module,
offload_dir: Union[str, os.PathLike], offload_dir: Union[str, os.PathLike],

View File

@ -18,7 +18,14 @@ from typing import Dict, List, Mapping, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from .utils import PrefixedDataset, find_device, named_module_tensors, send_to_device, set_module_tensor_to_device from .utils import (
PrefixedDataset,
find_device,
is_mps_available,
named_module_tensors,
send_to_device,
set_module_tensor_to_device,
)
class ModelHook: class ModelHook:
@ -494,3 +501,61 @@ def attach_align_device_hook_on_blocks(
module_name=child_name, module_name=child_name,
preload_module_classes=preload_module_classes, preload_module_classes=preload_module_classes,
) )
class CpuOffload(ModelHook):
"""
Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
the forward, the user needs to call the `init_hook` method again for this.
Args:
execution_device(`str`, `int` or `torch.device`, *optional*):
The device on which the model should be executed. Will default to the MPS device if it's available, then
GPU 0 if there is a GPU, and finally to the CPU.
prev_module_hook (`UserCpuOffloadHook`, *optional*):
The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
passed, its offload method will be called just before the forward of the model to which this hook is
attached.
"""
def __init__(
self,
execution_device: Optional[Union[str, int, torch.device]] = None,
prev_module_hook: Optional["UserCpuOffloadHook"] = None,
):
self.prev_module_hook = prev_module_hook
if execution_device is not None:
self.execution_device = execution_device
elif is_mps_available():
self.execution_device = torch.device("mps")
elif torch.cuda.is_available():
self.execution_device = torch.device(0)
else:
self.execution_device = torch.device("cpu")
def init_hook(self, module):
return module.to("cpu")
def pre_forward(self, module, *args, **kwargs):
module.to(self.execution_device)
if self.prev_module_hook is not None:
self.prev_module_hook.offload()
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
class UserCpuOffloadHook:
"""
A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
or remove it entirely.
"""
def __init__(self, model, hook):
self.model = model
self.hook = hook
def offload(self):
self.hook.init_hook(self.model)
def remove(self):
remove_hook_from_module(self.model)

View File

@ -21,6 +21,7 @@ import torch.nn as nn
from accelerate.big_modeling import ( from accelerate.big_modeling import (
cpu_offload, cpu_offload,
cpu_offload_with_hook,
disk_offload, disk_offload,
dispatch_model, dispatch_model,
init_empty_weights, init_empty_weights,
@ -484,3 +485,33 @@ class BigModelingTester(unittest.TestCase):
output = new_model(x) output = new_model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5)) self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
@require_cuda
def test_cpu_offload_with_hook(self):
model1 = torch.nn.Linear(4, 5)
model1, hook1 = cpu_offload_with_hook(model1)
self.assertEqual(model1.weight.device, torch.device("cpu"))
inputs = torch.randn(3, 4)
outputs = model1(inputs)
self.assertEqual(outputs.device, torch.device(0))
self.assertEqual(model1.weight.device, torch.device(0))
hook1.offload()
self.assertEqual(model1.weight.device, torch.device("cpu"))
model2 = torch.nn.Linear(5, 5)
model2, hook2 = cpu_offload_with_hook(model2, prev_module_hook=hook1)
self.assertEqual(model2.weight.device, torch.device("cpu"))
outputs = model1(inputs)
self.assertEqual(outputs.device, torch.device(0))
self.assertEqual(model1.weight.device, torch.device(0))
outputs = model2(outputs)
self.assertEqual(outputs.device, torch.device(0))
self.assertEqual(model1.weight.device, torch.device("cpu"))
self.assertEqual(model2.weight.device, torch.device(0))
hook2.offload()
self.assertEqual(model2.weight.device, torch.device("cpu"))