mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 14:14:32 +08:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 65b5d2e74a |
@ -220,17 +220,23 @@ class AlignDevicesHook(ModelHook):
|
||||
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
elif self.offload:
|
||||
self.original_devices = {name: param.device for name, param in named_module_tensors(module)}
|
||||
self.original_devices = {
|
||||
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
|
||||
}
|
||||
if self.weights_map is None:
|
||||
self.weights_map = {
|
||||
name: param.to("cpu")
|
||||
for name, param in named_module_tensors(module, include_buffers=self.offload_buffers)
|
||||
for name, param in named_module_tensors(
|
||||
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
||||
)
|
||||
}
|
||||
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
for name, _ in named_module_tensors(
|
||||
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
||||
):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=False):
|
||||
for name, _ in module.named_buffers(recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
return module
|
||||
|
||||
@ -238,14 +244,18 @@ class AlignDevicesHook(ModelHook):
|
||||
if self.io_same_device:
|
||||
self.input_device = find_device([args, kwargs])
|
||||
if self.offload:
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
for name, _ in named_module_tensors(
|
||||
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
||||
):
|
||||
set_module_tensor_to_device(module, name, self.execution_device, value=self.weights_map[name])
|
||||
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
def post_forward(self, module, output):
|
||||
if self.offload:
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
for name, _ in named_module_tensors(
|
||||
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
|
||||
):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if self.io_same_device and self.input_device is not None:
|
||||
@ -403,18 +413,21 @@ def attach_align_device_hook_on_blocks(
|
||||
add_hook_to_module(module, hook)
|
||||
attach_execution_device_hook(module, execution_device[module_name])
|
||||
elif module_name in execution_device:
|
||||
attach_align_device_hook(
|
||||
module,
|
||||
if weights_map is not None:
|
||||
prefix = f"{module_name}." if len(module_name) > 0 else ""
|
||||
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
|
||||
else:
|
||||
prefixed_weights_map = None
|
||||
hook = AlignDevicesHook(
|
||||
execution_device=execution_device[module_name],
|
||||
offload=True,
|
||||
weights_map=weights_map,
|
||||
weights_map=prefixed_weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=module_name,
|
||||
io_same_device=(module_name == ""),
|
||||
place_submodules=True,
|
||||
)
|
||||
if not hasattr(module, "_hf_hook"):
|
||||
hook = AlignDevicesHook(execution_device=execution_device[module_name], io_same_device=(module_name == ""))
|
||||
add_hook_to_module(module, hook)
|
||||
attach_execution_device_hook(module, execution_device[module_name])
|
||||
add_hook_to_module(module, hook)
|
||||
elif module_name == "":
|
||||
hook = AlignDevicesHook(io_same_device=True)
|
||||
add_hook_to_module(module, hook)
|
||||
|
||||
Reference in New Issue
Block a user