add support for tensor learning rate (vs scalar) (#7633)

This change is intended to help enable support for using a tensor
learning rate value vs a scalar ones.
We found this helpful in cases where the optimizer is torch.compiled (in
such cases changing the scalar LR value could cause recompilation
degrading the performance).
The implementation allows the model script to determine the type of LR
value used by setting the initial value.

Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
Nir Sonnenschein
2025-10-20 08:32:29 +03:00
committed by GitHub
parent 2734a6a15f
commit 407708cdb6

View File

@ -13,6 +13,7 @@ import argparse
from torch.optim import Optimizer
import math
from deepspeed.utils import logger
from torch import tensor, is_tensor
LR_SCHEDULE = 'lr_schedule'
LR_RANGE_TEST = 'LRRangeTest'
@ -249,6 +250,9 @@ def get_lr_from_config(config):
def update_lr(param_groups, lrs):
for param_group, lr in zip(param_groups, lrs):
# new LR should match the type of current LR for scalar and Tensor LR support
if is_tensor(param_group['lr']):
lr = tensor([lr], device=param_group['lr'].device)
param_group['lr'] = lr
return [group['lr'] for group in param_groups]