mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 11:29:23 +08:00
Compare commits
4 Commits
v0.22.0
...
simplify-d
Author | SHA1 | Date | |
---|---|---|---|
db7270ca70 | |||
70f92d209e | |||
39faf36a91 | |||
1cb4150dfb |
@ -162,6 +162,18 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
|
||||
|
||||
def test_pad_to_multiple_of_and_padding_free(self):
|
||||
"""Test padding to multiple of specified value."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
|
||||
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
|
||||
|
||||
def test_custom_position_ids(self):
|
||||
"""Test handling of custom position IDs in examples."""
|
||||
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
|
@ -12,18 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
|
||||
|
||||
|
||||
# Read version from VERSION file
|
||||
_version_file = Path(__file__).parent.parent / "VERSION"
|
||||
try:
|
||||
with open(_version_file, encoding="utf-8") as f:
|
||||
__version__ = f.read().strip()
|
||||
except FileNotFoundError:
|
||||
__version__ = version("trl")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
|
||||
_import_structure = {
|
||||
|
@ -663,8 +663,9 @@ def pack_dataset(
|
||||
>>> dataset = Dataset.from_dict(examples)
|
||||
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
|
||||
>>> packed_dataset[:]
|
||||
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
|
||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
|
||||
{'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]],
|
||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]],
|
||||
'seq_lengths': [[3, 1], [3], [2]]}
|
||||
```
|
||||
"""
|
||||
if map_kwargs is None:
|
||||
|
@ -124,7 +124,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
that are no in the completion.
|
||||
padding_free (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
|
||||
generated accordingly. The attention mask will be set to 1 for all tokens.
|
||||
generated accordingly.
|
||||
return_position_ids (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return position IDs. If `True`, position IDs are generated and returned. If `False`, attention
|
||||
masks are generated and returned instead. Note that when using FlashAttention, this should be set to
|
||||
`True`.
|
||||
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
@ -139,8 +143,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
>>> collator(examples)
|
||||
{'input_ids': tensor([[ 1, 2, 3],
|
||||
[ 4, 5, 0]]),
|
||||
'attention_mask': tensor([[ 1, 1, 1],
|
||||
[ 1, 1, 0]]),
|
||||
'position_ids': tensor([[0, 1, 2],
|
||||
[0, 1, 0]]),
|
||||
'labels': tensor([[ 1, 2, 3],
|
||||
@ -154,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
>>> collator(examples)
|
||||
{'input_ids': tensor([[ 1, 2, 3],
|
||||
[ 4, 5, 0]]),
|
||||
'attention_mask': tensor([[ 1, 1, 1],
|
||||
[ 1, 1, 0]]),
|
||||
'position_ids': tensor([[0, 1, 2],
|
||||
[0, 1, 0]]),
|
||||
'labels': tensor([[-100, 2, 3],
|
||||
@ -182,14 +182,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
# Convert to tensor
|
||||
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
|
||||
# Check if we have meaningful seq_lengths from packing (restarting sequences)
|
||||
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free
|
||||
|
||||
# For packing with position_ids, we should NOT create attention_mask as it causes
|
||||
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
|
||||
if not has_packed_position_ids:
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
|
||||
# In practice, self.return_position_ids is True when using FlashAttention. When using FlashAttention, we should
|
||||
# NOT create attention_mask as it causes FlashAttention to ignore position_ids and compute wrong cu_seq_lens
|
||||
# from the all-1s mask.
|
||||
if self.return_position_ids:
|
||||
if "seq_lengths" in examples[0]:
|
||||
position_ids = self.get_position_ids_from_packed_seq_lengths(
|
||||
@ -197,6 +192,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
)
|
||||
else:
|
||||
position_ids = [torch.arange(len(ids)) for ids in input_ids]
|
||||
else:
|
||||
if "seq_lengths" in examples[0]:
|
||||
logger.warning(
|
||||
"The input examples contain `seq_lengths` but `return_position_ids` is set to `False`. "
|
||||
"`seq_lengths` will be ignored."
|
||||
)
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
if "labels" in examples[0]:
|
||||
labels = [torch.tensor(example["labels"]) for example in examples]
|
||||
else:
|
||||
@ -206,48 +208,48 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
|
||||
|
||||
# Pad
|
||||
# If padding_free, flatten everything into a single sequence
|
||||
output = {}
|
||||
if self.padding_free:
|
||||
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
|
||||
if not has_packed_position_ids:
|
||||
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
||||
input_ids = [torch.cat(input_ids, dim=0)]
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
|
||||
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
|
||||
position_ids = [torch.cat(position_ids, dim=0)]
|
||||
else:
|
||||
attention_mask = [torch.cat(attention_mask, dim=0)]
|
||||
labels = [torch.cat(labels, dim=0)]
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
|
||||
output["labels"][completion_mask == 0] = -100
|
||||
completion_mask = [torch.cat(completion_mask, dim=0)]
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
else:
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
assistant_masks = [torch.cat(assistant_masks, dim=0)]
|
||||
|
||||
# Pad
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
else:
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = pad(
|
||||
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = pad(
|
||||
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user