mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 23:46:02 +08:00
Fix invalid f-strings detected by ruff. --------- Signed-off-by: cyy <cyyever@outlook.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Michael Wyatt <michael.wyatt@snowflake.com>
482 lines
17 KiB
Python
482 lines
17 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import torch
|
|
from dataclasses import dataclass
|
|
from deepspeed import comm as dist
|
|
from typing import Dict, List, Callable
|
|
|
|
|
|
@dataclass
|
|
class fragment_address:
|
|
numel: int
|
|
start: int
|
|
|
|
|
|
@dataclass
|
|
class tensor_fragment:
|
|
lp_fragment: torch.Tensor
|
|
lp_fragment_address: fragment_address
|
|
hp_fragment: torch.Tensor
|
|
hp_fragment_address: fragment_address
|
|
gradient_dict: Dict
|
|
offload_gradient_dict: Dict
|
|
use_offload: bool
|
|
param_group_index: int
|
|
optim_fragment: Dict = None
|
|
|
|
def update_hp(self):
|
|
self.hp_fragment.data.copy_(self.lp_fragment.data)
|
|
|
|
def update_lp(self):
|
|
self.lp_fragment.data.copy_(self.hp_fragment.data)
|
|
|
|
def get_optim_state_fragment(self, key):
|
|
if key in self.optim_fragment:
|
|
return self.optim_fragment[key]
|
|
else:
|
|
raise ValueError(f'{key} not found in optimizer state fragment')
|
|
|
|
def set_optim_state_fragment(self, flat_hp_partition, optim_fragment):
|
|
self.optim_fragment = {
|
|
key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel)
|
|
for key, value in optim_fragment.items()
|
|
if torch.is_tensor(value) and value.shape == flat_hp_partition.shape
|
|
}
|
|
|
|
def get_hp_fragment_address(self):
|
|
return self.hp_fragment_address
|
|
|
|
def get_optim_state_keys(self):
|
|
return list(self.optim_fragment.keys())
|
|
|
|
def get_hp_fragment(self, optim_state_key=None):
|
|
if optim_state_key is None:
|
|
return self.hp_fragment
|
|
return self.get_optim_state_fragment(optim_state_key)
|
|
|
|
def get_lp_grad_fragment(self, index_in_param_group):
|
|
if self.use_offload:
|
|
gradient_dict = self.offload_gradient_dict
|
|
else:
|
|
gradient_dict = self.gradient_dict
|
|
|
|
if self.param_group_index not in gradient_dict or gradient_dict[self.param_group_index] is None:
|
|
raise ValueError("Gradients are only available immediately after backward and before engine step")
|
|
|
|
return gradient_dict[self.param_group_index][index_in_param_group]
|
|
|
|
|
|
def map_to_flat_opt_states(flat_hp_tensor, lp_tensors, optim_state, opt_keys):
|
|
for key in opt_keys:
|
|
hp_param = flat_hp_tensor
|
|
buffer = torch.zeros_like(hp_param)
|
|
|
|
for lp in lp_tensors:
|
|
if lp._hp_mapping is not None:
|
|
hp_fragment_address = lp._hp_mapping.get_hp_fragment_address()
|
|
hp_fragment = buffer.narrow(0, hp_fragment_address.start, hp_fragment_address.numel)
|
|
hp_fragment.data.copy_(lp._hp_mapping.get_hp_fragment(optim_state_key=key).data)
|
|
lp._hp_mapping.hp_fragment = hp_fragment
|
|
|
|
optim_state[hp_param][key] = buffer
|
|
|
|
|
|
def get_full_hp_param(self, optim_state_key=None):
|
|
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
|
|
if self._hp_mapping is not None:
|
|
lp_frag_address = self._hp_mapping.lp_fragment_address
|
|
reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
|
|
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
|
|
reduce_fragment.data.copy_(hp_fragment.data)
|
|
dist.all_reduce(reduce_buffer, group=self._dp_group)
|
|
return reduce_buffer.reshape_as(self)
|
|
|
|
|
|
def set_full_hp_param(self, value, optim_state_key=None):
|
|
if self._hp_mapping is not None:
|
|
lp_frag_address = self._hp_mapping.lp_fragment_address
|
|
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
|
|
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
|
|
hp_fragment.data.copy_(value_fragment.data)
|
|
|
|
|
|
def get_full_hp_grad(self):
|
|
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
|
|
if self._hp_mapping is not None:
|
|
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
|
|
hp_grad_fragment = lp_grad_fragment.to(torch.float32).flatten()
|
|
|
|
lp_frag_address = self._hp_mapping.lp_fragment_address
|
|
reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
|
|
|
|
if self.view(-1).shape == hp_grad_fragment.shape:
|
|
reduce_buffer.data.copy_(hp_grad_fragment.data)
|
|
else:
|
|
reduce_fragment.data.copy_(hp_grad_fragment.data)
|
|
|
|
dist.all_reduce(reduce_buffer, group=self._dp_group)
|
|
return reduce_buffer.reshape_as(self)
|
|
|
|
|
|
def set_full_hp_grad(self, value):
|
|
if self._hp_mapping is not None:
|
|
lp_grad_fragment = self._hp_mapping.get_lp_grad_fragment(self._index_in_param_group)
|
|
lp_frag_address = self._hp_mapping.lp_fragment_address
|
|
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
|
|
lp_grad_fragment.data.copy_(value_fragment.data.reshape_as(lp_grad_fragment.data))
|
|
if hasattr(self, '_zero_optimizer'):
|
|
self._zero_optimizer.update_offload_overflow_tracker(value)
|
|
|
|
|
|
def safe_get_full_fp32_param(param):
|
|
"""Assemble and return the fp32 parameter of a low-precision (e.g., fp16) parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
# ZeRO stage 3 param
|
|
if hasattr(param, 'ds_id'):
|
|
return param._z3_optimizer.get_full_hp_param(param)
|
|
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
if hasattr(param, '_hp_mapping'):
|
|
return param.get_full_hp_param()
|
|
return None
|
|
|
|
|
|
def safe_set_full_fp32_param(param, value):
|
|
"""Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
value (``torch.Tensor``): New value
|
|
"""
|
|
# ZeRO stage 3 param
|
|
if hasattr(param, 'ds_id'):
|
|
param._z3_optimizer.set_full_hp_param(value, param)
|
|
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
if hasattr(param, '_hp_mapping'):
|
|
param.set_full_hp_param(value)
|
|
|
|
|
|
def safe_get_full_optimizer_state(param, optim_state_key):
|
|
"""Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
# ZeRO stage 3 param
|
|
if hasattr(param, 'ds_id'):
|
|
return param._z3_optimizer.get_full_hp_param(param, optim_state_key)
|
|
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
if hasattr(param, '_hp_mapping'):
|
|
return param.get_full_hp_param(optim_state_key)
|
|
return None
|
|
|
|
|
|
def safe_set_full_optimizer_state(param, value, optim_state_key):
|
|
"""Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
value (``torch.Tensor``): New value
|
|
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
|
|
"""
|
|
# ZeRO stage 3 param
|
|
if hasattr(param, 'ds_id'):
|
|
param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)
|
|
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
if hasattr(param, '_hp_mapping'):
|
|
param.set_full_hp_param(value, optim_state_key)
|
|
|
|
|
|
# TODO: Figure out the correct return dtype
|
|
def safe_get_full_grad(param):
|
|
"""
|
|
Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
|
|
The return data type is that used for gradient accumulation. This is usually the param data type,
|
|
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
if param.grad is not None:
|
|
return param.grad
|
|
|
|
# ZeRO stage 3 param
|
|
if hasattr(param, 'ds_id'):
|
|
return param._z3_optimizer.get_fp32_grad_for_param(param)
|
|
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
if hasattr(param, '_hp_mapping'):
|
|
return param.get_full_hp_grad()
|
|
|
|
return None
|
|
|
|
|
|
def safe_set_full_grad(param, value):
|
|
"""
|
|
Update the partitioned gradient of a low-precision (e.g., fp16) parameter.
|
|
To avoid precision issues, the update value should have the data type of
|
|
gradient accumulation.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
value (``torch.Tensor``): The un-partitioned new gradient value.
|
|
"""
|
|
if param.grad is not None:
|
|
param.grad.copy_(value)
|
|
elif hasattr(param, 'ds_id'):
|
|
# ZeRO stage 3 param
|
|
param._z3_optimizer.set_fp32_grad_for_param(value, param)
|
|
elif hasattr(param, '_hp_mapping'):
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
param.set_full_hp_grad(value)
|
|
|
|
|
|
### Local API START ###
|
|
def safe_get_local_grad(param):
|
|
"""
|
|
Get the local gradient partition of a ZeRO-3 partitioned parameter.
|
|
The return data type is that used for gradient accumulation. This is usually the param data type,
|
|
but could also be different (e.g., bf16 param training with fp32 gradient accumulation).
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
return param._z3_optimizer.get_local_fp32_grad_for_param(param)
|
|
|
|
|
|
def safe_set_local_grad(param, value):
|
|
"""
|
|
Update the local gradient partition of a ZeRO-3 partitioned parameter.
|
|
To avoid precision issues, the update value should have the data type of
|
|
gradient accumulation.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter.
|
|
value (``torch.Tensor``): New value of local gradient partition.
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
param._z3_optimizer.set_local_grad_for_param(value, param)
|
|
|
|
|
|
def safe_get_local_fp32_param(param):
|
|
"""Get the local partition of a ZeRO-3 partitioned parameter in fp32 precision.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter.
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
return param._z3_optimizer.get_local_fp32_param(param)
|
|
|
|
|
|
def safe_get_local_optimizer_state(param, optim_state_key):
|
|
"""Get the local optimizer state partition of ZeRO-3 partitioned parameter in fp32 precision.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter
|
|
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
|
|
|
|
Returns:
|
|
Union[torch.Tensor, None]: A tensor on accelerator device
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
return param._z3_optimizer.get_local_fp32_param(param, optim_state_key)
|
|
|
|
|
|
def safe_set_local_optimizer_state(param, value, optim_state_key):
|
|
"""Update the local optimizer state partition of a ZeRO-3 partitioned parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter.
|
|
value (``torch.Tensor``): New value of local optimizer state partition.
|
|
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer).
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
param._z3_optimizer.set_local_hp_param(value, param, optim_state_key)
|
|
|
|
|
|
def safe_set_local_fp32_param(param, value):
|
|
"""Update the local partition of ZeRO-3 partitioned parameter.
|
|
|
|
Args:
|
|
param (``torch.nn.Parameter``): A model parameter.
|
|
value (``torch.Tensor``): New value of local parameter partition.
|
|
"""
|
|
assert hasattr(param, 'ds_id'), 'This API is only defined for ZeRO-3 partitioned parameters'
|
|
param._z3_optimizer.set_local_hp_param(value, param)
|
|
|
|
|
|
### Local API END ###
|
|
|
|
|
|
### VECTORIZED API BEGIN ###
|
|
def safe_update_full_grad_vectorized(param_list: List[torch.nn.Parameter], update_func: Callable):
|
|
"""
|
|
Vectorized update of the partitioned gradients of a list of low-precision (e.g., fp16) parameters.
|
|
To avoid precision issues, the update value should have the data type of
|
|
gradient accumulation.
|
|
|
|
Args:
|
|
param_list (``List[torch.nn.Parameter]``): List of model parameters
|
|
update_func (``torch.Tensor``): A function that takes current full gradient value and returns new one.
|
|
"""
|
|
partitioned_grad_params = []
|
|
for p in param_list:
|
|
if p.grad is not None:
|
|
p.grad.copy_(update_func(p.grad, p))
|
|
elif p.requires_grad:
|
|
partitioned_grad_params.append(p)
|
|
|
|
if not partitioned_grad_params:
|
|
return
|
|
|
|
if hasattr(partitioned_grad_params[0], 'ds_id'):
|
|
# ZeRO stage 3 param
|
|
partitioned_grad_params[0]._z3_optimizer.update_fp32_grad_for_param_vectorized(
|
|
update_func, partitioned_grad_params)
|
|
elif hasattr(partitioned_grad_params[0], '_hp_mapping'):
|
|
# ZeRO stage 1, 2, and bf16_optimizer params
|
|
for p in partitioned_grad_params:
|
|
old_grad = safe_get_full_grad(p)
|
|
new_grad = update_func(old_grad, p)
|
|
p.set_full_hp_grad(new_grad)
|
|
|
|
|
|
### VECTORIZED API END ###
|
|
|
|
|
|
def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
|
|
param_group_index, partition_start, partition_size):
|
|
lp_end = lp_param.numel() + lp_start
|
|
hp_start = partition_start
|
|
hp_end = partition_start + partition_size
|
|
|
|
fragment_start = max(lp_start, hp_start)
|
|
fragment_end = min(lp_end, hp_end)
|
|
assert fragment_start < fragment_end, \
|
|
f'fragment start {fragment_start} should be < fragment_end {fragment_end}'
|
|
|
|
fragment_numel = fragment_end - fragment_start
|
|
hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel)
|
|
hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel)
|
|
|
|
lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel)
|
|
lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel)
|
|
|
|
return tensor_fragment(lp_fragment=lp_fragment_tensor,
|
|
lp_fragment_address=lp_frag_address,
|
|
hp_fragment=hp_fragment_tensor,
|
|
hp_fragment_address=hp_frag_address,
|
|
gradient_dict=gradient_dict,
|
|
offload_gradient_dict=offload_gradient_dict,
|
|
use_offload=use_offload,
|
|
param_group_index=param_group_index)
|
|
|
|
|
|
'''
|
|
Logic for lp_param to hp_param mapping
|
|
|
|
lp lp0 lp1 lp2 lp3 lp4 <------- indices/names
|
|
lp [ ][ ][ ][ ][ ] <-------- tensors
|
|
flat_lp [ ] <-------- flat lp params
|
|
flat_hp [ ] <------------------ flat hp partition on current rank
|
|
full_hp [ ] <------- full flat hp params
|
|
|
|
|
|
lp2
|
|
full numel = 16
|
|
lp_frag
|
|
numel = 12
|
|
frag_start = 3
|
|
frag_end = 15
|
|
hp_frag
|
|
numel = 12
|
|
frag_start = 0
|
|
frag_end = 11
|
|
|
|
hp_frag.copy_(lp_frag)
|
|
|
|
|
|
lp3:
|
|
full numel = 4
|
|
lp_frag
|
|
numel = 4
|
|
start = 0
|
|
end = 3
|
|
hp_frag
|
|
numel = 4
|
|
start = 12
|
|
end = 15
|
|
|
|
|
|
lp4:
|
|
full numel = 12
|
|
lp_frag
|
|
numel = 4
|
|
start = 0
|
|
end = 3
|
|
hp_frag
|
|
numel = 4
|
|
start = 16
|
|
end = 19
|
|
|
|
|
|
|
|
Visual depiction of above
|
|
lp { }
|
|
flat_lp [ ]
|
|
flat_hp ( )
|
|
|
|
|
|
flat_lp [ { ( } ) ]
|
|
lx hx ly hy
|
|
ly-hx
|
|
|
|
|
|
lp { }
|
|
flat_lp [ ]
|
|
flat_hp ( )
|
|
|
|
|
|
flat_lp [ ( { ) } ]
|
|
hx lx hy ly
|
|
hy-lx
|
|
|
|
lp { }
|
|
flat_lp [ ]
|
|
flat_hp ( )
|
|
|
|
|
|
flat_lp [ ( { } ) ]
|
|
hx lx ly hy
|
|
ly-lx
|
|
|
|
lp -> (lx, hy)
|
|
flat_hp -> (hx, hy)
|
|
'''
|