mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
Merge branch 'refactor_generate_5' into tool-call-finally
This commit is contained in:
@ -1230,15 +1230,10 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.llm.sleep(level=1)
|
||||
|
||||
elif self.use_transformers_paged:
|
||||
processor_kwargs = {
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
@ -1274,12 +1269,11 @@ class GRPOTrainer(BaseTrainer):
|
||||
"padding_side": "left",
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True, tools=self.tools
|
||||
conversation=prompts, **processor_kwargs, tokenize=True, tools=self.tools, return_dict=True
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
|
@ -1189,15 +1189,13 @@ class RLOOTrainer(BaseTrainer):
|
||||
self.llm.sleep(level=1)
|
||||
|
||||
elif self.use_transformers_paged:
|
||||
processor_kwargs = {
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
conversation=prompts,
|
||||
**processor_kwargs,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
@ -1232,12 +1230,11 @@ class RLOOTrainer(BaseTrainer):
|
||||
"padding_side": "left",
|
||||
"max_length": self.max_prompt_length,
|
||||
"truncation": True,
|
||||
"return_dict": True,
|
||||
"add_special_tokens": False,
|
||||
}
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
generate_inputs = self.processing_class.apply_chat_template(
|
||||
conversation=prompts, **processor_kwargs, tokenize=True
|
||||
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
|
||||
)
|
||||
else:
|
||||
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
|
||||
|
Reference in New Issue
Block a user