Compare commits

...

1 Commits

Author SHA1 Message Date
65b5d2e74a Try to make it work for BLOOM 2022-05-31 13:42:50 -04:00

View File

@ -220,17 +220,23 @@ class AlignDevicesHook(ModelHook):
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
elif self.offload:
self.original_devices = {name: param.device for name, param in named_module_tensors(module)}
self.original_devices = {
name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
}
if self.weights_map is None:
self.weights_map = {
name: param.to("cpu")
for name, param in named_module_tensors(module, include_buffers=self.offload_buffers)
for name, param in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
)
}
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
):
set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=False):
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
return module
@ -238,14 +244,18 @@ class AlignDevicesHook(ModelHook):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
if self.offload:
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
):
set_module_tensor_to_device(module, name, self.execution_device, value=self.weights_map[name])
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
def post_forward(self, module, output):
if self.offload:
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
):
set_module_tensor_to_device(module, name, "meta")
if self.io_same_device and self.input_device is not None:
@ -403,18 +413,21 @@ def attach_align_device_hook_on_blocks(
add_hook_to_module(module, hook)
attach_execution_device_hook(module, execution_device[module_name])
elif module_name in execution_device:
attach_align_device_hook(
module,
if weights_map is not None:
prefix = f"{module_name}." if len(module_name) > 0 else ""
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
else:
prefixed_weights_map = None
hook = AlignDevicesHook(
execution_device=execution_device[module_name],
offload=True,
weights_map=weights_map,
weights_map=prefixed_weights_map,
offload_buffers=offload_buffers,
module_name=module_name,
io_same_device=(module_name == ""),
place_submodules=True,
)
if not hasattr(module, "_hf_hook"):
hook = AlignDevicesHook(execution_device=execution_device[module_name], io_same_device=(module_name == ""))
add_hook_to_module(module, hook)
attach_execution_device_hook(module, execution_device[module_name])
add_hook_to_module(module, hook)
elif module_name == "":
hook = AlignDevicesHook(io_same_device=True)
add_hook_to_module(module, hook)