[Model] IBM/NASA Prithvi Geospatial model (#12830)

This commit is contained in:
Christian Pinto
2025-02-12 04:34:30 +00:00
committed by GitHub
parent 3ee696a63d
commit 974dfd4971
7 changed files with 811 additions and 9 deletions

View File

@ -0,0 +1,530 @@
# SPDX-License-Identifier: Apache-2.0
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
""" # noqa: E501
import argparse
import datetime
import os
import re
from typing import List, Union
import albumentations
import numpy as np
import rasterio
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm import LLM
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
model_config = """{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
'w') as config_file:
config_file.write(model_config)
datamodule_config = {
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
'batch_size':
16,
'constant_scale':
0.0001,
'data_root':
'/dccstor/geofm-finetuning/datasets/sen1floods11',
'drop_last':
True,
'no_data_replace':
0.0,
'no_label_replace':
-1,
'num_workers':
8,
'test_transform': [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0)
],
}
class PrithviMAE:
def __init__(self):
print("Initializing PrithviMAE model")
self.model = LLM(model=os.path.join(os.path.dirname(__file__),
"./model"),
skip_tokenizer_init=True,
dtype="float32")
def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure
mm_data = {
"pixel_values":
torch.empty(0) if input_data is None else input_data,
"location_coords":
torch.empty(0) if location_coords is None else location_coords
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
print(
"################ Inference done (it took seconds) ##############"
)
return outputs[0].outputs.data
def generate_datamodule():
datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config['data_root'],
batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform"
""])
return datamodule
def process_channel_group(orig_img, channels):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
"""
orig_img = orig_img[channels, ...]
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
valid_mask[orig_img == NO_DATA_FLOAT] = False
# Rescale (enhancing contrast)
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
min_value = OFFSET
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
1)
# No data as zeros
orig_img[~valid_mask] = 0
return orig_img
def read_geotiff(file_path: str):
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
with rasterio.open(file_path) as src:
img = src.read()
meta = src.meta
try:
coords = src.lnglat()
except Exception:
# Cannot read coords
coords = None
return img, meta, coords
def save_geotiff(image, output_path: str, meta: dict):
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
with rasterio.open(output_path, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
return
def _convert_np_uint8(float_image: torch.Tensor):
image = float_image.numpy() * 255.0
image = image.astype(dtype=np.uint8)
return image
def load_example(
file_paths: List[str],
mean: List[float] = None,
std: List[float] = None,
indices: Union[list[int], None] = None,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images
in *file_paths*.
std: list containing std values for each band in the images
in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs = []
metas = []
temporal_coords = []
location_coords = []
for file in file_paths:
img, meta, coords = read_geotiff(file)
# Rescaling (don't normalize on nodata)
img = np.moveaxis(img, 0, -1) # channels last for rescaling
if indices is not None:
img = img[..., indices]
if mean is not None and std is not None:
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
imgs.append(img)
metas.append(meta)
if coords is not None:
location_coords.append(coords)
try:
match = re.search(r'(\d{7,8}T\d{6})', file)
if match:
year = int(match.group(1)[:4])
julian_day = match.group(1).split('T')[0][4:]
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = datetime.datetime.strptime(
julian_day, '%m%d').timetuple().tm_yday
temporal_coords.append([year, julian_day])
except Exception as e:
print(f'Could not extract timestamp for {file} ({e})')
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
def run_model(input_data,
temporal_coords,
location_coords,
model,
datamodule,
img_size,
lightning_model=None):
# Reflect pad if not divisible by img_size
original_h, original_w = input_data.shape[-2:]
pad_h = (img_size - (original_h % img_size)) % img_size
pad_w = (img_size - (original_w % img_size)) % img_size
input_data = np.pad(input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
mode="reflect")
# Build sliding window
batch_size = 1
batch = torch.tensor(input_data, device="cpu")
windows = (batch.unfold(3, img_size,
img_size).unfold(4, img_size, img_size))
h1, w1 = windows.shape[3:5]
windows = rearrange(windows,
"b c t h1 w1 h w -> (b h1 w1) c t h w",
h=img_size,
w=img_size)
# Split into batches if number of windows > batch_size
num_batches = windows.shape[0] // batch_size if windows.shape[
0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords,
device=device).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0],
device=device).unsqueeze(0)
else:
location_coords = None
# Run model
pred_imgs = []
for x in windows:
# Apply standardization
x = datamodule.test_transform(
image=x.squeeze().numpy().transpose(1, 2, 0))
x = datamodule.aug(x)['image']
with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
x,
temporal_coords=temporal_coords,
location_coords=location_coords)
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
size=img_size,
mode="nearest")
pred_imgs.append(y_hat)
pred_imgs = torch.concat(pred_imgs, dim=0)
# Build images from patches
pred_imgs = rearrange(
pred_imgs,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
h=img_size,
w=img_size,
b=1,
c=1,
h1=h1,
w1=w1,
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., :original_h, :original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
return pred_imgs
def main(
data_file: str,
output_dir: str,
rgb_outputs: bool,
input_indices: list[int] = None,
):
os.makedirs(output_dir, exist_ok=True)
# Load model ---------------------------------------------------------------
model_obj = PrithviMAE()
datamodule = generate_datamodule()
img_size = 256 # Size of Sen1Floods11
# Loading data -------------------------------------------------------------
input_data, temporal_coords, location_coords, meta_data = load_example(
file_paths=[data_file],
indices=input_indices,
)
meta_data = meta_data[0] # only one image
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
# Running model ------------------------------------------------------------
channels = [
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB
pred = run_model(input_data, temporal_coords, location_coords, model_obj,
datamodule, img_size)
# Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join(
output_dir,
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
# Save image + pred
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
if input_data.mean() < 1:
input_data = input_data * 10000 # Scale to 0-10000
rgb_orig = process_channel_group(
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
channels=channels,
)
pred[pred == 0.] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
img_pred_file = os.path.join(
output_dir,
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
save_geotiff(
image=_convert_np_uint8(img_pred),
output_path=img_pred_file,
meta=meta_data,
)
# Save image rgb
if rgb_outputs:
rgb_file = os.path.join(
output_dir, "original_rgb_"
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
save_geotiff(
image=_convert_np_uint8(rgb_orig),
output_path=rgb_file,
meta=meta_data,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help=
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
args = parser.parse_args()
main(**vars(args))

View File

@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
# The model on Huggingface is currently being updated,
# hence I temporarily mark it as not available online
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False),
}
_CROSS_ENCODER_EXAMPLE_MODELS = {

View File

@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
# Some input builders such as ModelInputForCPUBuilder do not have the
# "inter_data_list" attribute.
# Let's check inter_data_list exists before we reference it.
if hasattr(self.input_builder, "inter_data_list"):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1

View File

@ -254,8 +254,14 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
@ -273,9 +279,15 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest],
) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`."""
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
)
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
if not self.tokenizer:
tokenizer = None
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)

View File

@ -0,0 +1,238 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The vLLM team.
# Copyright 2025 IBM.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
from transformers import BatchFeature
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass
class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
return ProcessorInputs(
prompt_text="",
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
mm_data={
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0)
})
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
pass
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
mm_kwargs = {}
for k, v in mm_data.items():
mm_kwargs[k] = v
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_placeholders={},
)
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
loss=config["task_args"]["loss"],
lr=config["task_args"]["lr"],
ignore_index=config["task_args"]["ignore_index"],
optimizer=config["task_args"]["optimizer"],
optimizer_hparams=config["optimizer_params"],
scheduler=config["task_args"]["scheduler"],
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"])
return task.model
else:
return None
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# the actual model is dynamically instantiated using terratorch
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.)
self.model = self._instantiate_model(
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
if self.model is None:
raise ValueError(
"Unsupported task."
"Only SemanticSegmentationTask is supported for now"
"by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
raise ValueError(f"Incorrect type of location_coords. "
f"Got type: {type(location_coords)}")
location_coords = torch.unbind(location_coords, dim=0)[0]
if location_coords.shape == torch.Size([0]):
location_coords = None
return pixel_values, location_coords
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
location_coords=location_coords)
return model_output.output
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_list = []
model_buffers = dict(self.named_buffers())
loaded_buffers = []
for key, value in weights:
if key == "state_dict":
weights_to_parse = value
for name, weight in weights_to_parse.items():
if "pos_embed" in name:
continue
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if name in model_buffers:
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else:
params_list.append((name, weight))
break
# Load the remaining model parameters
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(params_list)
return autoloaded_weights.union(set(loaded_buffers))

View File

@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
}
_CROSS_ENCODER_MODELS = {

View File

@ -74,7 +74,16 @@ class PoolingModelRunner(
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
# Pooling models are (ab-)used also to integrate non text models that
# are not autoregressive (PrithviGeosaptialMAE).
# These model might not use attention and do not really have a prefill
# and decode phase. The model input is processed in one shot and both
# decode_metadata and prefill_metadata would be None for such models.
# See the PlaceholderAttentionMetadata class.
# TODO: Figure out if cuda_graph is of any use for these models and
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][