mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🧺 [2/N] Refactor _generate
in GRPO/RLOO: Use prompt_ids
from generation (#4152)
This commit is contained in:
committed by
GitHub
parent
30cf68a97b
commit
cc578b6b14
@ -74,36 +74,42 @@ class TestVLLMClientServer(TrlTestCase):
|
||||
|
||||
def test_generate(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts)["completion_ids"]
|
||||
outputs = self.client.generate(prompts)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is equal to the number of prompts
|
||||
assert len(outputs) == len(prompts)
|
||||
# Check that the number of sequences are equal to the number of prompts
|
||||
assert len(prompt_ids) == len(prompts)
|
||||
assert len(completion_ids) == len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_generate_with_params(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
"completion_ids"
|
||||
]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is 2 times the number of prompts
|
||||
assert len(outputs) == 2 * len(prompts)
|
||||
assert len(completion_ids) == 2 * len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
# Check that the length of the generated sequences is less than or equal to 32
|
||||
for seq in outputs:
|
||||
for seq in completion_ids:
|
||||
assert len(seq) <= 32
|
||||
|
||||
def test_update_model_params(self):
|
||||
@ -148,36 +154,42 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
||||
|
||||
def test_generate(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts)["completion_ids"]
|
||||
outputs = self.client.generate(prompts)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is equal to the number of prompts
|
||||
assert len(outputs) == len(prompts)
|
||||
# Check that the number of sequences are equal to the number of prompts
|
||||
assert len(prompt_ids) == len(prompts)
|
||||
assert len(completion_ids) == len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_generate_with_params(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||
"completion_ids"
|
||||
]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is 2 times the number of prompts
|
||||
assert len(outputs) == 2 * len(prompts)
|
||||
assert len(completion_ids) == 2 * len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
# Check that the length of the generated sequences is less than or equal to 32
|
||||
for seq in outputs:
|
||||
for seq in completion_ids:
|
||||
assert len(seq) <= 32
|
||||
|
||||
def test_update_model_params(self):
|
||||
@ -224,16 +236,22 @@ class TestVLLMClientServerTP(TrlTestCase):
|
||||
|
||||
def test_generate(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts)["completion_ids"]
|
||||
outputs = self.client.generate(prompts)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is equal to the number of prompts
|
||||
assert len(outputs) == len(prompts)
|
||||
# Check that the number of sequences are equal to the number of prompts
|
||||
assert len(prompt_ids) == len(prompts)
|
||||
assert len(completion_ids) == len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_update_model_params(self):
|
||||
@ -280,16 +298,22 @@ class TestVLLMClientServerDP(TrlTestCase):
|
||||
|
||||
def test_generate(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts)["completion_ids"]
|
||||
outputs = self.client.generate(prompts)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
|
||||
# Check that the output is a list
|
||||
assert isinstance(outputs, list)
|
||||
# Check that the outputs are lists
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert isinstance(completion_ids, list)
|
||||
|
||||
# Check that the number of generated sequences is equal to the number of prompts
|
||||
assert len(outputs) == len(prompts)
|
||||
# Check that the number of sequences are equal to the number of prompts
|
||||
assert len(prompt_ids) == len(prompts)
|
||||
assert len(completion_ids) == len(prompts)
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
# Check that the sequences are lists of integers
|
||||
for seq in prompt_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
for seq in completion_ids:
|
||||
assert all(isinstance(tok, int) for tok in seq)
|
||||
|
||||
def test_update_model_params(self):
|
||||
@ -336,9 +360,13 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
||||
|
||||
# Test basic functionality
|
||||
prompts = ["Hello, AI!"]
|
||||
outputs = client.generate(prompts)["completion_ids"]
|
||||
assert isinstance(outputs, list)
|
||||
assert len(outputs) == len(prompts)
|
||||
outputs = client.generate(prompts)
|
||||
prompt_ids = outputs["prompt_ids"]
|
||||
completion_ids = outputs["completion_ids"]
|
||||
assert isinstance(prompt_ids, list)
|
||||
assert len(prompt_ids) == len(prompts)
|
||||
assert isinstance(completion_ids, list)
|
||||
assert len(completion_ids) == len(prompts)
|
||||
|
||||
client.close_communicator()
|
||||
|
||||
|
@ -83,8 +83,12 @@ class VLLMClient:
|
||||
|
||||
>>> client = VLLMClient()
|
||||
>>> client.generate(["Hello, AI!", "Tell me a joke"])
|
||||
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
|
||||
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
|
||||
{'prompt_ids': [[9707, 11, 15235, 0],
|
||||
[40451, 752, 264, 21646]],
|
||||
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
|
||||
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
|
||||
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
|
||||
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}
|
||||
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
|
||||
@ -212,6 +216,8 @@ class VLLMClient:
|
||||
|
||||
Returns:
|
||||
`dict` with keys:
|
||||
- `prompt_ids` (`list[list[int]]`):
|
||||
List of lists of token IDs representing the tokenized input prompts.
|
||||
- `completion_ids` (`list[list[int]]`):
|
||||
List of lists of token IDs representing the model-generated completions for each prompt.
|
||||
- `logprobs` (`list[list[float]]`):
|
||||
@ -246,7 +252,11 @@ class VLLMClient:
|
||||
)
|
||||
if response.status_code == 200:
|
||||
json_response = response.json()
|
||||
return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]}
|
||||
return {
|
||||
"prompt_ids": json_response["prompt_ids"],
|
||||
"completion_ids": json_response["completion_ids"],
|
||||
"logprobs": json_response["logprobs"],
|
||||
}
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
|
@ -499,6 +499,7 @@ def main(script_args: ScriptArguments):
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
prompt_ids: list[list[int]]
|
||||
completion_ids: list[list[int]]
|
||||
logprobs: list[list[float]]
|
||||
|
||||
@ -532,6 +533,7 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
Returns:
|
||||
`GenerateResponse`:
|
||||
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
|
||||
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
|
||||
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
|
||||
generated completions.
|
||||
@ -543,7 +545,11 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]}
|
||||
{
|
||||
"prompt_ids": [[101, 102], [201, 202]],
|
||||
"completion_ids": [[103, 104, 105], [203, 204, 205]],
|
||||
"logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]
|
||||
}
|
||||
```
|
||||
"""
|
||||
request.images = request.images or [None] * len(request.prompts)
|
||||
@ -596,13 +602,14 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
# Flatten and combine all results
|
||||
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
|
||||
prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
|
||||
logprobs: list[list[float]] = [
|
||||
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
|
||||
for outputs in all_outputs
|
||||
for output in outputs.outputs
|
||||
]
|
||||
return {"completion_ids": completion_ids, "logprobs": logprobs}
|
||||
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
|
||||
|
||||
class InitCommunicatorRequest(BaseModel):
|
||||
host: str
|
||||
|
@ -1101,11 +1101,12 @@ class GRPOTrainer(BaseTrainer):
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
@ -1187,19 +1188,23 @@ class GRPOTrainer(BaseTrainer):
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
payload = (output["completion_ids"], output["logprobs"])
|
||||
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||
else:
|
||||
payload = None
|
||||
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
||||
obj_list = [payload]
|
||||
broadcast_object_list(obj_list, from_process=0)
|
||||
all_completion_ids, all_logprobs = obj_list[0]
|
||||
all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]
|
||||
|
||||
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
|
||||
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
|
||||
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
prompt_ids = all_prompt_ids[process_slice]
|
||||
completion_ids = all_completion_ids[process_slice]
|
||||
logprobs = all_logprobs[process_slice]
|
||||
|
||||
@ -1254,6 +1259,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||
|
||||
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||
all_logprobs = [
|
||||
[next(iter(lp.values())).logprob for lp in output.logprobs]
|
||||
@ -1266,9 +1272,11 @@ class GRPOTrainer(BaseTrainer):
|
||||
# Each rank generates all outputs — we keep only our share.
|
||||
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||
prompt_ids = all_prompt_ids[tp_slice]
|
||||
completion_ids = all_completion_ids[tp_slice]
|
||||
logprobs = all_logprobs[tp_slice]
|
||||
else:
|
||||
prompt_ids = all_prompt_ids
|
||||
completion_ids = all_completion_ids
|
||||
logprobs = all_logprobs
|
||||
|
||||
@ -1311,10 +1319,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
|
||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
||||
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
|
||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
|
@ -1090,11 +1090,12 @@ class RLOOTrainer(BaseTrainer):
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
@ -1176,19 +1177,23 @@ class RLOOTrainer(BaseTrainer):
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
payload = (output["completion_ids"], output["logprobs"])
|
||||
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||
else:
|
||||
payload = None
|
||||
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
||||
obj_list = [payload]
|
||||
broadcast_object_list(obj_list, from_process=0)
|
||||
all_completion_ids, _ = obj_list[0]
|
||||
all_prompt_ids, all_completion_ids, _ = obj_list[0]
|
||||
|
||||
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
|
||||
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
|
||||
|
||||
process_slice = slice(
|
||||
self.accelerator.process_index * len(prompts),
|
||||
(self.accelerator.process_index + 1) * len(prompts),
|
||||
)
|
||||
prompt_ids = all_prompt_ids[process_slice]
|
||||
completion_ids = all_completion_ids[process_slice]
|
||||
|
||||
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
|
||||
@ -1241,6 +1246,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||
|
||||
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||
|
||||
if self.vllm_tensor_parallel_size > 1:
|
||||
@ -1248,8 +1254,10 @@ class RLOOTrainer(BaseTrainer):
|
||||
# Each rank generates all outputs — we keep only our share.
|
||||
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||
prompt_ids = all_prompt_ids[tp_slice]
|
||||
completion_ids = all_completion_ids[tp_slice]
|
||||
else:
|
||||
prompt_ids = all_prompt_ids
|
||||
completion_ids = all_completion_ids
|
||||
|
||||
if self.args.vllm_enable_sleep_mode:
|
||||
@ -1290,10 +1298,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
|
||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
||||
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
|
||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
|
Reference in New Issue
Block a user