Files
DeepSpeed/deepspeed/runtime/swap_tensor/optimizer_utils.py
Yuanyuan Chen 1c03d1b1bb Fix invalid f-strings (#7457)
Fix invalid f-strings detected by ruff.

---------

Signed-off-by: cyy <cyyever@outlook.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
2025-08-16 18:22:19 +00:00

528 lines
21 KiB
Python

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Functionality of swapping tensors to/from (NVMe) storage devices.
"""
import os
import torch
from deepspeed import comm as dist
from deepspeed.utils.logging import logger
from deepspeed.runtime.swap_tensor.constants import *
from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, \
MIN_AIO_BYTES, AIO_ALIGNED_BYTES, get_sized_buffers
from deepspeed.runtime.swap_tensor.utils import SwapBufferManager, SwapBufferPool
from deepspeed.accelerator import get_accelerator
class FlattenedTensorSwapInfo(object):
def __init__(self, path, length, offset):
self.path = path
self.offset = offset
self.length = length
class SwapTensorContext(object):
def __init__(self, tensor, swap_folder):
self.compute_tensor = tensor
self.swap_tensor = torch.Tensor()
self.swap_path = os.path.join(swap_folder, f'{OptimizerSwapper.parameter_id(tensor)}.tensor.swp')
def release_memory(self):
self.compute_tensor.data = torch.Tensor()
self.swap_tensor.data = torch.Tensor()
def set_buffers(self, compute_buffer, swap_buffer):
self.compute_tensor.data = compute_buffer.data
self.swap_tensor.data = swap_buffer.data
class OptimizerStateSwapInfo(object):
def __init__(self, parameter, numel, base_folder):
self.tensors = []
self.param_id = OptimizerSwapper.parameter_id(parameter)
self.swap_folder = base_folder
self.swapped_gradients = {}
self.unswapped_gradients = {}
self.tensor_numel = numel
self.tensor_dtype = parameter.dtype
self.tensor_device = parameter.device
self.has_state_tensors = False
self.swap_buffers = []
self._add_tensors([parameter])
def numel(self):
return self.tensor_numel
def has_gradients(self):
return bool(self.swapped_gradients) or bool(self.unswapped_gradients)
def _add_tensors(self, tensor_list):
for t in tensor_list:
self.tensors.append(SwapTensorContext(t, self.swap_folder))
def add_state_tensors(self, tensor_list):
self.has_state_tensors = True
self._add_tensors(tensor_list)
def num_tensors(self):
return len(self.tensors)
def device(self):
return self.tensor_device
def dtype(self):
return self.tensor_dtype
def release_memory(self):
for t in self.tensors:
t.release_memory()
def get_compute_tensors(self):
return [t.compute_tensor for t in self.tensors]
def get_swap_paths(self):
return [t.swap_path for t in self.tensors]
def get_swap_buffers_and_paths(self, pinned):
swap_buffers = []
swap_paths = []
select_tensors = [t for t in self.tensors if get_accelerator().is_pinned(t.compute_tensor) == pinned]
for t in select_tensors:
swap_buffers.append(t.swap_tensor if pinned else t.compute_tensor)
swap_paths.append(t.swap_path)
return swap_buffers, swap_paths
def get_or_create_gradient_paths(self, offsets, lengths):
gradient_paths = []
for offset, length in zip(offsets, lengths):
if offset not in self.swapped_gradients.keys():
path = os.path.join(self.swap_folder, f'{self.param_id}_gradient_{offset}_{length}.tensor.swp')
self.swapped_gradients[offset] = FlattenedTensorSwapInfo(path, length, offset)
gradient_paths.append(self.swapped_gradients[offset].path)
return gradient_paths
def set_swap_buffers(self, buffers, aligned_numel):
num_tensors = len(self.tensors)
compute_lengths = [self.numel()] * num_tensors
compute_buffers = get_sized_buffers(buffers, compute_lengths)
swap_lengths = [aligned_numel] * num_tensors
swap_buffers = get_sized_buffers(buffers, swap_lengths)
for i, t in enumerate(self.tensors):
t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i])
def get_swap_gradient_buffers(self, swap_buffer):
assert self.numel() <= swap_buffer.numel()
return [swap_buffer.narrow(0, grad.offset, grad.length) for grad in self.swapped_gradients.values()]
def get_swap_gradient_paths(self):
return [grad.path for grad in self.swapped_gradients.values()]
def get_unpinned_state_tensors(self):
return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)]
def read_unswapped_gradients(self, dest_buffer):
num_elem_count = 0
for offset, grad_partition in self.unswapped_gradients.items():
dst_tensor = dest_buffer.narrow(0, offset, grad_partition.numel())
dst_tensor.data.copy_(grad_partition.data)
num_elem_count += grad_partition.numel()
return num_elem_count
def write_unswapped_gradients(self, src_buffer):
num_elem_count = 0
for offset, grad_partition in self.unswapped_gradients.items():
src_tensor = src_buffer.narrow(0, offset, grad_partition.numel())
grad_partition.data.copy_(src_tensor.data)
num_elem_count += grad_partition.numel()
return num_elem_count
def release_unswapped_gradients(self):
self.unswapped_gradients = {}
SWAPPER_DEBUG_MODE = False
SWAP_OUT_GRADIENT_TIMER = 'swap_out_gradient'
class OptimizerSwapper(object):
@staticmethod
def parameter_id(param):
return param.ds_id
def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers):
self.swap_config = swap_config
self.aio_config = aio_config
# NVMe swap management
self.swap_params_info = {}
self.swap_element_size = torch.tensor([], dtype=dtype).element_size()
self.swap_folder = os.path.join(base_folder, 'optimizer', f'rank{dist.get_rank()}')
os.makedirs(self.swap_folder, exist_ok=True)
self.optimizer = optimizer
# Read/Write alignment for each thread during Intra-request parallelism
self.min_aio_bytes = max(MIN_AIO_BYTES, aio_config[AIO_BLOCK_SIZE])
self.aligned_bytes = AIO_ALIGNED_BYTES * aio_config[AIO_INTRA_OP_PARALLELISM]
self.numel_alignment = self.aligned_bytes // self.swap_element_size
# Swap buffer management
self.largest_numel = self._io_aligned_numel(largest_numel)
self.dtype = dtype
self.swap_buffer_manager = SwapBufferManager(num_elems=self.largest_numel,
count=swap_config.buffer_count,
dtype=dtype)
# Timers
self.timers = timers
self.timer_names = set()
# Print exclusion list
self.print_exclude_list = [
'optimizer',
'swap_buffer_manager',
'swap_params_info',
'timers',
'timer_names',
]
def purge_state(self):
for swap_info in self.swap_params_info.values():
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False
def is_swappable_tensor(self, tensor=None, numel=None):
assert tensor is not None or numel is not None, "Either tensor or numel must be provided"
if tensor is not None:
return self.min_aio_bytes <= (tensor.numel() * self.swap_element_size)
return self.min_aio_bytes <= (numel * self.swap_element_size)
def init_timers(self):
self.timer_names = set()
def log_timers(self):
if self.timer_names:
self._log_timers(list(self.timer_names), force=True)
def pre_backward(self):
self.init_timers()
def post_backward(self):
pass
def _flush_gradient_swapper(self, gradient_swapper):
if gradient_swapper.has_buffers():
self._start_timer(SWAP_OUT_GRADIENT_TIMER)
pinned_buffers = gradient_swapper.release_buffers()
self.swap_buffer_manager.free(pinned_buffers)
self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
self.timer_names.update(gradient_swapper.get_timer_names())
def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gradient_swapper):
if OptimizerSwapper.parameter_id(parameter) not in self.swap_params_info.keys():
return
swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)]
swappable_tensors = []
swappable_offsets = []
swappable_lengths = []
aligned_gradients, aligned_offsets = self._adjust_for_misaligned_lengths(tensors=gradient_tensors,
offsets=gradient_offsets)
self._start_timer(SWAP_OUT_GRADIENT_TIMER)
for tensor, offset in zip(aligned_gradients, aligned_offsets):
if not self.is_swappable_tensor(tensor=tensor):
swap_info.unswapped_gradients[offset] = tensor
continue
swappable_tensors.append(tensor)
swappable_offsets.append(offset)
swappable_lengths.append(tensor.numel())
if len(swappable_tensors) > 0:
if not gradient_swapper.has_buffers():
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
gradient_swapper.add_buffers(pinned_buffers)
swappable_paths = swap_info.get_or_create_gradient_paths(swappable_offsets, swappable_lengths)
gradient_swapper.swap_out_tensors(tensor_list=swappable_tensors, path_list=swappable_paths)
self._stop_timer(SWAP_OUT_GRADIENT_TIMER)
self.timer_names.add(SWAP_OUT_GRADIENT_TIMER)
def _initialize_from_swapped_fp16_params(self, aio_handle, fp16_partitions_info, fp16_num_elems,
fp16_pinned_buffers, fp32_parameters):
assert len(fp32_parameters) == len(fp16_partitions_info)
assert len(fp32_parameters) == len(fp16_num_elems)
assert all([get_accelerator().is_pinned(buffer) for buffer in fp16_pinned_buffers])
fp32_swap_paths = self._get_swap_paths(parameters=fp32_parameters, num_elems=fp16_num_elems)
fp32_pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
fp16_buffer_numel = [buf.numel() for buf in fp16_pinned_buffers]
assert all([numel >= self.largest_numel for numel in fp16_buffer_numel]), \
f"numel of fp16 buffers {fp16_buffer_numel} is too small for initializing fp32 params {self.largest_numel}"
fp32_swap_buffers = SwapBufferPool(fp32_pinned_buffers)
fp16_swap_buffers = SwapBufferPool(fp16_pinned_buffers)
curr_index = 0
while curr_index < len(fp32_parameters):
fp16_pinned_tensors = self._swap_in_fp16_params(aio_handle=aio_handle,
fp16_num_elems=fp16_num_elems[curr_index:],
fp16_partitions_info=fp16_partitions_info[curr_index:],
fp16_swap_buffers=fp16_swap_buffers)
if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
for i, tensor in enumerate(fp16_pinned_tensors):
true_index = curr_index + i
logger.info(
f'swap_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(fp32_parameters[true_index])} index = {true_index} orig_num_elem = {fp16_num_elems[true_index]}, swap_num_elem = {fp16_pinned_tensors[i].numel()}'
)
swap_out_count = self._swap_out_fp16_params(aio_handle=aio_handle,
fp32_swap_paths=fp32_swap_paths[curr_index:],
fp32_swap_buffers=fp32_swap_buffers,
fp16_pinned_tensors=fp16_pinned_tensors)
assert swap_out_count == len(fp16_pinned_tensors), \
f"{swap_out_count} does not match {len(fp16_pinned_tensors)}"
fp16_swap_buffers.reset()
fp32_swap_buffers.reset()
curr_index += swap_out_count
self.swap_buffer_manager.free(fp32_pinned_buffers)
def _swap_in_fp16_params(self, aio_handle, fp16_num_elems, fp16_partitions_info, fp16_swap_buffers):
assert len(fp16_num_elems) > 0
swapped_fp16_tensors = []
swap_tensors = []
swap_paths = []
unswapped_srcs = []
unswapped_dsts = []
for i, numel in enumerate(fp16_num_elems):
pinned_tensor, _ = fp16_swap_buffers.allocate_tensor(numel, None, numel)
if pinned_tensor is None:
break
swapped_fp16_tensors.append(pinned_tensor)
offset = 0
for tensor, partition_numel, partition_path in fp16_partitions_info[i]:
dst_tensor = pinned_tensor.narrow(0, offset, partition_numel)
if partition_path is None:
unswapped_srcs.append(tensor)
unswapped_dsts.append(dst_tensor)
else:
swap_paths.append(partition_path)
swap_tensors.append(dst_tensor)
offset += partition_numel
assert len(swapped_fp16_tensors) + len(unswapped_srcs) > 0
ret = swap_in_tensors(aio_handle, swap_tensors, swap_paths)
for src, dst in zip(unswapped_srcs, unswapped_dsts):
dst.data.copy_(src.data)
assert len(swap_tensors) == aio_handle.wait()
return swapped_fp16_tensors
def _swap_out_fp16_params(self, aio_handle, fp32_swap_paths, fp32_swap_buffers, fp16_pinned_tensors):
assert len(fp16_pinned_tensors) <= len(fp32_swap_paths)
swap_out_count = 0
for i, fp16_tensor in enumerate(fp16_pinned_tensors):
if not fp32_swap_buffers.has_space(fp16_tensor.numel()):
fp32_swap_buffers.swap_out(aio_handle)
fp32_swap_buffers.reset()
pinned_tensor, _ = fp32_swap_buffers.insert_tensor(fp16_tensor, fp32_swap_paths[i],
self._io_aligned_numel(fp16_tensor.numel()))
assert pinned_tensor is not None
swap_out_count += 1
if len(fp32_swap_buffers.get_swap_tensors()) > 0:
fp32_swap_buffers.swap_out(aio_handle)
return swap_out_count
def _initialize_parameters(self, parameters, src_tensors, aio_handle):
assert len(parameters) == len(src_tensors)
swap_paths = self._get_swap_paths(parameters=parameters, num_elems=[src.numel() for src in src_tensors])
SWAP_INIT_TIMER = "swap_init_write"
self._start_timer(SWAP_INIT_TIMER)
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
assert pinned_buffers is not None
self._swap_out_unpinned_tensors(aio_handle=aio_handle,
unpinned_tensors=src_tensors,
dest_paths=swap_paths,
pinned_buffers=pinned_buffers)
if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
for i, tensor in enumerate(src_tensors):
logger.info(
f'copy_in_fp16_param: fp32_id = {OptimizerSwapper.parameter_id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
)
self.swap_buffer_manager.free(pinned_buffers)
self._stop_timer(SWAP_INIT_TIMER)
self._log_timers([SWAP_INIT_TIMER])
def _get_swap_paths(self, parameters, num_elems):
swap_info_list = [
self._create_param_swap_info(parameter=p,
numel=numel) \
for p, numel in zip(parameters, num_elems)
]
assert len(swap_info_list) == len(num_elems)
swap_paths = [info.tensors[0].swap_path for info in swap_info_list]
return swap_paths
def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
swap_buffer_count = len(pinned_buffers)
unpinned_tensor_count = len(unpinned_tensors)
for i in range(0, unpinned_tensor_count, swap_buffer_count):
swap_tensor_count = min((unpinned_tensor_count - i), swap_buffer_count)
src_tensors = unpinned_tensors[i:(i + swap_tensor_count)]
compute_lengths = [t.numel() for t in src_tensors]
compute_buffers = get_sized_buffers(pinned_buffers, compute_lengths)
for dst, src in zip(compute_buffers, src_tensors):
dst.data.copy_(src.data)
swap_lengths = [self._io_aligned_numel(t.numel()) for t in src_tensors]
swap_buffers = get_sized_buffers(pinned_buffers, swap_lengths)
swap_paths = dest_paths[i:(i + swap_tensor_count)]
swap_out_tensors(aio_handle, swap_buffers, swap_paths)
assert aio_handle.wait() == swap_tensor_count
def _adjust_for_misaligned_lengths(self, tensors, offsets):
new_tensors = []
new_offsets = []
for orig_tensor, orig_offset in zip(tensors, offsets):
if not self.is_swappable_tensor(tensor=orig_tensor):
new_tensors.append(orig_tensor)
new_offsets.append(orig_offset)
continue
remainder = orig_tensor.numel() % self.numel_alignment
if remainder == 0:
new_tensors.append(orig_tensor)
new_offsets.append(orig_offset)
continue
# Split into two by making remainder a tensor
aligned_length = (orig_tensor.numel() // self.numel_alignment) * self.numel_alignment
new_tensors.append(orig_tensor.narrow(0, 0, aligned_length))
new_offsets.append(orig_offset)
# remainder tensor
new_tensors.append(orig_tensor.narrow(0, aligned_length, remainder))
new_offsets.append(orig_offset + aligned_length)
return new_tensors, new_offsets
def _retrieve_unswapped_grad_partitions(self, swap_info, dest_buffer):
UNSWAPPED_READ_GRADIENTS = 'unswapped_read_gradients'
self._start_timer(UNSWAPPED_READ_GRADIENTS)
tensor_count = len(swap_info.unswapped_gradients)
num_elem_count = swap_info.read_unswapped_gradients(dest_buffer)
self._stop_timer(UNSWAPPED_READ_GRADIENTS)
self._log_timers([UNSWAPPED_READ_GRADIENTS])
# It should be safe to discard unswapped gradient partitions
swap_info.release_unswapped_gradients()
if SWAPPER_DEBUG_MODE:
logger.info(
f'optimizer_retrieve_unswapped_gradients: param={swap_info.param_id} tensor_count={tensor_count} elem_count={num_elem_count}'
)
def _get_state_tensors(self, parameter):
if parameter not in self.optimizer.state:
return []
tensor_list = []
for state_name, value in self.optimizer.state[parameter].items():
if torch.is_tensor(value) and self.is_swappable_tensor(tensor=value):
value.ds_id = state_name + '-' + parameter.ds_id
tensor_list.append(value)
return tensor_list
def _update_param_state_info(self, swap_info, parameter):
if not swap_info.has_state_tensors:
state_tensors = self._get_state_tensors(parameter)
if state_tensors:
swap_info.add_state_tensors(state_tensors)
def _create_param_swap_info(self, parameter, numel):
param_id = OptimizerSwapper.parameter_id(parameter)
assert param_id not in self.swap_params_info
self.swap_params_info[param_id] = OptimizerStateSwapInfo(parameter=parameter,
numel=numel,
base_folder=self.swap_folder)
swap_info = self.swap_params_info[param_id]
self._update_param_state_info(swap_info, parameter)
return swap_info
def _get_param_swap_info(self, parameter):
param_id = OptimizerSwapper.parameter_id(parameter)
swap_info = self.swap_params_info.get(param_id, None)
if swap_info is not None:
self._update_param_state_info(swap_info, parameter)
return swap_info
def _start_timer(self, name):
if self.timers:
self.timers(name).start()
def _stop_timer(self, name):
if self.timers:
self.timers(name).stop()
def _log_timers(self, name_list, force=False):
if self.timers and (SWAPPER_DEBUG_MODE or force):
self.timers.log(name_list)
def _io_aligned_numel(self, numel):
remainder = numel % self.numel_alignment
return numel if remainder == 0 else (numel + self.numel_alignment - remainder)