mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
9 Commits
v4.55.2
...
pixtral_ba
Author | SHA1 | Date | |
---|---|---|---|
8e07742afa | |||
31858be2c0 | |||
d9f165af99 | |||
c0515cebcd | |||
9cd45c8b48 | |||
0837c7e442 | |||
8041515fd7 | |||
7158f7488f | |||
19876ea405 |
@ -37,7 +37,7 @@ from ...image_utils import (
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
|
||||
from ...utils.import_utils import requires_backends
|
||||
|
||||
|
||||
@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
|
||||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
return obj.to(*args, **kwargs)
|
||||
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||
# only send to device, don't cast
|
||||
return obj.to(device=device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif isinstance(v, torch.Tensor) and device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
|
||||
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
|
@ -489,7 +489,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
all tokens of all images of shape (N_toks, D)
|
||||
"""
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
||||
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||
|
@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
|
||||
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
|
||||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
return obj.to(*args, **kwargs)
|
||||
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||
# only send to device, don't cast
|
||||
return obj.to(device=device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif isinstance(v, torch.Tensor) and device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
|
||||
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
|
@ -274,3 +274,26 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
return [[super().prepare_image_inputs()]] * batch_size
|
||||
|
||||
def test_batch_feature_mix_to(self):
|
||||
# This test is here because BatchFeatureMix.to() was breaking the structure of some inputs
|
||||
# so we ensure it doesn't regress
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
processor.image_processor.patch_size = {"height": 2, "width": 2}
|
||||
|
||||
# Test passing in an image
|
||||
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
|
||||
# Convert to a random other dtype and ensure structure is preserved
|
||||
inputs_image = inputs_image.to(torch.bfloat16)
|
||||
self.assertIn("input_ids", inputs_image)
|
||||
self.assertTrue(len(inputs_image["input_ids"]) == 1)
|
||||
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
|
||||
self.assertIsInstance(inputs_image["pixel_values"], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0], list)
|
||||
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
|
||||
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)
|
||||
|
Reference in New Issue
Block a user