[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
Christian Pinto
2025-07-23 19:00:23 +01:00
committed by GitHub
parent 316b1bf706
commit 8560a5b258
15 changed files with 704 additions and 238 deletions

View File

@ -1,122 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This is a demo script showing how to use the
PrithviGeospatialMAE model with vLLM
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
The requirements for running this script are:
- Installing [terratorch, albumentations, rasterio] in your python environment
- downloading the model weights in a 'model' folder local to the script
(temporary measure until the proper config.json file is uploaded to HF)
- download an input example image (India_900498_S2Hand.tif) and place it in
the same folder with the script (or specify with the --data_file argument)
Run the example:
python prithvi_geospatial_mae.py
""" # noqa: E501
import argparse
import datetime
import os
import re
from typing import Union
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm import LLM
torch.set_default_dtype(torch.float16)
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
model_config = """{
"architectures": ["PrithviGeoSpatialMAE"],
"num_classes": 0,
"pretrained_cfg": {
"task_args": {
"task": "SemanticSegmentationTask",
"model_factory": "EncoderDecoderFactory",
"loss": "ce",
"ignore_index": -1,
"lr": 0.001,
"freeze_backbone": false,
"freeze_decoder": false,
"plot_on_val": 10,
"optimizer": "AdamW",
"scheduler": "CosineAnnealingLR"
},
"model_args": {
"backbone_pretrained": false,
"backbone": "prithvi_eo_v2_300_tl",
"decoder": "UperNetDecoder",
"decoder_channels": 256,
"decoder_scale_modules": true,
"num_classes": 2,
"rescale": true,
"backbone_bands": [
"BLUE",
"GREEN",
"RED",
"NIR_NARROW",
"SWIR_1",
"SWIR_2"
],
"head_dropout": 0.1,
"necks": [
{
"name": "SelectIndices",
"indices": [
5,
11,
17,
23
]
},
{
"name": "ReshapeTokensToImage"
}
]
},
"optimizer_params" : {
"lr": 5.0e-05,
"betas": [0.9, 0.999],
"eps": [1.0e-08],
"weight_decay": 0.05,
"amsgrad": false,
"maximize": false,
"capturable": false,
"differentiable": false
},
"scheduler_params" : {
"T_max": 50,
"eta_min": 0,
"last_epoch": -1,
"verbose": "deprecated"
}
},
"torch_dtype": "float32"
}
"""
# Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF
with open(
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config)
datamodule_config = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size": 16,
@ -138,28 +43,24 @@ datamodule_config = {
class PrithviMAE:
def __init__(self):
print("Initializing PrithviMAE model")
self.llm = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
def __init__(self, model):
self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
)
def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure
if input_data is not None and input_data.dtype == torch.float32:
input_data = input_data.to(torch.float16)
input_data = input_data[0]
mm_data = {
"pixel_values": torch.empty(0) if input_data is None else input_data,
"location_coords": torch.empty(0)
if location_coords is None
else location_coords,
"pixel_values": input_data,
"location_coords": location_coords,
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.llm.encode(prompt, use_tqdm=False)
print("################ Inference done (it took seconds) ##############")
outputs = self.model.encode(prompt, use_tqdm=False)
return outputs[0].outputs.data
@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
"""
Args:
orig_img: torch.Tensor representing original image (reference)
with shape = (bands, H, W).
with shape = (bands, H, W).
channels: list of indices representing RGB channels.
Returns:
torch.Tensor with shape (num_channels, height, width) for original image
torch.Tensor with shape (num_channels, height, width)
for original image
"""
orig_img = orig_img[channels, ...]
@ -260,10 +162,10 @@ def load_example(
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the images
in *file_paths*.
std: list containing std values for each band in the images
in *file_paths*.
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
@ -308,7 +210,7 @@ def load_example(
print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
@ -332,8 +234,10 @@ def run_model(
)
# Build sliding window
batch_size = 1
batch = torch.tensor(input_data, device="cpu")
# batch = torch.tensor(input_data, device="cpu")
batch = torch.tensor(input_data)
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
h1, w1 = windows.shape[3:5]
windows = rearrange(
@ -344,18 +248,16 @@ def run_model(
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None
# Run model
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
pred_imgs = []
for x in windows:
# Apply standardization
@ -363,15 +265,7 @@ def run_model(
x = datamodule.aug(x)["image"]
with torch.no_grad():
x = x.to(device)
pred = model.run(x, location_coords=location_coords)
if lightning_model:
pred_lightning = lightning_model(
x, temporal_coords=temporal_coords, location_coords=location_coords
)
pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning):
print("Inference output is not equal")
y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate(
@ -403,52 +297,18 @@ def run_model(
return pred_imgs
def parse_args():
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
def main(
data_file: str,
model: str,
output_dir: str,
rgb_outputs: bool,
input_indices: list[int] = None,
):
os.makedirs(output_dir, exist_ok=True)
# Load model ---------------------------------------------------------------
model_obj = PrithviMAE()
model_obj = PrithviMAE(model=model)
datamodule = generate_datamodule()
img_size = 256 # Size of Sen1Floods11
# Loading data -------------------------------------------------------------
img_size = 512 # Size of Sen1Floods11
input_data, temporal_coords, location_coords, meta_data = load_example(
file_paths=[data_file],
@ -460,8 +320,6 @@ def main(
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
# Running model ------------------------------------------------------------
channels = [
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB
@ -469,7 +327,6 @@ def main(
pred = run_model(
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
)
# Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join(
@ -487,6 +344,7 @@ def main(
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
channels=channels,
)
rgb_orig = rgb_orig.to(torch.float32)
pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3
@ -503,9 +361,10 @@ def main(
# Save image rgb
if rgb_outputs:
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
rgb_file = os.path.join(
output_dir,
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
f"original_rgb_{name_suffix}.tiff",
)
save_geotiff(
image=_convert_np_uint8(rgb_orig),
@ -515,6 +374,42 @@ def main(
if __name__ == "__main__":
args = parse_args()
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
parser.add_argument(
"--data_file",
type=str,
default="./India_900498_S2Hand.tif",
help="Path to the file.",
)
parser.add_argument(
"--model",
type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
help="Path to a checkpoint file to load from.",
)
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Path to the directory where to save outputs.",
)
parser.add_argument(
"--input_indices",
default=[1, 2, 3, 8, 11, 12],
type=int,
nargs="+",
help="""
0-based indices of the six Prithvi channels to be selected from the input.
By default selects [1,2,3,8,11,12] for S2L1C data.
""",
)
parser.add_argument(
"--rgb_outputs",
action="store_true",
help="If present, output files will only contain RGB channels. "
"Otherwise, all bands will be saved.",
)
args = parser.parse_args()
main(**vars(args))

View File

@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test

View File

@ -6,6 +6,10 @@ accelerate==1.0.1
# via
# lm-eval
# peft
aenum==3.1.16
# via lightly
affine==2.4.0
# via rasterio
aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.11
@ -21,8 +25,18 @@ aiosignal==1.3.1
# via
# aiohttp
# ray
albucore==0.0.16
# via terratorch
albumentations==1.4.6
# via terratorch
alembic==1.16.4
# via mlflow
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.6.2.post1
# via
# httpx
@ -34,10 +48,12 @@ arrow==1.3.0
attrs==24.2.0
# via
# aiohttp
# fiona
# hypothesis
# jsonlines
# jsonschema
# pytest-subtests
# rasterio
# referencing
audioread==3.0.1
# via librosa
@ -46,9 +62,13 @@ backoff==2.2.1
# -r requirements/test.in
# schemathesis
bitsandbytes==0.46.1
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
black==24.10.0
# via datamodel-code-generator
blinker==1.9.0
# via flask
blobfile==3.0.0
# via -r requirements/test.in
bm25s==0.2.13
@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
# via
# google-auth
# mlflow-skinny
certifi==2024.8.30
# via
# fiona
# httpcore
# httpx
# lightly
# pyogrio
# pyproj
# rasterio
# requests
cffi==1.17.1
# via soundfile
@ -79,11 +106,28 @@ charset-normalizer==3.4.0
click==8.1.7
# via
# black
# click-plugins
# cligj
# fiona
# flask
# jiwer
# mlflow-skinny
# nltk
# rasterio
# ray
# schemathesis
# typer
# uvicorn
click-plugins==1.1.1.2
# via
# fiona
# rasterio
cligj==0.7.2
# via
# fiona
# rasterio
cloudpickle==3.1.1
# via mlflow-skinny
colorama==0.4.6
# via
# sacrebleu
@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
# via ray
cycler==0.12.1
# via matplotlib
databricks-sdk==0.59.0
# via mlflow-skinny
datamodel-code-generator==0.26.3
# via -r requirements/test.in
dataproperty==1.0.1
@ -122,13 +168,21 @@ distlib==0.3.9
# via virtualenv
dnspython==2.7.0
# via email-validator
docker==7.1.0
# via mlflow
docopt==0.6.2
# via num2words
einops==0.8.0
docstring-parser==0.17.0
# via jsonargparse
efficientnet-pytorch==0.7.1
# via segmentation-models-pytorch
einops==0.8.1
# via
# -r requirements/test.in
# encodec
# mamba-ssm
# terratorch
# torchgeo
# vector-quantize-pytorch
# vocos
einx==0.3.0
@ -141,6 +195,8 @@ eval-type-backport==0.2.2
# via mteb
evaluate==0.4.3
# via lm-eval
fastapi==0.116.1
# via mlflow-skinny
fastparquet==2024.11.0
# via genai-perf
fastrlock==0.8.2
@ -156,6 +212,10 @@ filelock==3.16.1
# torch
# transformers
# virtualenv
fiona==1.10.1
# via torchgeo
flask==3.1.1
# via mlflow
fonttools==4.54.1
# via matplotlib
fqdn==1.5.1
@ -173,6 +233,8 @@ fsspec==2024.9.0
# evaluate
# fastparquet
# huggingface-hub
# lightning
# pytorch-lightning
# torch
ftfy==6.3.1
# via open-clip-torch
@ -180,18 +242,41 @@ genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
geopandas==1.0.1
# via terratorch
gitdb==4.0.12
# via gitpython
gitpython==3.1.44
# via mlflow-skinny
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
# via
# databricks-sdk
# google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphene==3.4.3
# via mlflow
graphql-core==3.2.6
# via hypothesis-graphql
# via
# graphene
# graphql-relay
# hypothesis-graphql
graphql-relay==3.2.0
# via graphene
greenlet==3.2.3
# via sqlalchemy
grpcio==1.71.0
# via ray
gunicorn==23.0.0
# via mlflow
h11==0.14.0
# via httpcore
# via
# httpcore
# uvicorn
h5py==3.13.0
# via terratorch
harfile==0.3.0
# via schemathesis
hf-xet==1.1.3
@ -204,7 +289,7 @@ httpx==0.27.2
# via
# -r requirements/test.in
# schemathesis
huggingface-hub==0.33.0
huggingface-hub==0.33.1
# via
# -r requirements/test.in
# accelerate
@ -212,13 +297,19 @@ huggingface-hub==0.33.0
# evaluate
# open-clip-torch
# peft
# segmentation-models-pytorch
# sentence-transformers
# terratorch
# timm
# tokenizers
# transformers
# vocos
humanize==4.11.0
# via runai-model-streamer
hydra-core==1.3.2
# via
# lightly
# lightning
hypothesis==6.131.0
# via
# hypothesis-graphql
@ -236,6 +327,14 @@ idna==3.10
# jsonschema
# requests
# yarl
imageio==2.37.0
# via scikit-image
importlib-metadata==8.7.0
# via
# mlflow-skinny
# opentelemetry-api
importlib-resources==6.5.2
# via typeshed-client
inflect==5.6.2
# via datamodel-code-generator
iniconfig==2.0.0
@ -244,9 +343,13 @@ isoduration==20.11.0
# via jsonschema
isort==5.13.2
# via datamodel-code-generator
itsdangerous==2.2.0
# via flask
jinja2==3.1.6
# via
# datamodel-code-generator
# flask
# mlflow
# torch
jiwer==3.0.5
# via -r requirements/test.in
@ -259,6 +362,10 @@ joblib==1.4.2
# librosa
# nltk
# scikit-learn
jsonargparse==4.35.0
# via
# lightning
# terratorch
jsonlines==4.0.0
# via lm-eval
jsonpointer==3.0.0
@ -277,12 +384,33 @@ kaleido==0.2.1
# via genai-perf
kiwisolver==1.4.7
# via matplotlib
kornia==0.8.1
# via torchgeo
kornia-rs==0.1.9
# via kornia
lazy-loader==0.4
# via librosa
# via
# librosa
# scikit-image
libnacl==2.1.0
# via tensorizer
librosa==0.10.2.post1
# via -r requirements/test.in
lightly==1.5.20
# via
# terratorch
# torchgeo
lightly-utils==0.0.2
# via lightly
lightning==2.5.1.post0
# via
# terratorch
# torchgeo
lightning-utilities==0.14.3
# via
# lightning
# pytorch-lightning
# torchmetrics
llvmlite==0.44.0
# via numba
lm-eval==0.4.8
@ -291,16 +419,27 @@ lxml==5.3.0
# via
# blobfile
# sacrebleu
mako==1.3.10
# via alembic
mamba-ssm==2.2.4
# via -r requirements/test.in
markdown==3.8.2
# via mlflow
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.1
# via
# flask
# jinja2
# mako
# werkzeug
matplotlib==3.9.2
# via -r requirements/test.in
# via
# -r requirements/test.in
# lightning
# mlflow
# pycocotools
# torchgeo
mbstrdecoder==1.1.3
# via
# dataproperty
@ -310,6 +449,10 @@ mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.0
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
mlflow-skinny==2.22.0
# via mlflow
more-itertools==10.5.0
# via lm-eval
mpmath==1.3.0
@ -328,10 +471,14 @@ multiprocess==0.70.16
# via
# datasets
# evaluate
munch==4.0.0
# via pretrainedmodels
mypy-extensions==1.0.0
# via black
networkx==3.2.1
# via torch
# via
# scikit-image
# torch
ninja==1.11.1.3
# via mamba-ssm
nltk==3.9.1
@ -348,6 +495,8 @@ numpy==1.26.4
# via
# -r requirements/test.in
# accelerate
# albucore
# albumentations
# bitsandbytes
# bm25s
# contourpy
@ -358,9 +507,15 @@ numpy==1.26.4
# evaluate
# fastparquet
# genai-perf
# geopandas
# h5py
# imageio
# librosa
# lightly
# lightly-utils
# matplotlib
# mistral-common
# mlflow
# mteb
# numba
# numexpr
@ -368,18 +523,30 @@ numpy==1.26.4
# pandas
# patsy
# peft
# pycocotools
# pyogrio
# rasterio
# rioxarray
# rouge-score
# runai-model-streamer
# sacrebleu
# scikit-image
# scikit-learn
# scipy
# segmentation-models-pytorch
# shapely
# soxr
# statsmodels
# tensorboardx
# tensorizer
# tifffile
# torchgeo
# torchmetrics
# torchvision
# transformers
# tritonclient
# vocos
# xarray
nvidia-cublas-cu12==12.8.3.14
# via
# nvidia-cudnn-cu12
@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
omegaconf==2.3.0
# via
# hydra-core
# lightning
open-clip-torch==2.32.0
# via -r requirements/test.in
opencensus==0.11.4
@ -426,7 +597,18 @@ opencensus-context==0.1.3
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
# albucore
# albumentations
# mistral-common
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-sdk==1.35.0
# via mlflow-skinny
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
# via
# accelerate
@ -435,26 +617,44 @@ packaging==24.2
# datasets
# evaluate
# fastparquet
# geopandas
# gunicorn
# huggingface-hub
# hydra-core
# kornia
# lazy-loader
# lightning
# lightning-utilities
# mamba-ssm
# matplotlib
# mlflow-skinny
# peft
# plotly
# pooch
# pyogrio
# pytest
# pytest-rerunfailures
# pytorch-lightning
# ray
# rioxarray
# scikit-image
# statsmodels
# tensorboardx
# torchmetrics
# transformers
# typepy
# xarray
pandas==2.2.3
# via
# datasets
# evaluate
# fastparquet
# genai-perf
# geopandas
# mlflow
# statsmodels
# torchgeo
# xarray
pathspec==0.12.1
# via black
pathvalidate==3.2.1
@ -468,9 +668,14 @@ peft==0.13.2
pillow==10.4.0
# via
# genai-perf
# imageio
# lightly-utils
# matplotlib
# mistral-common
# scikit-image
# segmentation-models-pytorch
# sentence-transformers
# torchgeo
# torchvision
platformdirs==4.3.6
# via
@ -489,6 +694,8 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
# via ray
propcache==0.2.0
@ -499,8 +706,10 @@ protobuf==5.28.3
# via
# google-api-core
# googleapis-common-protos
# mlflow-skinny
# proto-plus
# ray
# tensorboardx
# tensorizer
psutil==6.1.0
# via
@ -515,6 +724,7 @@ pyarrow==18.0.0
# via
# datasets
# genai-perf
# mlflow
pyasn1==0.6.1
# via
# pyasn1-modules
@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycocotools==2.0.8
# via terratorch
pycountry==24.6.1
# via pydantic-extra-types
pycparser==2.22
@ -532,8 +744,12 @@ pycryptodomex==3.22.0
pydantic==2.11.5
# via
# -r requirements/test.in
# albumentations
# datamodel-code-generator
# fastapi
# lightly
# mistral-common
# mlflow-skinny
# mteb
# pydantic-extra-types
# ray
@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
# via mistral-common
pygments==2.18.0
# via rich
pyogrio==0.11.0
# via geopandas
pyparsing==3.2.0
# via matplotlib
# via
# matplotlib
# rasterio
pyproj==3.7.1
# via
# geopandas
# rioxarray
# torchgeo
pyrate-limiter==3.7.0
# via schemathesis
pystemmer==3.0.0
# via mteb
pytablewriter==1.2.0
# via lm-eval
pytest==8.3.3
pytest==8.3.5
# via
# -r requirements/test.in
# buildkite-test-collector
@ -564,6 +789,7 @@ pytest==8.3.3
# pytest-subtests
# pytest-timeout
# schemathesis
# terratorch
pytest-asyncio==0.24.0
# via -r requirements/test.in
pytest-forked==1.6.0
@ -578,15 +804,23 @@ pytest-subtests==0.14.1
# via schemathesis
pytest-timeout==2.3.1
# via -r requirements/test.in
python-box==7.3.2
# via terratorch
python-dateutil==2.9.0.post0
# via
# arrow
# botocore
# graphene
# lightly
# matplotlib
# pandas
# typepy
python-rapidjson==1.20
# via tritonclient
pytorch-lightning==2.5.2
# via
# lightly
# lightning
pytrec-eval-terrier==0.5.7
# via mteb
pytz==2024.2
@ -596,11 +830,17 @@ pytz==2024.2
pyyaml==6.0.2
# via
# accelerate
# albumentations
# datamodel-code-generator
# datasets
# genai-perf
# huggingface-hub
# jsonargparse
# lightning
# mlflow-skinny
# omegaconf
# peft
# pytorch-lightning
# ray
# responses
# schemathesis
@ -609,6 +849,11 @@ pyyaml==6.0.2
# vocos
rapidfuzz==3.12.1
# via jiwer
rasterio==1.4.3
# via
# rioxarray
# terratorch
# torchgeo
ray==2.43.0
# via -r requirements/test.in
redis==5.2.0
@ -627,12 +872,16 @@ regex==2024.9.11
requests==2.32.3
# via
# buildkite-test-collector
# databricks-sdk
# datasets
# docker
# evaluate
# google-api-core
# huggingface-hub
# lightly
# lm-eval
# mistral-common
# mlflow-skinny
# mteb
# pooch
# ray
@ -650,8 +899,11 @@ rfc3987==1.3.8
rich==13.9.4
# via
# genai-perf
# lightning
# mteb
# typer
rioxarray==0.19.0
# via terratorch
rouge-score==0.1.2
# via lm-eval
rpds-py==0.20.1
@ -660,6 +912,8 @@ rpds-py==0.20.1
# referencing
rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
@ -677,21 +931,32 @@ safetensors==0.4.5
# transformers
schemathesis==3.39.15
# via -r requirements/test.in
scikit-image==0.25.2
# via albumentations
scikit-learn==1.5.2
# via
# albumentations
# librosa
# lm-eval
# mlflow
# mteb
# sentence-transformers
scipy==1.13.1
# via
# albumentations
# bm25s
# librosa
# mlflow
# mteb
# scikit-image
# scikit-learn
# sentence-transformers
# statsmodels
# vocos
segmentation-models-pytorch==0.4.0
# via
# terratorch
# torchgeo
sentence-transformers==3.2.1
# via
# -r requirements/test.in
@ -700,21 +965,30 @@ sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3
# via
# lightning-utilities
# mamba-ssm
# pytablewriter
# torch
# triton
shapely==2.1.1
# via
# geopandas
# torchgeo
shellingham==1.5.4
# via typer
six==1.16.0
# via
# junit-xml
# lightly
# opencensus
# python-dateutil
# rfc3339-validator
# rouge-score
# segmentation-models-pytorch
smart-open==7.1.0
# via ray
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via
# anyio
@ -727,10 +1001,17 @@ soundfile==0.12.1
# librosa
soxr==0.5.0.post1
# via librosa
sqlalchemy==2.0.41
# via
# alembic
# mlflow
sqlitedict==2.1.0
# via lm-eval
sqlparse==0.5.3
# via mlflow-skinny
starlette==0.46.2
# via
# fastapi
# schemathesis
# starlette-testclient
starlette-testclient==0.4.1
@ -751,18 +1032,29 @@ tenacity==9.0.0
# via
# lm-eval
# plotly
tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc2
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn
tifffile==2025.3.30
# via
# scikit-image
# terratorch
tiktoken==0.7.0
# via
# lm-eval
# mistral-common
timm==1.0.11
timm==1.0.15
# via
# -r requirements/test.in
# open-clip-torch
# segmentation-models-pytorch
# terratorch
# torchgeo
tokenizers==0.21.1
# via
# -r requirements/test.in
@ -776,18 +1068,28 @@ torch==2.7.1+cu128
# -r requirements/test.in
# accelerate
# bitsandbytes
# efficientnet-pytorch
# encodec
# fastsafetensors
# kornia
# lightly
# lightning
# lm-eval
# mamba-ssm
# mteb
# open-clip-torch
# peft
# pretrainedmodels
# pytorch-lightning
# runai-model-streamer
# segmentation-models-pytorch
# sentence-transformers
# tensorizer
# terratorch
# timm
# torchaudio
# torchgeo
# torchmetrics
# torchvision
# vector-quantize-pytorch
# vocos
@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
# -r requirements/test.in
# encodec
# vocos
torchgeo==0.7.0
# via terratorch
torchmetrics==1.7.4
# via
# lightning
# pytorch-lightning
# terratorch
# torchgeo
torchvision==0.22.1+cu128
# via
# -r requirements/test.in
# lightly
# open-clip-torch
# pretrainedmodels
# segmentation-models-pytorch
# terratorch
# timm
# torchgeo
tqdm==4.66.6
# via
# datasets
# evaluate
# huggingface-hub
# lightly
# lightning
# lm-eval
# mteb
# nltk
# open-clip-torch
# peft
# pqdm
# pretrainedmodels
# pytorch-lightning
# segmentation-models-pytorch
# sentence-transformers
# tqdm-multiprocess
# transformers
@ -843,18 +1163,34 @@ typer==0.15.2
# via fastsafetensors
types-python-dateutil==2.9.0.20241206
# via arrow
typeshed-client==2.8.2
# via jsonargparse
typing-extensions==4.12.2
# via
# albumentations
# alembic
# fastapi
# graphene
# huggingface-hub
# librosa
# lightning
# lightning-utilities
# mistral-common
# mlflow-skinny
# mteb
# opentelemetry-api
# opentelemetry-sdk
# opentelemetry-semantic-conventions
# pqdm
# pydantic
# pydantic-core
# pydantic-extra-types
# pytorch-lightning
# sqlalchemy
# torch
# torchgeo
# typer
# typeshed-client
# typing-inspection
typing-inspection==0.4.1
# via pydantic
@ -866,9 +1202,13 @@ urllib3==2.2.3
# via
# blobfile
# botocore
# docker
# lightly
# requests
# responses
# tritonclient
uvicorn==0.35.0
# via mlflow-skinny
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
@ -880,11 +1220,15 @@ wcwidth==0.2.13
webcolors==24.11.1
# via jsonschema
werkzeug==3.1.3
# via schemathesis
# via
# flask
# schemathesis
word2number==1.1
# via lm-eval
wrapt==1.17.2
# via smart-open
xarray==2025.7.1
# via rioxarray
xxhash==3.5.0
# via
# datasets
@ -893,5 +1237,7 @@ yarl==1.17.1
# via
# aiohttp
# schemathesis
zipp==3.23.0
# via importlib-metadata
zstandard==0.23.0
# via lm-eval

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.utils import set_default_torch_num_threads
from ....conftest import VllmRunner
def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data
def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
prompt = [
{
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": generate_test_mm_data(),
} for _ in range(10)
]
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)

View File

@ -651,6 +651,8 @@ class ModelConfig:
self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self.registry.supports_multimodal_raw_input(self.architectures))
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@ -1243,10 +1245,10 @@ class ModelConfig:
return self.get_hf_config_sliding_window()
def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
@property
def is_deepseek_mla(self) -> bool:

View File

@ -238,14 +238,14 @@ class LLMEngine:
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs
if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
else:
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
self.detokenizer = None
tokenizer_group = None
else:
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
tokenizer_group = self.get_tokenizer_group()
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues

View File

@ -136,6 +136,40 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...
@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...
def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
return getattr(model, "supports_multimodal_raw_input", False)
@runtime_checkable
class SupportsScoreTemplate(Protocol):
"""The interface required for all models that support score template."""

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Union
@ -27,13 +28,14 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
PoolerIdentity, SimplePooler)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.interfaces import (
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
def _get_prompt_updates(
@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
for k, v in mm_data.items():
mm_kwargs[k] = v
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(
modality="image",
key=key,
data=data,
field=MultiModalSharedField(1),
) for key, data in mm_kwargs.items()
])
]
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders={},
mm_placeholders=mm_placeholders,
)
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
"""Prithvi Masked Autoencoder"""
is_pooling_model = True
@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"])
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty((input_ids.shape[0], 0))
def forward(
self,
input_ids: Optional[torch.Tensor],

View File

@ -22,8 +22,8 @@ from vllm.logger import init_logger
from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -287,6 +287,7 @@ class _ModelInfo:
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@ -304,6 +305,7 @@ class _ModelInfo:
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
@ -573,6 +575,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def supports_multimodal_raw_input(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal_raw_input
def is_pp_supported_model(
self,
architectures: Union[str, list[str]],

View File

@ -266,7 +266,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
if tokenizer is None:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
mm_config = model_config.get_multimodal_config()

View File

@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
self.log_requests = log_requests
self.log_stats = log_stats
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
return self.tokenizer.get_lora_tokenizer(lora_request)
async def is_tracing_enabled(self) -> bool:

View File

@ -82,11 +82,14 @@ class LLMEngine:
self.dp_group = None
self.should_execute_dummy_batch = False
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
if self.model_config.skip_tokenizer_init:
self.tokenizer = None
else:
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
lora_config=vllm_config.lora_config)
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config=vllm_config,

View File

@ -327,14 +327,16 @@ class OutputProcessor:
if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.")
req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
tokenizer = None if not self.tokenizer else \
self.tokenizer.get_lora_tokenizer(request.lora_request)
req_state = RequestState.from_new_request(tokenizer=tokenizer,
request=request,
prompt=prompt,
parent_req=parent_req,
request_index=request_index,
queue=queue,
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:

View File

@ -380,7 +380,6 @@ class Processor:
prompt_type: Literal["encoder", "decoder"],
):
model_config = self.model_config
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids:
@ -389,9 +388,14 @@ class Processor:
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:

View File

@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_multimodal_model = model_config.is_multimodal_model
self.is_pooling_model = model_config.pooler_config is not None
self.model_supports_multimodal_raw_input = (
model_config.model_supports_multimodal_raw_input)
self.max_model_len = model_config.max_model_len
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _init_model_kwargs_for_multimodal_model(
self,
scheduler_output: Optional["SchedulerOutput"] = None,
num_reqs: int = -1,
) -> dict[str, Any]:
model_kwargs: dict[str, Any] = {}
if self.model_supports_multimodal_raw_input:
# This model requires the raw multimodal data in input.
if scheduler_output:
multi_modal_kwargs_list = []
for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs)
multi_modal_kwargs = MultiModalKwargs.batch(
multi_modal_kwargs_list)
else:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data = [
self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=1).multi_modal_data for i in range(num_reqs)
]
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
model_kwargs.update(multi_modal_kwargs)
return model_kwargs
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
model_kwargs = self._init_model_kwargs_for_multimodal_model(
scheduler_output=scheduler_output)
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
multimodal_embeddings=mm_embeds or None,
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {}
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
else:
@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
)
self.maybe_wait_for_kv_save()
@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens):
model = self.model
if self.is_multimodal_model:
model_kwargs = self._init_model_kwargs_for_multimodal_model(
num_reqs=num_reqs)
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
model_kwargs = {}
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
else:
@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs(
model_kwargs,
device=self.device,
),
)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = outputs
else: