mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
set token scaling flag
This commit is contained in:
@ -818,7 +818,7 @@ class SFTTrainer(BaseTrainer):
|
||||
)
|
||||
|
||||
# Loss function
|
||||
if args.loss_type == "nll":
|
||||
if args.loss_type == "nll" or args.use_liger_kernel:
|
||||
pass # use the default loss
|
||||
elif args.loss_type == "dft":
|
||||
if compute_loss_func is not None:
|
||||
@ -1095,9 +1095,10 @@ class SFTTrainer(BaseTrainer):
|
||||
|
||||
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
|
||||
inputs["use_cache"] = False
|
||||
# Request token accuracy from Liger kernel if used
|
||||
# Request token accuracy from Liger kernel and set token scaling if using DFT loss
|
||||
if self.args.use_liger_kernel:
|
||||
inputs["return_token_accuracy"] = True
|
||||
inputs["use_token_scaling"] = self.args.loss_type == "dft"
|
||||
|
||||
(loss, outputs) = super().compute_loss(
|
||||
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
|
||||
|
Reference in New Issue
Block a user