Compare commits

...

9 Commits

Author SHA1 Message Date
8e07742afa Comment on processor test 2024-12-05 17:58:22 +00:00
31858be2c0 No more unsqueezing 2024-12-05 17:58:22 +00:00
d9f165af99 [run-slow] pixtral 2024-12-05 17:58:22 +00:00
c0515cebcd Add test 2024-12-05 17:58:22 +00:00
9cd45c8b48 Correctly pass device arg during recursion 2024-12-05 17:58:22 +00:00
0837c7e442 make fixup 2024-12-05 17:58:22 +00:00
8041515fd7 Fix incorrectly inserted device arg 2024-12-05 17:58:22 +00:00
7158f7488f make fixup 2024-12-05 17:58:22 +00:00
19876ea405 Fix case of nested tensors in BatchMixFeature 2024-12-05 17:58:22 +00:00
4 changed files with 60 additions and 35 deletions

View File

@ -37,7 +37,7 @@ from ...image_utils import (
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
from ...utils.import_utils import requires_backends
@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self

View File

@ -489,7 +489,7 @@ class PixtralVisionModel(PixtralPreTrainedModel):
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)

View File

@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
logger = logging.get_logger(__name__)
@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self

View File

@ -274,3 +274,26 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size
def test_batch_feature_mix_to(self):
# This test is here because BatchFeatureMix.to() was breaking the structure of some inputs
# so we ensure it doesn't regress
processor = self.processor_class.from_pretrained(self.tmpdirname)
prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:"
# Make small for checking image token expansion
processor.image_processor.size = {"longest_edge": 30}
processor.image_processor.patch_size = {"height": 2, "width": 2}
# Test passing in an image
inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt")
# Convert to a random other dtype and ensure structure is preserved
inputs_image = inputs_image.to(torch.bfloat16)
self.assertIn("input_ids", inputs_image)
self.assertTrue(len(inputs_image["input_ids"]) == 1)
self.assertIsInstance(inputs_image["input_ids"], torch.Tensor)
self.assertIsInstance(inputs_image["pixel_values"], list)
self.assertTrue(len(inputs_image["pixel_values"]) == 1)
self.assertIsInstance(inputs_image["pixel_values"][0], list)
self.assertTrue(len(inputs_image["pixel_values"][0]) == 2)
self.assertIsInstance(inputs_image["pixel_values"][0][0], torch.Tensor)