|
|
|
@ -297,9 +297,14 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
|
|
|
|
elif isinstance(module, nn.Linear) and module.bias is not None:
|
|
|
|
|
module.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
|
|
|
if isinstance(module, Blip2Encoder):
|
|
|
|
|
module.gradient_checkpointing = value
|
|
|
|
|
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
|
|
|
|
if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
|
|
|
|
|
module.gradient_checkpointing_func = gradient_checkpointing_func
|
|
|
|
|
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
|
|
|
|
|
|
|
|
|
# Enable / disable GC for the language model as well
|
|
|
|
|
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
|
|
|
|
|
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BLIP_2_START_DOCSTRING = r"""
|
|
|
|
@ -473,17 +478,11 @@ class Blip2Encoder(nn.Module):
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
encoder_states = encoder_states + (hidden_states,)
|
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
|
return module(*inputs, output_attentions)
|
|
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
create_custom_forward(encoder_layer),
|
|
|
|
|
layer_outputs = self.gradient_checkpointing_func(
|
|
|
|
|
encoder_layer.__call__,
|
|
|
|
|
hidden_states,
|
|
|
|
|
attention_mask,
|
|
|
|
|
output_attentions,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
layer_outputs = encoder_layer(
|
|
|
|
@ -944,15 +943,8 @@ class Blip2QFormerEncoder(nn.Module):
|
|
|
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
|
|
)
|
|
|
|
|
use_cache = False
|
|
|
|
|
|
|
|
|
|
def create_custom_forward(module):
|
|
|
|
|
def custom_forward(*inputs):
|
|
|
|
|
return module(*inputs, past_key_value, output_attentions, query_length)
|
|
|
|
|
|
|
|
|
|
return custom_forward
|
|
|
|
|
|
|
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
create_custom_forward(layer_module),
|
|
|
|
|
layer_outputs = self.gradient_checkpointing_func(
|
|
|
|
|
layer_module.__call__,
|
|
|
|
|
hidden_states,
|
|
|
|
|
attention_mask,
|
|
|
|
|
layer_head_mask,
|
|
|
|
@ -1272,14 +1264,10 @@ class Blip2Model(Blip2PreTrainedModel):
|
|
|
|
|
>>> import torch
|
|
|
|
|
>>> from transformers import AutoTokenizer, Blip2Model
|
|
|
|
|
|
|
|
|
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device)
|
|
|
|
|
>>> inputs = tokenizer(["a photo of a cat"], padding=True, return_tensors="pt")
|
|
|
|
|
>>> text_features = model.get_text_features(**inputs)
|
|
|
|
|
```"""
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
@ -1333,16 +1321,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
|
|
|
|
>>> import requests
|
|
|
|
|
>>> from transformers import AutoProcessor, Blip2Model
|
|
|
|
|
|
|
|
|
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
|
|
|
|
|
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
>>> image_outputs = model.get_image_features(**inputs)
|
|
|
|
|
```"""
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
@ -1381,15 +1365,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
|
|
|
|
>>> import requests
|
|
|
|
|
>>> from transformers import Blip2Processor, Blip2Model
|
|
|
|
|
|
|
|
|
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
|
|
|
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
|
|
|
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
>>> qformer_outputs = model.get_qformer_features(**inputs)
|
|
|
|
|
```"""
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
@ -1654,34 +1635,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
|
|
Image captioning (without providing a text prompt):
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from PIL import Image
|
|
|
|
|
>>> import requests
|
|
|
|
|
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
|
|
|
|
>>> import torch
|
|
|
|
|
|
|
|
|
|
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
|
|
|
|
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
|
|
|
|
... "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
|
|
|
|
|
... )
|
|
|
|
|
>>> model.to(device) # doctest: +IGNORE_RESULT
|
|
|
|
|
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
|
|
|
|
|
|
|
|
|
>>> generated_ids = model.generate(**inputs)
|
|
|
|
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
|
|
>>> print(generated_text)
|
|
|
|
|
two cats laying on a couch
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
Visual question answering (prompt = question):
|
|
|
|
|
Prepare processor, model and image input
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from PIL import Image
|
|
|
|
@ -1698,7 +1652,22 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|
|
|
|
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
Image captioning (without providing a text prompt):
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
|
|
|
|
|
|
|
|
|
>>> generated_ids = model.generate(**inputs)
|
|
|
|
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
|
|
|
|
>>> print(generated_text)
|
|
|
|
|
two cats laying on a couch
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
Visual question answering (prompt = question):
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> prompt = "Question: how many cats are there? Answer:"
|
|
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.float16)
|
|
|
|
|
|
|
|
|
@ -1712,20 +1681,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
|
|
|
|
|
This greatly reduces the amount of memory used by the model while maintaining the same performance.
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
>>> from PIL import Image
|
|
|
|
|
>>> import requests
|
|
|
|
|
>>> from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
|
|
|
|
>>> import torch
|
|
|
|
|
|
|
|
|
|
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
|
|
|
|
|
>>> model = Blip2ForConditionalGeneration.from_pretrained(
|
|
|
|
|
... "Salesforce/blip2-flan-t5-xl", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
|
|
|
|
|
... "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.bfloat16
|
|
|
|
|
... ) # doctest: +IGNORE_RESULT
|
|
|
|
|
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
|
|
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
|
|
|
|
|
>>> prompt = "Question: how many cats are there? Answer:"
|
|
|
|
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
>>> generated_ids = model.generate(**inputs)
|
|
|
|
|