Merge branch 'refactor_generate_5' into tool-call-finally

This commit is contained in:
Quentin Gallouédec
2025-10-17 22:29:58 -06:00
committed by GitHub
2 changed files with 9 additions and 18 deletions

View File

@ -1230,15 +1230,10 @@ class GRPOTrainer(BaseTrainer):
self.llm.sleep(level=1)
elif self.use_transformers_paged:
processor_kwargs = {
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
@ -1274,12 +1269,11 @@ class GRPOTrainer(BaseTrainer):
"padding_side": "left",
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True, tools=self.tools
conversation=prompts, **processor_kwargs, tokenize=True, tools=self.tools, return_dict=True
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)

View File

@ -1189,15 +1189,13 @@ class RLOOTrainer(BaseTrainer):
self.llm.sleep(level=1)
elif self.use_transformers_paged:
processor_kwargs = {
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
conversation=prompts,
**processor_kwargs,
tokenize=True,
return_dict=True,
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)
@ -1232,12 +1230,11 @@ class RLOOTrainer(BaseTrainer):
"padding_side": "left",
"max_length": self.max_prompt_length,
"truncation": True,
"return_dict": True,
"add_special_tokens": False,
}
if is_conversational({"prompt": prompts[0]}):
generate_inputs = self.processing_class.apply_chat_template(
conversation=prompts, **processor_kwargs, tokenize=True
conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True
)
else:
generate_inputs = self.processing_class(text=prompts, **processor_kwargs)