mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
set token scaling flag
This commit is contained in:
@ -818,7 +818,7 @@ class SFTTrainer(BaseTrainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Loss function
|
# Loss function
|
||||||
if args.loss_type == "nll":
|
if args.loss_type == "nll" or args.use_liger_kernel:
|
||||||
pass # use the default loss
|
pass # use the default loss
|
||||||
elif args.loss_type == "dft":
|
elif args.loss_type == "dft":
|
||||||
if compute_loss_func is not None:
|
if compute_loss_func is not None:
|
||||||
@ -1095,9 +1095,10 @@ class SFTTrainer(BaseTrainer):
|
|||||||
|
|
||||||
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
|
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
|
||||||
inputs["use_cache"] = False
|
inputs["use_cache"] = False
|
||||||
# Request token accuracy from Liger kernel if used
|
# Request token accuracy from Liger kernel and set token scaling if using DFT loss
|
||||||
if self.args.use_liger_kernel:
|
if self.args.use_liger_kernel:
|
||||||
inputs["return_token_accuracy"] = True
|
inputs["return_token_accuracy"] = True
|
||||||
|
inputs["use_token_scaling"] = self.args.loss_type == "dft"
|
||||||
|
|
||||||
(loss, outputs) = super().compute_loss(
|
(loss, outputs) = super().compute_loss(
|
||||||
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
|
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
|
||||||
|
Reference in New Issue
Block a user