mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🧺 [2/N] Refactor _generate
in GRPO/RLOO: Use prompt_ids
from generation (#4152)
This commit is contained in:
committed by
GitHub
parent
30cf68a97b
commit
cc578b6b14
@ -74,36 +74,42 @@ class TestVLLMClientServer(TrlTestCase):
|
|||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts)["completion_ids"]
|
outputs = self.client.generate(prompts)
|
||||||
|
prompt_ids = outputs["prompt_ids"]
|
||||||
|
completion_ids = outputs["completion_ids"]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the outputs are lists
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(prompt_ids, list)
|
||||||
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is equal to the number of prompts
|
# Check that the number of sequences are equal to the number of prompts
|
||||||
assert len(outputs) == len(prompts)
|
assert len(prompt_ids) == len(prompts)
|
||||||
|
assert len(completion_ids) == len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in prompt_ids:
|
||||||
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
def test_generate_with_params(self):
|
def test_generate_with_params(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||||
"completion_ids"
|
"completion_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the output is a list
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is 2 times the number of prompts
|
# Check that the number of generated sequences is 2 times the number of prompts
|
||||||
assert len(outputs) == 2 * len(prompts)
|
assert len(completion_ids) == 2 * len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the generated sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
# Check that the length of the generated sequences is less than or equal to 32
|
# Check that the length of the generated sequences is less than or equal to 32
|
||||||
for seq in outputs:
|
for seq in completion_ids:
|
||||||
assert len(seq) <= 32
|
assert len(seq) <= 32
|
||||||
|
|
||||||
def test_update_model_params(self):
|
def test_update_model_params(self):
|
||||||
@ -148,36 +154,42 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
|||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts)["completion_ids"]
|
outputs = self.client.generate(prompts)
|
||||||
|
prompt_ids = outputs["prompt_ids"]
|
||||||
|
completion_ids = outputs["completion_ids"]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the outputs are lists
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(prompt_ids, list)
|
||||||
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is equal to the number of prompts
|
# Check that the number of sequences are equal to the number of prompts
|
||||||
assert len(outputs) == len(prompts)
|
assert len(prompt_ids) == len(prompts)
|
||||||
|
assert len(completion_ids) == len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in prompt_ids:
|
||||||
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
def test_generate_with_params(self):
|
def test_generate_with_params(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
|
||||||
"completion_ids"
|
"completion_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the output is a list
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is 2 times the number of prompts
|
# Check that the number of generated sequences is 2 times the number of prompts
|
||||||
assert len(outputs) == 2 * len(prompts)
|
assert len(completion_ids) == 2 * len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the generated sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
# Check that the length of the generated sequences is less than or equal to 32
|
# Check that the length of the generated sequences is less than or equal to 32
|
||||||
for seq in outputs:
|
for seq in completion_ids:
|
||||||
assert len(seq) <= 32
|
assert len(seq) <= 32
|
||||||
|
|
||||||
def test_update_model_params(self):
|
def test_update_model_params(self):
|
||||||
@ -224,16 +236,22 @@ class TestVLLMClientServerTP(TrlTestCase):
|
|||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts)["completion_ids"]
|
outputs = self.client.generate(prompts)
|
||||||
|
prompt_ids = outputs["prompt_ids"]
|
||||||
|
completion_ids = outputs["completion_ids"]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the outputs are lists
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(prompt_ids, list)
|
||||||
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is equal to the number of prompts
|
# Check that the number of sequences are equal to the number of prompts
|
||||||
assert len(outputs) == len(prompts)
|
assert len(prompt_ids) == len(prompts)
|
||||||
|
assert len(completion_ids) == len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in prompt_ids:
|
||||||
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
def test_update_model_params(self):
|
def test_update_model_params(self):
|
||||||
@ -280,16 +298,22 @@ class TestVLLMClientServerDP(TrlTestCase):
|
|||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||||
outputs = self.client.generate(prompts)["completion_ids"]
|
outputs = self.client.generate(prompts)
|
||||||
|
prompt_ids = outputs["prompt_ids"]
|
||||||
|
completion_ids = outputs["completion_ids"]
|
||||||
|
|
||||||
# Check that the output is a list
|
# Check that the outputs are lists
|
||||||
assert isinstance(outputs, list)
|
assert isinstance(prompt_ids, list)
|
||||||
|
assert isinstance(completion_ids, list)
|
||||||
|
|
||||||
# Check that the number of generated sequences is equal to the number of prompts
|
# Check that the number of sequences are equal to the number of prompts
|
||||||
assert len(outputs) == len(prompts)
|
assert len(prompt_ids) == len(prompts)
|
||||||
|
assert len(completion_ids) == len(prompts)
|
||||||
|
|
||||||
# Check that the generated sequences are lists of integers
|
# Check that the sequences are lists of integers
|
||||||
for seq in outputs:
|
for seq in prompt_ids:
|
||||||
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
for seq in completion_ids:
|
||||||
assert all(isinstance(tok, int) for tok in seq)
|
assert all(isinstance(tok, int) for tok in seq)
|
||||||
|
|
||||||
def test_update_model_params(self):
|
def test_update_model_params(self):
|
||||||
@ -336,9 +360,13 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
|||||||
|
|
||||||
# Test basic functionality
|
# Test basic functionality
|
||||||
prompts = ["Hello, AI!"]
|
prompts = ["Hello, AI!"]
|
||||||
outputs = client.generate(prompts)["completion_ids"]
|
outputs = client.generate(prompts)
|
||||||
assert isinstance(outputs, list)
|
prompt_ids = outputs["prompt_ids"]
|
||||||
assert len(outputs) == len(prompts)
|
completion_ids = outputs["completion_ids"]
|
||||||
|
assert isinstance(prompt_ids, list)
|
||||||
|
assert len(prompt_ids) == len(prompts)
|
||||||
|
assert isinstance(completion_ids, list)
|
||||||
|
assert len(completion_ids) == len(prompts)
|
||||||
|
|
||||||
client.close_communicator()
|
client.close_communicator()
|
||||||
|
|
||||||
|
@ -83,8 +83,12 @@ class VLLMClient:
|
|||||||
|
|
||||||
>>> client = VLLMClient()
|
>>> client = VLLMClient()
|
||||||
>>> client.generate(["Hello, AI!", "Tell me a joke"])
|
>>> client.generate(["Hello, AI!", "Tell me a joke"])
|
||||||
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
|
{'prompt_ids': [[9707, 11, 15235, 0],
|
||||||
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
|
[40451, 752, 264, 21646]],
|
||||||
|
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
|
||||||
|
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
|
||||||
|
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
|
||||||
|
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}
|
||||||
|
|
||||||
>>> from transformers import AutoModelForCausalLM
|
>>> from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
@ -212,6 +216,8 @@ class VLLMClient:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`dict` with keys:
|
`dict` with keys:
|
||||||
|
- `prompt_ids` (`list[list[int]]`):
|
||||||
|
List of lists of token IDs representing the tokenized input prompts.
|
||||||
- `completion_ids` (`list[list[int]]`):
|
- `completion_ids` (`list[list[int]]`):
|
||||||
List of lists of token IDs representing the model-generated completions for each prompt.
|
List of lists of token IDs representing the model-generated completions for each prompt.
|
||||||
- `logprobs` (`list[list[float]]`):
|
- `logprobs` (`list[list[float]]`):
|
||||||
@ -246,7 +252,11 @@ class VLLMClient:
|
|||||||
)
|
)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
json_response = response.json()
|
json_response = response.json()
|
||||||
return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]}
|
return {
|
||||||
|
"prompt_ids": json_response["prompt_ids"],
|
||||||
|
"completion_ids": json_response["completion_ids"],
|
||||||
|
"logprobs": json_response["logprobs"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
|
@ -499,6 +499,7 @@ def main(script_args: ScriptArguments):
|
|||||||
generation_kwargs: dict = field(default_factory=dict)
|
generation_kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
class GenerateResponse(BaseModel):
|
class GenerateResponse(BaseModel):
|
||||||
|
prompt_ids: list[list[int]]
|
||||||
completion_ids: list[list[int]]
|
completion_ids: list[list[int]]
|
||||||
logprobs: list[list[float]]
|
logprobs: list[list[float]]
|
||||||
|
|
||||||
@ -532,6 +533,7 @@ def main(script_args: ScriptArguments):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`GenerateResponse`:
|
`GenerateResponse`:
|
||||||
|
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
|
||||||
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
|
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
|
||||||
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
|
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
|
||||||
generated completions.
|
generated completions.
|
||||||
@ -543,7 +545,11 @@ def main(script_args: ScriptArguments):
|
|||||||
|
|
||||||
Example response:
|
Example response:
|
||||||
```json
|
```json
|
||||||
{"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]}
|
{
|
||||||
|
"prompt_ids": [[101, 102], [201, 202]],
|
||||||
|
"completion_ids": [[103, 104, 105], [203, 204, 205]],
|
||||||
|
"logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]
|
||||||
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
request.images = request.images or [None] * len(request.prompts)
|
request.images = request.images or [None] * len(request.prompts)
|
||||||
@ -596,13 +602,14 @@ def main(script_args: ScriptArguments):
|
|||||||
|
|
||||||
# Flatten and combine all results
|
# Flatten and combine all results
|
||||||
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
|
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
|
||||||
|
prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||||
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
|
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
|
||||||
logprobs: list[list[float]] = [
|
logprobs: list[list[float]] = [
|
||||||
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
|
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
|
||||||
for outputs in all_outputs
|
for outputs in all_outputs
|
||||||
for output in outputs.outputs
|
for output in outputs.outputs
|
||||||
]
|
]
|
||||||
return {"completion_ids": completion_ids, "logprobs": logprobs}
|
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
|
||||||
|
|
||||||
class InitCommunicatorRequest(BaseModel):
|
class InitCommunicatorRequest(BaseModel):
|
||||||
host: str
|
host: str
|
||||||
|
@ -1101,11 +1101,12 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
|
||||||
|
|
||||||
if self.max_prompt_length is not None:
|
if self.max_prompt_length is not None:
|
||||||
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||||
|
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||||
|
|
||||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||||
# tokens are needed for generation.
|
# tokens are needed for generation.
|
||||||
@ -1187,19 +1188,23 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
guided_decoding_regex=self.guided_decoding_regex,
|
guided_decoding_regex=self.guided_decoding_regex,
|
||||||
generation_kwargs=self.args.generation_kwargs,
|
generation_kwargs=self.args.generation_kwargs,
|
||||||
)
|
)
|
||||||
payload = (output["completion_ids"], output["logprobs"])
|
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||||
else:
|
else:
|
||||||
payload = None
|
payload = None
|
||||||
|
|
||||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
||||||
obj_list = [payload]
|
obj_list = [payload]
|
||||||
broadcast_object_list(obj_list, from_process=0)
|
broadcast_object_list(obj_list, from_process=0)
|
||||||
all_completion_ids, all_logprobs = obj_list[0]
|
all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]
|
||||||
|
|
||||||
|
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
|
||||||
|
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
|
||||||
|
|
||||||
process_slice = slice(
|
process_slice = slice(
|
||||||
self.accelerator.process_index * len(prompts),
|
self.accelerator.process_index * len(prompts),
|
||||||
(self.accelerator.process_index + 1) * len(prompts),
|
(self.accelerator.process_index + 1) * len(prompts),
|
||||||
)
|
)
|
||||||
|
prompt_ids = all_prompt_ids[process_slice]
|
||||||
completion_ids = all_completion_ids[process_slice]
|
completion_ids = all_completion_ids[process_slice]
|
||||||
logprobs = all_logprobs[process_slice]
|
logprobs = all_logprobs[process_slice]
|
||||||
|
|
||||||
@ -1254,6 +1259,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||||
|
|
||||||
|
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
[next(iter(lp.values())).logprob for lp in output.logprobs]
|
[next(iter(lp.values())).logprob for lp in output.logprobs]
|
||||||
@ -1266,9 +1272,11 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
# Each rank generates all outputs — we keep only our share.
|
# Each rank generates all outputs — we keep only our share.
|
||||||
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
||||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||||
|
prompt_ids = all_prompt_ids[tp_slice]
|
||||||
completion_ids = all_completion_ids[tp_slice]
|
completion_ids = all_completion_ids[tp_slice]
|
||||||
logprobs = all_logprobs[tp_slice]
|
logprobs = all_logprobs[tp_slice]
|
||||||
else:
|
else:
|
||||||
|
prompt_ids = all_prompt_ids
|
||||||
completion_ids = all_completion_ids
|
completion_ids = all_completion_ids
|
||||||
logprobs = all_logprobs
|
logprobs = all_logprobs
|
||||||
|
|
||||||
@ -1311,10 +1319,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Regular generation path
|
# Regular generation path
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
|
||||||
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
|
|
||||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
profiling_context(self, "transformers.generate"),
|
profiling_context(self, "transformers.generate"),
|
||||||
|
@ -1090,11 +1090,12 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
|
||||||
|
|
||||||
if self.max_prompt_length is not None:
|
if self.max_prompt_length is not None:
|
||||||
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||||
|
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||||
|
|
||||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||||
# tokens are needed for generation.
|
# tokens are needed for generation.
|
||||||
@ -1176,19 +1177,23 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
guided_decoding_regex=self.guided_decoding_regex,
|
guided_decoding_regex=self.guided_decoding_regex,
|
||||||
generation_kwargs=self.args.generation_kwargs,
|
generation_kwargs=self.args.generation_kwargs,
|
||||||
)
|
)
|
||||||
payload = (output["completion_ids"], output["logprobs"])
|
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
|
||||||
else:
|
else:
|
||||||
payload = None
|
payload = None
|
||||||
|
|
||||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
|
||||||
obj_list = [payload]
|
obj_list = [payload]
|
||||||
broadcast_object_list(obj_list, from_process=0)
|
broadcast_object_list(obj_list, from_process=0)
|
||||||
all_completion_ids, _ = obj_list[0]
|
all_prompt_ids, all_completion_ids, _ = obj_list[0]
|
||||||
|
|
||||||
|
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
|
||||||
|
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
|
||||||
|
|
||||||
process_slice = slice(
|
process_slice = slice(
|
||||||
self.accelerator.process_index * len(prompts),
|
self.accelerator.process_index * len(prompts),
|
||||||
(self.accelerator.process_index + 1) * len(prompts),
|
(self.accelerator.process_index + 1) * len(prompts),
|
||||||
)
|
)
|
||||||
|
prompt_ids = all_prompt_ids[process_slice]
|
||||||
completion_ids = all_completion_ids[process_slice]
|
completion_ids = all_completion_ids[process_slice]
|
||||||
|
|
||||||
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
|
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
|
||||||
@ -1241,6 +1246,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
|
||||||
|
|
||||||
|
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
|
||||||
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
|
||||||
|
|
||||||
if self.vllm_tensor_parallel_size > 1:
|
if self.vllm_tensor_parallel_size > 1:
|
||||||
@ -1248,8 +1254,10 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
# Each rank generates all outputs — we keep only our share.
|
# Each rank generates all outputs — we keep only our share.
|
||||||
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
|
||||||
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
|
||||||
|
prompt_ids = all_prompt_ids[tp_slice]
|
||||||
completion_ids = all_completion_ids[tp_slice]
|
completion_ids = all_completion_ids[tp_slice]
|
||||||
else:
|
else:
|
||||||
|
prompt_ids = all_prompt_ids
|
||||||
completion_ids = all_completion_ids
|
completion_ids = all_completion_ids
|
||||||
|
|
||||||
if self.args.vllm_enable_sleep_mode:
|
if self.args.vllm_enable_sleep_mode:
|
||||||
@ -1290,10 +1298,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Regular generation path
|
# Regular generation path
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||||
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
|
|
||||||
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
|
|
||||||
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
profiling_context(self, "transformers.generate"),
|
profiling_context(self, "transformers.generate"),
|
||||||
|
Reference in New Issue
Block a user