set token scaling flag

This commit is contained in:
Kashif Rasul
2025-10-18 13:31:30 +02:00
parent 622ee4ac5f
commit 1b9ec9f0cd

View File

@ -818,7 +818,7 @@ class SFTTrainer(BaseTrainer):
) )
# Loss function # Loss function
if args.loss_type == "nll": if args.loss_type == "nll" or args.use_liger_kernel:
pass # use the default loss pass # use the default loss
elif args.loss_type == "dft": elif args.loss_type == "dft":
if compute_loss_func is not None: if compute_loss_func is not None:
@ -1095,9 +1095,10 @@ class SFTTrainer(BaseTrainer):
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
inputs["use_cache"] = False inputs["use_cache"] = False
# Request token accuracy from Liger kernel if used # Request token accuracy from Liger kernel and set token scaling if using DFT loss
if self.args.use_liger_kernel: if self.args.use_liger_kernel:
inputs["return_token_accuracy"] = True inputs["return_token_accuracy"] = True
inputs["use_token_scaling"] = self.args.loss_type == "dft"
(loss, outputs) = super().compute_loss( (loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch