mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
16 Commits
7e9c6e45d5
...
grpo-ssr-r
Author | SHA1 | Date | |
---|---|---|---|
b5f8feb4a6 | |||
6b485118a7 | |||
f88569edd9 | |||
e7986290ba | |||
b562ec28af | |||
631499efe8 | |||
446fd001ab | |||
727b3767d1 | |||
edbf12a061 | |||
f9b645d292 | |||
91ad02766f | |||
b2b285872e | |||
7ce8661b84 | |||
7a17bc3fbc | |||
29007cc546 | |||
e7ceb78c93 |
@ -24,7 +24,7 @@ from transformers.testing_utils import require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
|
||||
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_dict_list, shuffle_tensor_dict, split_tensor_dict
|
||||
|
||||
from .testing_utils import require_vllm
|
||||
|
||||
@ -104,6 +104,40 @@ class ShuffleTensorDictTester(unittest.TestCase):
|
||||
self.assertEqual(shuffled["x"].shape, x.shape)
|
||||
|
||||
|
||||
class ShuffleDictListTester(unittest.TestCase):
|
||||
def test_shuffle_preserves_length(self):
|
||||
a = [1, 2, 3, 4]
|
||||
b = ["a", "b", "c", "d"]
|
||||
tensor_dict = {"a": a.copy(), "b": b.copy()}
|
||||
|
||||
shuffled = shuffle_dict_list(tensor_dict)
|
||||
|
||||
self.assertEqual(len(shuffled["a"]), len(a))
|
||||
self.assertEqual(len(shuffled["b"]), len(b))
|
||||
|
||||
def test_shuffle_consistent_across_lists(self):
|
||||
a = [10, 20, 30]
|
||||
b = ["x", "y", "z"]
|
||||
tensor_dict = {"a": a.copy(), "b": b.copy()}
|
||||
|
||||
shuffled = shuffle_dict_list(tensor_dict)
|
||||
|
||||
mapping = dict(zip(shuffled["a"], shuffled["b"]))
|
||||
|
||||
self.assertEqual(mapping[10], "x")
|
||||
self.assertEqual(mapping[20], "y")
|
||||
self.assertEqual(mapping[30], "z")
|
||||
|
||||
def test_none_list_remains_none(self):
|
||||
a = [1, 2, 3]
|
||||
tensor_dict = {"a": a.copy(), "b": None}
|
||||
|
||||
shuffled = shuffle_dict_list(tensor_dict)
|
||||
|
||||
self.assertIsNone(shuffled["b"])
|
||||
self.assertEqual(len(shuffled["a"]), len(a))
|
||||
|
||||
|
||||
class RepeatRandomSamplerTester(unittest.TestCase):
|
||||
def test_sampler(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
|
@ -145,6 +145,15 @@ class GRPOConfig(TrainingArguments):
|
||||
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
|
||||
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
|
||||
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
|
||||
replay_buffer_class (`str`, *optional*, defaults to `"ReplayBuffer"`):
|
||||
Replay buffer class to use. Options are [`ReplayBuffer`] and [`SSRReplayBuffer`]. The default is
|
||||
`"ReplayBuffer"`, which randomly samples without replacement.
|
||||
ssr_capacity_scalar (`int`, *optional*, defaults to `4`):
|
||||
Scalar to multiply the replay buffer capacity. The default is `1`, which means the capacity is equal to the
|
||||
number of training samples in the effective batch.
|
||||
ssr_alpha (`float`, *optional*, defaults to `1.0`):
|
||||
Alpha parameter for controlling the probability distribution of the replay buffer. The default is `1.0`,
|
||||
which means the replay buffer samples uniformly.
|
||||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||||
weighted equally with weight `1.0`.
|
||||
@ -417,6 +426,26 @@ class GRPOConfig(TrainingArguments):
|
||||
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
|
||||
},
|
||||
)
|
||||
replay_buffer_class: str = field(
|
||||
default="ReplayBuffer",
|
||||
metadata={
|
||||
"help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement."
|
||||
},
|
||||
)
|
||||
ssr_capacity_scalar: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
"help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is "
|
||||
"equal to the number of training samples in the effective batch."
|
||||
},
|
||||
)
|
||||
ssr_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, "
|
||||
},
|
||||
)
|
||||
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
@ -21,6 +22,7 @@ from contextlib import nullcontext
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
@ -81,6 +83,124 @@ if is_wandb_available():
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.sample_indices = []
|
||||
|
||||
def add(self, experience):
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
else:
|
||||
self.buffer.pop(0)
|
||||
self.buffer.append(experience)
|
||||
|
||||
# Clear index queue when buffer changes
|
||||
self.sample_indices.clear()
|
||||
|
||||
def add_batch(self, experiences: dict[str, list[torch.Tensor]]):
|
||||
"""
|
||||
Add a batch of experiences to the replay buffer.
|
||||
"""
|
||||
first_tensor = next(tensor for tensor in experiences.values() if tensor is not None)
|
||||
num_items = len(first_tensor)
|
||||
|
||||
for i in range(num_items):
|
||||
experience = {key: tensor[i] if tensor is not None else None for key, tensor in experiences.items()}
|
||||
self.add(experience)
|
||||
|
||||
def _init_sampling_queue(self):
|
||||
self.sample_indices = list(range(len(self.buffer)))
|
||||
random.shuffle(self.sample_indices)
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.sample_indices:
|
||||
self._init_sampling_queue()
|
||||
|
||||
batch = []
|
||||
while len(batch) < batch_size and self.sample_indices:
|
||||
idx = self.sample_indices.pop(0)
|
||||
batch.append(self.buffer[idx])
|
||||
|
||||
if len(batch) != batch_size:
|
||||
raise ValueError("Not enough samples in the buffer to fill the batch.")
|
||||
|
||||
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
class SSRReplayBuffer(ReplayBuffer):
|
||||
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
|
||||
def __init__(self, capacity, alpha=1.0):
|
||||
super().__init__(capacity)
|
||||
self.alpha = alpha
|
||||
self.advantages = []
|
||||
|
||||
def add(self, experience):
|
||||
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
|
||||
advantage = experience["advantages"].item()
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
|
||||
elif abs(advantage) > EPS:
|
||||
# Replace the oldest entry if the buffer is full and adv is non zero
|
||||
self.buffer.pop(0)
|
||||
self.advantages.pop(0)
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage))
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.buffer:
|
||||
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
|
||||
|
||||
# Convert advantages to priorities
|
||||
scaled_priorities = np.power(self.advantages, self.alpha)
|
||||
total_priority = np.sum(scaled_priorities)
|
||||
probabilities = scaled_priorities / total_priority
|
||||
|
||||
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=True)
|
||||
batch = [self.buffer[i] for i in indices]
|
||||
|
||||
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
|
||||
|
||||
class DapoReplayBuffer(ReplayBuffer):
|
||||
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
|
||||
def __init__(self, capacity, alpha=1.0):
|
||||
super().__init__(capacity)
|
||||
self.alpha = alpha
|
||||
self.weights = []
|
||||
|
||||
def add(self, experience):
|
||||
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
|
||||
advantage = experience["advantages"].item()
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.weights.append(1.0) # Store absolute advantage
|
||||
elif abs(advantage) > EPS:
|
||||
# Replace the oldest entry if the buffer is full and adv is positive
|
||||
self.buffer.pop(0)
|
||||
self.weights.pop(0)
|
||||
self.buffer.append(experience)
|
||||
self.weights.append(1.0)
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.buffer:
|
||||
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
|
||||
|
||||
# Convert advantages to priorities
|
||||
scaled_priorities = np.power(self.weights, self.alpha)
|
||||
total_priority = np.sum(scaled_priorities)
|
||||
probabilities = scaled_priorities / total_priority
|
||||
|
||||
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=False)
|
||||
batch = [self.buffer[i] for i in indices]
|
||||
|
||||
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
|
||||
|
||||
|
||||
class RepeatSampler(Sampler):
|
||||
"""
|
||||
Sampler that repeats the indices of a dataset in a structured manner.
|
||||
@ -215,7 +335,7 @@ def split_tensor_dict(
|
||||
]
|
||||
"""
|
||||
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
|
||||
chunk_size = first_tensor.shape[0] // num_chunks
|
||||
chunk_size = len(first_tensor) // num_chunks
|
||||
return [
|
||||
{
|
||||
key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
|
||||
@ -225,6 +345,30 @@ def split_tensor_dict(
|
||||
]
|
||||
|
||||
|
||||
def combine_tensor_dict(split_dicts: list[dict[str, Optional[torch.Tensor]]]) -> dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Combines a list of dictionaries containing tensors into a single dictionary by
|
||||
concatenating the tensors along the first dimension.
|
||||
Example:
|
||||
>>> d1 = {"x": torch.tensor([[0, 1], [2, 3]]), "y": torch.tensor([[0], [1]])}
|
||||
>>> d2 = {"x": torch.tensor([[4, 5], [6, 7]]), "y": torch.tensor([[2], [3]])}
|
||||
>>> d3 = {"x": torch.tensor([[8, 9], [10, 11]]), "y": torch.tensor([[4], [5]])}
|
||||
>>> combine_tensor_dict([d1, d2, d3])
|
||||
{
|
||||
"x": tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]),
|
||||
"y": tensor([[0], [1], [2], [3], [4], [5]])
|
||||
}
|
||||
"""
|
||||
combined_dict = {}
|
||||
keys = split_dicts[0].keys()
|
||||
|
||||
for key in keys:
|
||||
tensors = [d[key] for d in split_dicts if d[key] is not None]
|
||||
combined_dict[key] = torch.stack(tensors, dim=0) if tensors else None
|
||||
|
||||
return combined_dict
|
||||
|
||||
|
||||
def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Shuffles a dictionary of tensors along the first dimension in unison.
|
||||
@ -247,6 +391,25 @@ def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[
|
||||
return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()}
|
||||
|
||||
|
||||
def shuffle_dict_list(tensor_dict: dict[str, Optional[list[Any]]]) -> dict[str, Optional[list[Any]]]:
|
||||
"""
|
||||
Shuffles a dictionary of lists along the first dimension in unison.
|
||||
|
||||
Example:
|
||||
>>> x = [1, 2, 3]
|
||||
>>> y = [4, 5, 6]
|
||||
>>> tensor_dict = {"x": x, "y": y}
|
||||
>>> shuffle_list_dict(tensor_dict)
|
||||
{'x': [2, 1, 3], 'y': [5, 4, 6]}
|
||||
"""
|
||||
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
|
||||
batch_size = len(first_tensor)
|
||||
permutation = torch.randperm(batch_size)
|
||||
return {
|
||||
key: [tensor[i] for i in permutation] if tensor is not None else None for key, tensor in tensor_dict.items()
|
||||
}
|
||||
|
||||
|
||||
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
||||
@ -528,7 +691,22 @@ class GRPOTrainer(Trainer):
|
||||
self._step = 0
|
||||
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
||||
# `_get_train_sampler` and `_prepare_inputs`.
|
||||
self._buffered_inputs = None
|
||||
if args.replay_buffer_class == "ReplayBuffer":
|
||||
self.replay_buffer = ReplayBuffer(capacity=args.generation_batch_size)
|
||||
elif args.replay_buffer_class == "SSRReplayBuffer":
|
||||
self.replay_buffer = SSRReplayBuffer(
|
||||
capacity=args.generation_batch_size * args.ssr_capacity_scalar,
|
||||
alpha=args.ssr_alpha,
|
||||
)
|
||||
elif args.replay_buffer_class == "DapoReplayBuffer":
|
||||
self.replay_buffer = DapoReplayBuffer(
|
||||
capacity=args.generation_batch_size * args.ssr_capacity_scalar,
|
||||
alpha=args.ssr_alpha,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `replay_buffer_class` passed to `GRPOConfig`. Expected either 'ReplayBuffer' or 'SSRReplayBuffer', but got {self.args.replay_buffer_class}."
|
||||
)
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
@ -967,12 +1145,16 @@ class GRPOTrainer(Trainer):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if mode == "train":
|
||||
generate_every = self.args.steps_per_generation * self.num_iterations
|
||||
if self._step % generate_every == 0 or self._buffered_inputs is None:
|
||||
if self._step % generate_every == 0 or len(self.replay_buffer) == 0:
|
||||
# self._buffered_inputs=None can occur when resuming from a checkpoint
|
||||
generation_batch = self._generate_and_score_completions(generation_batch)
|
||||
generation_batch = shuffle_tensor_dict(generation_batch)
|
||||
self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation)
|
||||
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
|
||||
# we shuffle the generation batch to ensure that the order of prompts is randomized
|
||||
# across different steps, onl relevant if the generation batch is larger than the optimization batch
|
||||
generation_batch_shuffled = shuffle_dict_list(generation_batch)
|
||||
self.replay_buffer.add_batch(generation_batch_shuffled)
|
||||
|
||||
inputs = self.replay_buffer.sample(self.args.per_device_train_batch_size)
|
||||
|
||||
self._step += 1
|
||||
else:
|
||||
# In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
|
||||
@ -1254,13 +1436,45 @@ class GRPOTrainer(Trainer):
|
||||
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
|
||||
self._textual_logs["advantages"].extend(all_process_advantages.tolist())
|
||||
|
||||
# Unpad the prompt and completion ids to optimise memory: https://github.com/huggingface/trl/pull/3495
|
||||
unpadded_prompt_ids = []
|
||||
unpadded_completion_ids = []
|
||||
unpadded_per_token_logps = []
|
||||
prompt_ids_length = prompt_ids.size(1)
|
||||
logp_iter = attention_mask if old_per_token_logps is None else old_per_token_logps # dummy iterator
|
||||
# get the start and end indices of the attention_mask
|
||||
for p_ids, c_ids, old_logps, mask in zip(prompt_ids, completion_ids, logp_iter, attention_mask):
|
||||
indices = torch.where(mask == 1)[0]
|
||||
|
||||
if len(indices) > 0:
|
||||
start = indices[0]
|
||||
end = indices[-1] + 1
|
||||
if prompt_ids_length > end:
|
||||
raise ValueError(
|
||||
f"End index {end} exceeds prompt_ids_length {prompt_ids_length}. "
|
||||
"This can happen if the attention mask is not correctly set."
|
||||
)
|
||||
# prompt ids were left padded
|
||||
unpadded_prompt_ids.append(p_ids[start:prompt_ids_length])
|
||||
# completion ids were right padded
|
||||
unpadded_completion_ids.append(c_ids[: end - prompt_ids_length])
|
||||
if mask[start:end].sum() != end - start:
|
||||
raise ValueError(f"Attention mask from {start} to {end} does not match the expected length. ")
|
||||
if old_per_token_logps is not None:
|
||||
unpadded_per_token_logps.append(old_logps[: end - prompt_ids_length])
|
||||
else:
|
||||
# case where the attention mask is all zeros, e.g. when mask_truncated_completions is enabled
|
||||
raise ValueError(
|
||||
"Attention mask is all zeros"
|
||||
f" for prompt {p_ids.tolist()} and completion {c_ids.tolist()}. "
|
||||
|
||||
)
|
||||
|
||||
return {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"unpadded_prompt_ids": unpadded_prompt_ids,
|
||||
"unpadded_completion_ids": unpadded_completion_ids,
|
||||
"unpadded_per_token_logps": unpadded_per_token_logps if old_per_token_logps is not None else None,
|
||||
"advantages": advantages,
|
||||
"old_per_token_logps": old_per_token_logps,
|
||||
}
|
||||
|
||||
def compute_liger_loss(self, unwrapped_model, inputs):
|
||||
@ -1312,14 +1526,42 @@ class GRPOTrainer(Trainer):
|
||||
|
||||
@profiling_decorator
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
prompt_ids, prompt_mask = pad(
|
||||
inputs["unpadded_prompt_ids"],
|
||||
padding_value=self.processing_class.pad_token_id,
|
||||
padding_side="left",
|
||||
return_mask=True,
|
||||
)
|
||||
completion_ids, completion_mask = pad(
|
||||
inputs["unpadded_completion_ids"],
|
||||
padding_value=self.processing_class.pad_token_id,
|
||||
padding_side="right",
|
||||
return_mask=True,
|
||||
)
|
||||
|
||||
old_per_token_logps = inputs["unpadded_per_token_logps"]
|
||||
|
||||
if old_per_token_logps is not None:
|
||||
old_per_token_logps = pad(inputs["unpadded_per_token_logps"], padding_value=0.0, padding_side="right")
|
||||
|
||||
padded_inputs = {
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"advantages": torch.stack(inputs["advantages"]),
|
||||
"old_per_token_logps": old_per_token_logps,
|
||||
}
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
if self.use_liger_loss:
|
||||
# Compute the loss using the liger grpo loss
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
|
||||
return self._forward_redirection(
|
||||
model, unwrapped_model, self.compute_liger_loss, unwrapped_model, padded_inputs
|
||||
)
|
||||
else:
|
||||
return self._compute_loss(model, inputs)
|
||||
return self._compute_loss(model, padded_inputs)
|
||||
|
||||
def _compute_loss(self, model, inputs):
|
||||
# Compute the per-token log probabilities for the model
|
||||
|
@ -420,7 +420,8 @@ def pad(
|
||||
padding_value: int = 0,
|
||||
padding_side: str = "right",
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
return_mask: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pads a list of tensors to the same shape along the first dimension.
|
||||
|
||||
@ -433,6 +434,8 @@ def pad(
|
||||
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to `None`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
return_mask (`bool`, *optional*, defaults to `False`):
|
||||
If True, returns an attn mask tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@ -461,7 +464,9 @@ def pad(
|
||||
|
||||
# Create an output tensor filled with the padding value
|
||||
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
|
||||
|
||||
# Initialize mask tensor if required
|
||||
if return_mask:
|
||||
mask = torch.zeros((len(tensors), output_shape[0]), dtype=torch.long, device=tensors[0].device)
|
||||
for i, t in enumerate(tensors):
|
||||
if padding_side == "left":
|
||||
seq_start = output_shape[0] - t.shape[0]
|
||||
@ -474,6 +479,10 @@ def pad(
|
||||
seq_slice = slice(seq_start, seq_start + t.shape[0])
|
||||
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
|
||||
output[i][slices] = t
|
||||
if return_mask:
|
||||
mask[i, seq_slice] = 1
|
||||
if return_mask:
|
||||
return output, mask
|
||||
|
||||
return output
|
||||
|
||||
|
Reference in New Issue
Block a user