mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
add support for tensor learning rate (vs scalar) (#7633)
This change is intended to help enable support for using a tensor learning rate value vs a scalar ones. We found this helpful in cases where the optimizer is torch.compiled (in such cases changing the scalar LR value could cause recompilation degrading the performance). The implementation allows the model script to determine the type of LR value used by setting the initial value. Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
@ -13,6 +13,7 @@ import argparse
|
||||
from torch.optim import Optimizer
|
||||
import math
|
||||
from deepspeed.utils import logger
|
||||
from torch import tensor, is_tensor
|
||||
|
||||
LR_SCHEDULE = 'lr_schedule'
|
||||
LR_RANGE_TEST = 'LRRangeTest'
|
||||
@ -249,6 +250,9 @@ def get_lr_from_config(config):
|
||||
|
||||
def update_lr(param_groups, lrs):
|
||||
for param_group, lr in zip(param_groups, lrs):
|
||||
# new LR should match the type of current LR for scalar and Tensor LR support
|
||||
if is_tensor(param_group['lr']):
|
||||
lr = tensor([lr], device=param_group['lr'].device)
|
||||
param_group['lr'] = lr
|
||||
return [group['lr'] for group in param_groups]
|
||||
|
||||
|
Reference in New Issue
Block a user