🧺 [2/N] Refactor _generate in GRPO/RLOO: Use prompt_ids from generation (#4152)

This commit is contained in:
Quentin Gallouédec
2025-10-07 12:11:34 -06:00
committed by GitHub
parent 30cf68a97b
commit cc578b6b14
5 changed files with 117 additions and 62 deletions

View File

@ -74,36 +74,42 @@ class TestVLLMClientServer(TrlTestCase):
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(outputs, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(outputs) == 2 * len(prompts)
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
@ -148,36 +154,42 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
"completion_ids"
]
# Check that the output is a list
assert isinstance(outputs, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is 2 times the number of prompts
assert len(outputs) == 2 * len(prompts)
assert len(completion_ids) == 2 * len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
for seq in completion_ids:
assert len(seq) <= 32
def test_update_model_params(self):
@ -224,16 +236,22 @@ class TestVLLMClientServerTP(TrlTestCase):
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
@ -280,16 +298,22 @@ class TestVLLMClientServerDP(TrlTestCase):
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)["completion_ids"]
outputs = self.client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
# Check that the output is a list
assert isinstance(outputs, list)
# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)
# Check that the number of generated sequences is equal to the number of prompts
assert len(outputs) == len(prompts)
# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)
# Check that the generated sequences are lists of integers
for seq in outputs:
# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)
def test_update_model_params(self):
@ -336,9 +360,13 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase):
# Test basic functionality
prompts = ["Hello, AI!"]
outputs = client.generate(prompts)["completion_ids"]
assert isinstance(outputs, list)
assert len(outputs) == len(prompts)
outputs = client.generate(prompts)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]
assert isinstance(prompt_ids, list)
assert len(prompt_ids) == len(prompts)
assert isinstance(completion_ids, list)
assert len(completion_ids) == len(prompts)
client.close_communicator()

View File

@ -83,8 +83,12 @@ class VLLMClient:
>>> client = VLLMClient()
>>> client.generate(["Hello, AI!", "Tell me a joke"])
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
{'prompt_ids': [[9707, 11, 15235, 0],
[40451, 752, 264, 21646]],
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}
>>> from transformers import AutoModelForCausalLM
@ -212,6 +216,8 @@ class VLLMClient:
Returns:
`dict` with keys:
- `prompt_ids` (`list[list[int]]`):
List of lists of token IDs representing the tokenized input prompts.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each prompt.
- `logprobs` (`list[list[float]]`):
@ -246,7 +252,11 @@ class VLLMClient:
)
if response.status_code == 200:
json_response = response.json()
return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]}
return {
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
}
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")

View File

@ -499,6 +499,7 @@ def main(script_args: ScriptArguments):
generation_kwargs: dict = field(default_factory=dict)
class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[float]]
@ -532,6 +533,7 @@ def main(script_args: ScriptArguments):
Returns:
`GenerateResponse`:
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
generated completions.
@ -543,7 +545,11 @@ def main(script_args: ScriptArguments):
Example response:
```json
{"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]}
{
"prompt_ids": [[101, 102], [201, 202]],
"completion_ids": [[103, 104, 105], [203, 204, 205]],
"logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]
}
```
"""
request.images = request.images or [None] * len(request.prompts)
@ -596,13 +602,14 @@ def main(script_args: ScriptArguments):
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
prompt_ids = [output.prompt_token_ids for output in all_outputs]
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
logprobs: list[list[float]] = [
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]
return {"completion_ids": completion_ids, "logprobs": logprobs}
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
class InitCommunicatorRequest(BaseModel):
host: str

View File

@ -1101,11 +1101,12 @@ class GRPOTrainer(BaseTrainer):
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
@ -1187,19 +1188,23 @@ class GRPOTrainer(BaseTrainer):
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
obj_list = [payload]
broadcast_object_list(obj_list, from_process=0)
all_completion_ids, all_logprobs = obj_list[0]
all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
prompt_ids = all_prompt_ids[process_slice]
completion_ids = all_completion_ids[process_slice]
logprobs = all_logprobs[process_slice]
@ -1254,6 +1259,7 @@ class GRPOTrainer(BaseTrainer):
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[next(iter(lp.values())).logprob for lp in output.logprobs]
@ -1266,9 +1272,11 @@ class GRPOTrainer(BaseTrainer):
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
logprobs = all_logprobs[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
logprobs = all_logprobs
@ -1311,10 +1319,7 @@ class GRPOTrainer(BaseTrainer):
else:
# Regular generation path
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
with (
profiling_context(self, "transformers.generate"),

View File

@ -1090,11 +1090,12 @@ class RLOOTrainer(BaseTrainer):
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
@ -1176,19 +1177,23 @@ class RLOOTrainer(BaseTrainer):
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
else:
payload = None
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
obj_list = [payload]
broadcast_object_list(obj_list, from_process=0)
all_completion_ids, _ = obj_list[0]
all_prompt_ids, all_completion_ids, _ = obj_list[0]
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
prompt_ids = all_prompt_ids[process_slice]
completion_ids = all_completion_ids[process_slice]
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
@ -1241,6 +1246,7 @@ class RLOOTrainer(BaseTrainer):
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
if self.vllm_tensor_parallel_size > 1:
@ -1248,8 +1254,10 @@ class RLOOTrainer(BaseTrainer):
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
prompt_ids = all_prompt_ids[tp_slice]
completion_ids = all_completion_ids[tp_slice]
else:
prompt_ids = all_prompt_ids
completion_ids = all_completion_ids
if self.args.vllm_enable_sleep_mode:
@ -1290,10 +1298,7 @@ class RLOOTrainer(BaseTrainer):
else:
# Regular generation path
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
with (
profiling_context(self, "transformers.generate"),