mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (#20577)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
@ -1,122 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This is a demo script showing how to use the
|
||||
PrithviGeospatialMAE model with vLLM
|
||||
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
|
||||
|
||||
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
|
||||
|
||||
The requirements for running this script are:
|
||||
- Installing [terratorch, albumentations, rasterio] in your python environment
|
||||
- downloading the model weights in a 'model' folder local to the script
|
||||
(temporary measure until the proper config.json file is uploaded to HF)
|
||||
- download an input example image (India_900498_S2Hand.tif) and place it in
|
||||
the same folder with the script (or specify with the --data_file argument)
|
||||
|
||||
Run the example:
|
||||
python prithvi_geospatial_mae.py
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
model_config = """{
|
||||
"architectures": ["PrithviGeoSpatialMAE"],
|
||||
"num_classes": 0,
|
||||
"pretrained_cfg": {
|
||||
"task_args": {
|
||||
"task": "SemanticSegmentationTask",
|
||||
"model_factory": "EncoderDecoderFactory",
|
||||
"loss": "ce",
|
||||
"ignore_index": -1,
|
||||
"lr": 0.001,
|
||||
"freeze_backbone": false,
|
||||
"freeze_decoder": false,
|
||||
"plot_on_val": 10,
|
||||
"optimizer": "AdamW",
|
||||
"scheduler": "CosineAnnealingLR"
|
||||
},
|
||||
"model_args": {
|
||||
"backbone_pretrained": false,
|
||||
"backbone": "prithvi_eo_v2_300_tl",
|
||||
"decoder": "UperNetDecoder",
|
||||
"decoder_channels": 256,
|
||||
"decoder_scale_modules": true,
|
||||
"num_classes": 2,
|
||||
"rescale": true,
|
||||
"backbone_bands": [
|
||||
"BLUE",
|
||||
"GREEN",
|
||||
"RED",
|
||||
"NIR_NARROW",
|
||||
"SWIR_1",
|
||||
"SWIR_2"
|
||||
],
|
||||
"head_dropout": 0.1,
|
||||
"necks": [
|
||||
{
|
||||
"name": "SelectIndices",
|
||||
"indices": [
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ReshapeTokensToImage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"optimizer_params" : {
|
||||
"lr": 5.0e-05,
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": [1.0e-08],
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": false,
|
||||
"maximize": false,
|
||||
"capturable": false,
|
||||
"differentiable": false
|
||||
},
|
||||
"scheduler_params" : {
|
||||
"T_max": 50,
|
||||
"eta_min": 0,
|
||||
"last_epoch": -1,
|
||||
"verbose": "deprecated"
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
"""
|
||||
|
||||
# Temporarily creating the "config.json" for the model.
|
||||
# This is going to disappear once the correct config.json is available on HF
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
|
||||
) as config_file:
|
||||
config_file.write(model_config)
|
||||
|
||||
datamodule_config = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size": 16,
|
||||
@ -138,28 +43,24 @@ datamodule_config = {
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
def __init__(self):
|
||||
print("Initializing PrithviMAE model")
|
||||
self.llm = LLM(
|
||||
model=os.path.join(os.path.dirname(__file__), "./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32",
|
||||
def __init__(self, model):
|
||||
self.model = LLM(
|
||||
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
print("################ Running inference on vLLM ##############")
|
||||
# merge the inputs into one data structure
|
||||
if input_data is not None and input_data.dtype == torch.float32:
|
||||
input_data = input_data.to(torch.float16)
|
||||
input_data = input_data[0]
|
||||
|
||||
mm_data = {
|
||||
"pixel_values": torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords": torch.empty(0)
|
||||
if location_coords is None
|
||||
else location_coords,
|
||||
"pixel_values": input_data,
|
||||
"location_coords": location_coords,
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
|
||||
outputs = self.llm.encode(prompt, use_tqdm=False)
|
||||
print("################ Inference done (it took seconds) ##############")
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
|
||||
"""
|
||||
Args:
|
||||
orig_img: torch.Tensor representing original image (reference)
|
||||
with shape = (bands, H, W).
|
||||
with shape = (bands, H, W).
|
||||
channels: list of indices representing RGB channels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with shape (num_channels, height, width) for original image
|
||||
torch.Tensor with shape (num_channels, height, width)
|
||||
for original image
|
||||
"""
|
||||
|
||||
orig_img = orig_img[channels, ...]
|
||||
@ -260,10 +162,10 @@ def load_example(
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the images
|
||||
in *file_paths*.
|
||||
std: list containing std values for each band in the images
|
||||
in *file_paths*.
|
||||
mean: list containing mean values for each band in the
|
||||
images in *file_paths*.
|
||||
std: list containing std values for each band in the
|
||||
images in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
@ -308,7 +210,7 @@ def load_example(
|
||||
print(f"Could not extract timestamp for {file} ({e})")
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
@ -332,8 +234,10 @@ def run_model(
|
||||
)
|
||||
|
||||
# Build sliding window
|
||||
|
||||
batch_size = 1
|
||||
batch = torch.tensor(input_data, device="cpu")
|
||||
# batch = torch.tensor(input_data, device="cpu")
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
@ -344,18 +248,16 @@ def run_model(
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
|
||||
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
|
||||
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
# Run model
|
||||
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
@ -363,15 +265,7 @@ def run_model(
|
||||
x = datamodule.aug(x)["image"]
|
||||
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
if lightning_model:
|
||||
pred_lightning = lightning_model(
|
||||
x, temporal_coords=temporal_coords, location_coords=location_coords
|
||||
)
|
||||
pred_lightning = pred_lightning.output.detach().cpu()
|
||||
if not torch.equal(pred, pred_lightning):
|
||||
print("Inference output is not equal")
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(
|
||||
@ -403,52 +297,18 @@ def run_model(
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
model: str,
|
||||
output_dir: str,
|
||||
rgb_outputs: bool,
|
||||
input_indices: list[int] = None,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load model ---------------------------------------------------------------
|
||||
|
||||
model_obj = PrithviMAE()
|
||||
model_obj = PrithviMAE(model=model)
|
||||
datamodule = generate_datamodule()
|
||||
img_size = 256 # Size of Sen1Floods11
|
||||
|
||||
# Loading data -------------------------------------------------------------
|
||||
img_size = 512 # Size of Sen1Floods11
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||
file_paths=[data_file],
|
||||
@ -460,8 +320,6 @@ def main(
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
# Running model ------------------------------------------------------------
|
||||
|
||||
channels = [
|
||||
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
@ -469,7 +327,6 @@ def main(
|
||||
pred = run_model(
|
||||
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
|
||||
)
|
||||
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
@ -487,6 +344,7 @@ def main(
|
||||
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||
channels=channels,
|
||||
)
|
||||
rgb_orig = rgb_orig.to(torch.float32)
|
||||
|
||||
pred[pred == 0.0] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
@ -503,9 +361,10 @@ def main(
|
||||
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
|
||||
rgb_file = os.path.join(
|
||||
output_dir,
|
||||
f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
|
||||
f"original_rgb_{name_suffix}.tiff",
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
@ -515,6 +374,42 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
help="Path to a checkpoint file to load from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="""
|
||||
0-based indices of the six Prithvi channels to be selected from the input.
|
||||
By default selects [1,2,3,8,11,12] for S2L1C data.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
@ -54,3 +54,4 @@ runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
@ -6,6 +6,10 @@ accelerate==1.0.1
|
||||
# via
|
||||
# lm-eval
|
||||
# peft
|
||||
aenum==3.1.16
|
||||
# via lightly
|
||||
affine==2.4.0
|
||||
# via rasterio
|
||||
aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.10.11
|
||||
@ -21,8 +25,18 @@ aiosignal==1.3.1
|
||||
# via
|
||||
# aiohttp
|
||||
# ray
|
||||
albucore==0.0.16
|
||||
# via terratorch
|
||||
albumentations==1.4.6
|
||||
# via terratorch
|
||||
alembic==1.16.4
|
||||
# via mlflow
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via
|
||||
# hydra-core
|
||||
# omegaconf
|
||||
anyio==4.6.2.post1
|
||||
# via
|
||||
# httpx
|
||||
@ -34,10 +48,12 @@ arrow==1.3.0
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# fiona
|
||||
# hypothesis
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# pytest-subtests
|
||||
# rasterio
|
||||
# referencing
|
||||
audioread==3.0.1
|
||||
# via librosa
|
||||
@ -46,9 +62,13 @@ backoff==2.2.1
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
bitsandbytes==0.46.1
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightning
|
||||
black==24.10.0
|
||||
# via datamodel-code-generator
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
blobfile==3.0.0
|
||||
# via -r requirements/test.in
|
||||
bm25s==0.2.13
|
||||
@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3
|
||||
buildkite-test-collector==0.1.9
|
||||
# via -r requirements/test.in
|
||||
cachetools==5.5.2
|
||||
# via google-auth
|
||||
# via
|
||||
# google-auth
|
||||
# mlflow-skinny
|
||||
certifi==2024.8.30
|
||||
# via
|
||||
# fiona
|
||||
# httpcore
|
||||
# httpx
|
||||
# lightly
|
||||
# pyogrio
|
||||
# pyproj
|
||||
# rasterio
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
@ -79,11 +106,28 @@ charset-normalizer==3.4.0
|
||||
click==8.1.7
|
||||
# via
|
||||
# black
|
||||
# click-plugins
|
||||
# cligj
|
||||
# fiona
|
||||
# flask
|
||||
# jiwer
|
||||
# mlflow-skinny
|
||||
# nltk
|
||||
# rasterio
|
||||
# ray
|
||||
# schemathesis
|
||||
# typer
|
||||
# uvicorn
|
||||
click-plugins==1.1.1.2
|
||||
# via
|
||||
# fiona
|
||||
# rasterio
|
||||
cligj==0.7.2
|
||||
# via
|
||||
# fiona
|
||||
# rasterio
|
||||
cloudpickle==3.1.1
|
||||
# via mlflow-skinny
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# sacrebleu
|
||||
@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0
|
||||
# via ray
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
databricks-sdk==0.59.0
|
||||
# via mlflow-skinny
|
||||
datamodel-code-generator==0.26.3
|
||||
# via -r requirements/test.in
|
||||
dataproperty==1.0.1
|
||||
@ -122,13 +168,21 @@ distlib==0.3.9
|
||||
# via virtualenv
|
||||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
docker==7.1.0
|
||||
# via mlflow
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
einops==0.8.0
|
||||
docstring-parser==0.17.0
|
||||
# via jsonargparse
|
||||
efficientnet-pytorch==0.7.1
|
||||
# via segmentation-models-pytorch
|
||||
einops==0.8.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# mamba-ssm
|
||||
# terratorch
|
||||
# torchgeo
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
@ -141,6 +195,8 @@ eval-type-backport==0.2.2
|
||||
# via mteb
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastapi==0.116.1
|
||||
# via mlflow-skinny
|
||||
fastparquet==2024.11.0
|
||||
# via genai-perf
|
||||
fastrlock==0.8.2
|
||||
@ -156,6 +212,10 @@ filelock==3.16.1
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fiona==1.10.1
|
||||
# via torchgeo
|
||||
flask==3.1.1
|
||||
# via mlflow
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
@ -173,6 +233,8 @@ fsspec==2024.9.0
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# torch
|
||||
ftfy==6.3.1
|
||||
# via open-clip-torch
|
||||
@ -180,18 +242,41 @@ genai-perf==0.0.8
|
||||
# via -r requirements/test.in
|
||||
genson==1.3.0
|
||||
# via datamodel-code-generator
|
||||
geopandas==1.0.1
|
||||
# via terratorch
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.44
|
||||
# via mlflow-skinny
|
||||
google-api-core==2.24.2
|
||||
# via opencensus
|
||||
google-auth==2.40.2
|
||||
# via google-api-core
|
||||
# via
|
||||
# databricks-sdk
|
||||
# google-api-core
|
||||
googleapis-common-protos==1.70.0
|
||||
# via google-api-core
|
||||
graphene==3.4.3
|
||||
# via mlflow
|
||||
graphql-core==3.2.6
|
||||
# via hypothesis-graphql
|
||||
# via
|
||||
# graphene
|
||||
# graphql-relay
|
||||
# hypothesis-graphql
|
||||
graphql-relay==3.2.0
|
||||
# via graphene
|
||||
greenlet==3.2.3
|
||||
# via sqlalchemy
|
||||
grpcio==1.71.0
|
||||
# via ray
|
||||
gunicorn==23.0.0
|
||||
# via mlflow
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
h5py==3.13.0
|
||||
# via terratorch
|
||||
harfile==0.3.0
|
||||
# via schemathesis
|
||||
hf-xet==1.1.3
|
||||
@ -204,7 +289,7 @@ httpx==0.27.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
huggingface-hub==0.33.0
|
||||
huggingface-hub==0.33.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
@ -212,13 +297,19 @@ huggingface-hub==0.33.0
|
||||
# evaluate
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# terratorch
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
humanize==4.11.0
|
||||
# via runai-model-streamer
|
||||
hydra-core==1.3.2
|
||||
# via
|
||||
# lightly
|
||||
# lightning
|
||||
hypothesis==6.131.0
|
||||
# via
|
||||
# hypothesis-graphql
|
||||
@ -236,6 +327,14 @@ idna==3.10
|
||||
# jsonschema
|
||||
# requests
|
||||
# yarl
|
||||
imageio==2.37.0
|
||||
# via scikit-image
|
||||
importlib-metadata==8.7.0
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-api
|
||||
importlib-resources==6.5.2
|
||||
# via typeshed-client
|
||||
inflect==5.6.2
|
||||
# via datamodel-code-generator
|
||||
iniconfig==2.0.0
|
||||
@ -244,9 +343,13 @@ isoduration==20.11.0
|
||||
# via jsonschema
|
||||
isort==5.13.2
|
||||
# via datamodel-code-generator
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# datamodel-code-generator
|
||||
# flask
|
||||
# mlflow
|
||||
# torch
|
||||
jiwer==3.0.5
|
||||
# via -r requirements/test.in
|
||||
@ -259,6 +362,10 @@ joblib==1.4.2
|
||||
# librosa
|
||||
# nltk
|
||||
# scikit-learn
|
||||
jsonargparse==4.35.0
|
||||
# via
|
||||
# lightning
|
||||
# terratorch
|
||||
jsonlines==4.0.0
|
||||
# via lm-eval
|
||||
jsonpointer==3.0.0
|
||||
@ -277,12 +384,33 @@ kaleido==0.2.1
|
||||
# via genai-perf
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
kornia==0.8.1
|
||||
# via torchgeo
|
||||
kornia-rs==0.1.9
|
||||
# via kornia
|
||||
lazy-loader==0.4
|
||||
# via librosa
|
||||
# via
|
||||
# librosa
|
||||
# scikit-image
|
||||
libnacl==2.1.0
|
||||
# via tensorizer
|
||||
librosa==0.10.2.post1
|
||||
# via -r requirements/test.in
|
||||
lightly==1.5.20
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
lightly-utils==0.0.2
|
||||
# via lightly
|
||||
lightning==2.5.1.post0
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
lightning-utilities==0.14.3
|
||||
# via
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# torchmetrics
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.8
|
||||
@ -291,16 +419,27 @@ lxml==5.3.0
|
||||
# via
|
||||
# blobfile
|
||||
# sacrebleu
|
||||
mako==1.3.10
|
||||
# via alembic
|
||||
mamba-ssm==2.2.4
|
||||
# via -r requirements/test.in
|
||||
markdown==3.8.2
|
||||
# via mlflow
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.1
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# mako
|
||||
# werkzeug
|
||||
matplotlib==3.9.2
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightning
|
||||
# mlflow
|
||||
# pycocotools
|
||||
# torchgeo
|
||||
mbstrdecoder==1.1.3
|
||||
# via
|
||||
# dataproperty
|
||||
@ -310,6 +449,10 @@ mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.8.0
|
||||
# via -r requirements/test.in
|
||||
mlflow==2.22.0
|
||||
# via terratorch
|
||||
mlflow-skinny==2.22.0
|
||||
# via mlflow
|
||||
more-itertools==10.5.0
|
||||
# via lm-eval
|
||||
mpmath==1.3.0
|
||||
@ -328,10 +471,14 @@ multiprocess==0.70.16
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
munch==4.0.0
|
||||
# via pretrainedmodels
|
||||
mypy-extensions==1.0.0
|
||||
# via black
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.11.1.3
|
||||
# via mamba-ssm
|
||||
nltk==3.9.1
|
||||
@ -348,6 +495,8 @@ numpy==1.26.4
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
# albucore
|
||||
# albumentations
|
||||
# bitsandbytes
|
||||
# bm25s
|
||||
# contourpy
|
||||
@ -358,9 +507,15 @@ numpy==1.26.4
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# geopandas
|
||||
# h5py
|
||||
# imageio
|
||||
# librosa
|
||||
# lightly
|
||||
# lightly-utils
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# mlflow
|
||||
# mteb
|
||||
# numba
|
||||
# numexpr
|
||||
@ -368,18 +523,30 @@ numpy==1.26.4
|
||||
# pandas
|
||||
# patsy
|
||||
# peft
|
||||
# pycocotools
|
||||
# pyogrio
|
||||
# rasterio
|
||||
# rioxarray
|
||||
# rouge-score
|
||||
# runai-model-streamer
|
||||
# sacrebleu
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# scipy
|
||||
# segmentation-models-pytorch
|
||||
# shapely
|
||||
# soxr
|
||||
# statsmodels
|
||||
# tensorboardx
|
||||
# tensorizer
|
||||
# tifffile
|
||||
# torchgeo
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
# xarray
|
||||
nvidia-cublas-cu12==12.8.3.14
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.8.55
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via
|
||||
# hydra-core
|
||||
# lightning
|
||||
open-clip-torch==2.32.0
|
||||
# via -r requirements/test.in
|
||||
opencensus==0.11.4
|
||||
@ -426,7 +597,18 @@ opencensus-context==0.1.3
|
||||
opencv-python-headless==4.11.0.86
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albucore
|
||||
# albumentations
|
||||
# mistral-common
|
||||
opentelemetry-api==1.35.0
|
||||
# via
|
||||
# mlflow-skinny
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
opentelemetry-sdk==1.35.0
|
||||
# via mlflow-skinny
|
||||
opentelemetry-semantic-conventions==0.56b0
|
||||
# via opentelemetry-sdk
|
||||
packaging==24.2
|
||||
# via
|
||||
# accelerate
|
||||
@ -435,26 +617,44 @@ packaging==24.2
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# geopandas
|
||||
# gunicorn
|
||||
# huggingface-hub
|
||||
# hydra-core
|
||||
# kornia
|
||||
# lazy-loader
|
||||
# lightning
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# matplotlib
|
||||
# mlflow-skinny
|
||||
# peft
|
||||
# plotly
|
||||
# pooch
|
||||
# pyogrio
|
||||
# pytest
|
||||
# pytest-rerunfailures
|
||||
# pytorch-lightning
|
||||
# ray
|
||||
# rioxarray
|
||||
# scikit-image
|
||||
# statsmodels
|
||||
# tensorboardx
|
||||
# torchmetrics
|
||||
# transformers
|
||||
# typepy
|
||||
# xarray
|
||||
pandas==2.2.3
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
# geopandas
|
||||
# mlflow
|
||||
# statsmodels
|
||||
# torchgeo
|
||||
# xarray
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
pathvalidate==3.2.1
|
||||
@ -468,9 +668,14 @@ peft==0.13.2
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# genai-perf
|
||||
# imageio
|
||||
# lightly-utils
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# scikit-image
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# torchgeo
|
||||
# torchvision
|
||||
platformdirs==4.3.6
|
||||
# via
|
||||
@ -489,6 +694,8 @@ portalocker==2.10.1
|
||||
# via sacrebleu
|
||||
pqdm==0.2.0
|
||||
# via -r requirements/test.in
|
||||
pretrainedmodels==0.7.4
|
||||
# via segmentation-models-pytorch
|
||||
prometheus-client==0.22.0
|
||||
# via ray
|
||||
propcache==0.2.0
|
||||
@ -499,8 +706,10 @@ protobuf==5.28.3
|
||||
# via
|
||||
# google-api-core
|
||||
# googleapis-common-protos
|
||||
# mlflow-skinny
|
||||
# proto-plus
|
||||
# ray
|
||||
# tensorboardx
|
||||
# tensorizer
|
||||
psutil==6.1.0
|
||||
# via
|
||||
@ -515,6 +724,7 @@ pyarrow==18.0.0
|
||||
# via
|
||||
# datasets
|
||||
# genai-perf
|
||||
# mlflow
|
||||
pyasn1==0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
@ -523,6 +733,8 @@ pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycocotools==2.0.8
|
||||
# via terratorch
|
||||
pycountry==24.6.1
|
||||
# via pydantic-extra-types
|
||||
pycparser==2.22
|
||||
@ -532,8 +744,12 @@ pycryptodomex==3.22.0
|
||||
pydantic==2.11.5
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albumentations
|
||||
# datamodel-code-generator
|
||||
# fastapi
|
||||
# lightly
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# pydantic-extra-types
|
||||
# ray
|
||||
@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5
|
||||
# via mistral-common
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyogrio==0.11.0
|
||||
# via geopandas
|
||||
pyparsing==3.2.0
|
||||
# via matplotlib
|
||||
# via
|
||||
# matplotlib
|
||||
# rasterio
|
||||
pyproj==3.7.1
|
||||
# via
|
||||
# geopandas
|
||||
# rioxarray
|
||||
# torchgeo
|
||||
pyrate-limiter==3.7.0
|
||||
# via schemathesis
|
||||
pystemmer==3.0.0
|
||||
# via mteb
|
||||
pytablewriter==1.2.0
|
||||
# via lm-eval
|
||||
pytest==8.3.3
|
||||
pytest==8.3.5
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# buildkite-test-collector
|
||||
@ -564,6 +789,7 @@ pytest==8.3.3
|
||||
# pytest-subtests
|
||||
# pytest-timeout
|
||||
# schemathesis
|
||||
# terratorch
|
||||
pytest-asyncio==0.24.0
|
||||
# via -r requirements/test.in
|
||||
pytest-forked==1.6.0
|
||||
@ -578,15 +804,23 @@ pytest-subtests==0.14.1
|
||||
# via schemathesis
|
||||
pytest-timeout==2.3.1
|
||||
# via -r requirements/test.in
|
||||
python-box==7.3.2
|
||||
# via terratorch
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# arrow
|
||||
# botocore
|
||||
# graphene
|
||||
# lightly
|
||||
# matplotlib
|
||||
# pandas
|
||||
# typepy
|
||||
python-rapidjson==1.20
|
||||
# via tritonclient
|
||||
pytorch-lightning==2.5.2
|
||||
# via
|
||||
# lightly
|
||||
# lightning
|
||||
pytrec-eval-terrier==0.5.7
|
||||
# via mteb
|
||||
pytz==2024.2
|
||||
@ -596,11 +830,17 @@ pytz==2024.2
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# albumentations
|
||||
# datamodel-code-generator
|
||||
# datasets
|
||||
# genai-perf
|
||||
# huggingface-hub
|
||||
# jsonargparse
|
||||
# lightning
|
||||
# mlflow-skinny
|
||||
# omegaconf
|
||||
# peft
|
||||
# pytorch-lightning
|
||||
# ray
|
||||
# responses
|
||||
# schemathesis
|
||||
@ -609,6 +849,11 @@ pyyaml==6.0.2
|
||||
# vocos
|
||||
rapidfuzz==3.12.1
|
||||
# via jiwer
|
||||
rasterio==1.4.3
|
||||
# via
|
||||
# rioxarray
|
||||
# terratorch
|
||||
# torchgeo
|
||||
ray==2.43.0
|
||||
# via -r requirements/test.in
|
||||
redis==5.2.0
|
||||
@ -627,12 +872,16 @@ regex==2024.9.11
|
||||
requests==2.32.3
|
||||
# via
|
||||
# buildkite-test-collector
|
||||
# databricks-sdk
|
||||
# datasets
|
||||
# docker
|
||||
# evaluate
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# lightly
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# pooch
|
||||
# ray
|
||||
@ -650,8 +899,11 @@ rfc3987==1.3.8
|
||||
rich==13.9.4
|
||||
# via
|
||||
# genai-perf
|
||||
# lightning
|
||||
# mteb
|
||||
# typer
|
||||
rioxarray==0.19.0
|
||||
# via terratorch
|
||||
rouge-score==0.1.2
|
||||
# via lm-eval
|
||||
rpds-py==0.20.1
|
||||
@ -660,6 +912,8 @@ rpds-py==0.20.1
|
||||
# referencing
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
rtree==1.4.0
|
||||
# via torchgeo
|
||||
runai-model-streamer==0.11.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.11.0
|
||||
@ -677,21 +931,32 @@ safetensors==0.4.5
|
||||
# transformers
|
||||
schemathesis==3.39.15
|
||||
# via -r requirements/test.in
|
||||
scikit-image==0.25.2
|
||||
# via albumentations
|
||||
scikit-learn==1.5.2
|
||||
# via
|
||||
# albumentations
|
||||
# librosa
|
||||
# lm-eval
|
||||
# mlflow
|
||||
# mteb
|
||||
# sentence-transformers
|
||||
scipy==1.13.1
|
||||
# via
|
||||
# albumentations
|
||||
# bm25s
|
||||
# librosa
|
||||
# mlflow
|
||||
# mteb
|
||||
# scikit-image
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
segmentation-models-pytorch==0.4.0
|
||||
# via
|
||||
# terratorch
|
||||
# torchgeo
|
||||
sentence-transformers==3.2.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -700,21 +965,30 @@ sentencepiece==0.2.0
|
||||
# via mistral-common
|
||||
setuptools==77.0.3
|
||||
# via
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# pytablewriter
|
||||
# torch
|
||||
# triton
|
||||
shapely==2.1.1
|
||||
# via
|
||||
# geopandas
|
||||
# torchgeo
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via
|
||||
# junit-xml
|
||||
# lightly
|
||||
# opencensus
|
||||
# python-dateutil
|
||||
# rfc3339-validator
|
||||
# rouge-score
|
||||
# segmentation-models-pytorch
|
||||
smart-open==7.1.0
|
||||
# via ray
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
@ -727,10 +1001,17 @@ soundfile==0.12.1
|
||||
# librosa
|
||||
soxr==0.5.0.post1
|
||||
# via librosa
|
||||
sqlalchemy==2.0.41
|
||||
# via
|
||||
# alembic
|
||||
# mlflow
|
||||
sqlitedict==2.1.0
|
||||
# via lm-eval
|
||||
sqlparse==0.5.3
|
||||
# via mlflow-skinny
|
||||
starlette==0.46.2
|
||||
# via
|
||||
# fastapi
|
||||
# schemathesis
|
||||
# starlette-testclient
|
||||
starlette-testclient==0.4.1
|
||||
@ -751,18 +1032,29 @@ tenacity==9.0.0
|
||||
# via
|
||||
# lm-eval
|
||||
# plotly
|
||||
tensorboardx==2.6.4
|
||||
# via lightning
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
terratorch==1.1rc2
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tifffile==2025.3.30
|
||||
# via
|
||||
# scikit-image
|
||||
# terratorch
|
||||
tiktoken==0.7.0
|
||||
# via
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
timm==1.0.15
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# open-clip-torch
|
||||
# segmentation-models-pytorch
|
||||
# terratorch
|
||||
# torchgeo
|
||||
tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -776,18 +1068,28 @@ torch==2.7.1+cu128
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# efficientnet-pytorch
|
||||
# encodec
|
||||
# fastsafetensors
|
||||
# kornia
|
||||
# lightly
|
||||
# lightning
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# mteb
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# pretrainedmodels
|
||||
# pytorch-lightning
|
||||
# runai-model-streamer
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# terratorch
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchgeo
|
||||
# torchmetrics
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchgeo==0.7.0
|
||||
# via terratorch
|
||||
torchmetrics==1.7.4
|
||||
# via
|
||||
# lightning
|
||||
# pytorch-lightning
|
||||
# terratorch
|
||||
# torchgeo
|
||||
torchvision==0.22.1+cu128
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# lightly
|
||||
# open-clip-torch
|
||||
# pretrainedmodels
|
||||
# segmentation-models-pytorch
|
||||
# terratorch
|
||||
# timm
|
||||
# torchgeo
|
||||
tqdm==4.66.6
|
||||
# via
|
||||
# datasets
|
||||
# evaluate
|
||||
# huggingface-hub
|
||||
# lightly
|
||||
# lightning
|
||||
# lm-eval
|
||||
# mteb
|
||||
# nltk
|
||||
# open-clip-torch
|
||||
# peft
|
||||
# pqdm
|
||||
# pretrainedmodels
|
||||
# pytorch-lightning
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
@ -843,18 +1163,34 @@ typer==0.15.2
|
||||
# via fastsafetensors
|
||||
types-python-dateutil==2.9.0.20241206
|
||||
# via arrow
|
||||
typeshed-client==2.8.2
|
||||
# via jsonargparse
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# albumentations
|
||||
# alembic
|
||||
# fastapi
|
||||
# graphene
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# lightning
|
||||
# lightning-utilities
|
||||
# mistral-common
|
||||
# mlflow-skinny
|
||||
# mteb
|
||||
# opentelemetry-api
|
||||
# opentelemetry-sdk
|
||||
# opentelemetry-semantic-conventions
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydantic-extra-types
|
||||
# pytorch-lightning
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# torchgeo
|
||||
# typer
|
||||
# typeshed-client
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
@ -866,9 +1202,13 @@ urllib3==2.2.3
|
||||
# via
|
||||
# blobfile
|
||||
# botocore
|
||||
# docker
|
||||
# lightly
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
uvicorn==0.35.0
|
||||
# via mlflow-skinny
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements/test.in
|
||||
virtualenv==20.31.2
|
||||
@ -880,11 +1220,15 @@ wcwidth==0.2.13
|
||||
webcolors==24.11.1
|
||||
# via jsonschema
|
||||
werkzeug==3.1.3
|
||||
# via schemathesis
|
||||
# via
|
||||
# flask
|
||||
# schemathesis
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
wrapt==1.17.2
|
||||
# via smart-open
|
||||
xarray==2025.7.1
|
||||
# via rioxarray
|
||||
xxhash==3.5.0
|
||||
# via
|
||||
# datasets
|
||||
@ -893,5 +1237,7 @@ yarl==1.17.1
|
||||
# via
|
||||
# aiohttp
|
||||
# schemathesis
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
63
tests/models/multimodal/pooling/test_prithvi_mae.py
Normal file
63
tests/models/multimodal/pooling/test_prithvi_mae.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
|
||||
def generate_test_mm_data():
|
||||
mm_data = {
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
return mm_data
|
||||
|
||||
|
||||
def _run_test(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
) -> None:
|
||||
|
||||
prompt = [
|
||||
{
|
||||
# This model deals with no text input
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": generate_test_mm_data(),
|
||||
} for _ in range(10)
|
||||
]
|
||||
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model,
|
||||
task="embed",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
) as vllm_model,
|
||||
):
|
||||
vllm_model.encode(prompt)
|
||||
|
||||
|
||||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_models_image(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model: str,
|
||||
) -> None:
|
||||
_run_test(
|
||||
vllm_runner,
|
||||
model,
|
||||
)
|
@ -651,6 +651,8 @@ class ModelConfig:
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
self.multimodal_config = self._init_multimodal_config()
|
||||
self.model_supports_multimodal_raw_input = (
|
||||
self.registry.supports_multimodal_raw_input(self.architectures))
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
@ -1243,10 +1245,10 @@ class ModelConfig:
|
||||
return self.get_hf_config_sliding_window()
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_text_config.vocab_size
|
||||
return getattr(self.hf_text_config, "vocab_size", 0)
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_text_config.hidden_size
|
||||
return getattr(self.hf_text_config, "hidden_size", 0)
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
|
@ -238,14 +238,14 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
self.use_cached_outputs = use_cached_outputs
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
else:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
tokenizer_group = None
|
||||
else:
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
|
||||
# Ensure that the function doesn't contain a reference to self,
|
||||
# to avoid engine GC issues
|
||||
|
@ -136,6 +136,40 @@ def supports_multimodal(
|
||||
return getattr(model, "supports_multimodal", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
|
||||
"""The interface required for all multi-modal models."""
|
||||
|
||||
supports_multimodal_raw_input: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports multi-modal inputs and processes
|
||||
them in their raw form and not embeddings.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_multimodal_raw_input(
|
||||
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
|
||||
...
|
||||
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
model: Union[type[object], object]
|
||||
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
|
||||
TypeIs[SupportsMultiModalWithRawInput]]:
|
||||
return getattr(model, "supports_multimodal_raw_input", False)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsScoreTemplate(Protocol):
|
||||
"""The interface required for all models that support score template."""
|
||||
|
@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -27,13 +28,14 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (AllPool, PoolerHead,
|
||||
PoolerIdentity, SimplePooler)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal,
|
||||
SupportsV0Only)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
MultiModalFieldElem, MultiModalInputs,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalSharedField, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder(
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
return {
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0),
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
||||
dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
|
||||
|
||||
@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
# a single image patch and therefore it is not to be split
|
||||
# into multiple elements, but rather to be considered a single one.
|
||||
# Hence, the decision of using a MultiModalSharedField.
|
||||
# The expected shape is (num_channels, width, height).
|
||||
|
||||
# This model however allows the user to also submit multiple image
|
||||
# patches as a batch, adding a further dimension to the above shape.
|
||||
# At this stage we only support submitting one patch per request and
|
||||
# batching is achieved via vLLM batching.
|
||||
# TODO (christian-pinto): enable support for multi patch requests
|
||||
# in tandem with vLLM batching.
|
||||
multimodal_kwargs_items = [
|
||||
MultiModalKwargsItem.from_elems([
|
||||
MultiModalFieldElem(
|
||||
modality="image",
|
||||
key=key,
|
||||
data=data,
|
||||
field=MultiModalSharedField(1),
|
||||
) for key, data in mm_kwargs.items()
|
||||
])
|
||||
]
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
|
||||
mm_hashes=None,
|
||||
mm_placeholders={},
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
SupportsV0Only):
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
|
||||
SupportsMultiModalWithRawInput):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
is_pooling_model = True
|
||||
@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
||||
)
|
||||
|
||||
return task.model
|
||||
else:
|
||||
@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
# We do not really use any input tokens and therefore no embeddings
|
||||
# to be calculated. However, due to the mandatory token ids in
|
||||
# the input prompt we pass one token and the size of the dummy
|
||||
# embedding tensors must reflect that.
|
||||
return torch.empty((input_ids.shape[0], 0))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
@ -22,8 +22,8 @@ from vllm.logger import init_logger
|
||||
|
||||
from .interfaces import (has_inner_state, has_noops, is_attention_free,
|
||||
is_hybrid, supports_cross_encoding,
|
||||
supports_multimodal, supports_pp,
|
||||
supports_transcription, supports_v0_only)
|
||||
supports_multimodal, supports_multimodal_raw_input,
|
||||
supports_pp, supports_transcription, supports_v0_only)
|
||||
from .interfaces_base import is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -287,6 +287,7 @@ class _ModelInfo:
|
||||
is_pooling_model: bool
|
||||
supports_cross_encoding: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input: bool
|
||||
supports_pp: bool
|
||||
has_inner_state: bool
|
||||
is_attention_free: bool
|
||||
@ -304,6 +305,7 @@ class _ModelInfo:
|
||||
is_pooling_model=True, # Can convert any model into a pooling model
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
|
||||
supports_pp=supports_pp(model),
|
||||
has_inner_state=has_inner_state(model),
|
||||
is_attention_free=is_attention_free(model),
|
||||
@ -573,6 +575,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal
|
||||
|
||||
def supports_multimodal_raw_input(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_multimodal_raw_input
|
||||
|
||||
def is_pp_supported_model(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
|
@ -266,7 +266,7 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
if tokenizer is None:
|
||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
if disable_cache is None:
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
@ -94,11 +94,14 @@ class AsyncLLM(EngineClient):
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
@ -525,6 +528,10 @@ class AsyncLLM(EngineClient):
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
|
@ -82,11 +82,14 @@ class LLMEngine:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer = None
|
||||
else:
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
lora_config=vllm_config.lora_config)
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
|
@ -327,14 +327,16 @@ class OutputProcessor:
|
||||
if request_id in self.request_states:
|
||||
raise ValueError(f"Request id {request_id} already running.")
|
||||
|
||||
req_state = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
tokenizer = None if not self.tokenizer else \
|
||||
self.tokenizer.get_lora_tokenizer(request.lora_request)
|
||||
|
||||
req_state = RequestState.from_new_request(tokenizer=tokenizer,
|
||||
request=request,
|
||||
prompt=prompt,
|
||||
parent_req=parent_req,
|
||||
request_index=request_index,
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
|
@ -380,7 +380,6 @@ class Processor:
|
||||
prompt_type: Literal["encoder", "decoder"],
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
if not prompt_ids:
|
||||
@ -389,9 +388,14 @@ class Processor:
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
|
@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.is_pooling_model = model_config.pooler_config is not None
|
||||
self.model_supports_multimodal_raw_input = (
|
||||
model_config.model_supports_multimodal_raw_input)
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_goups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _init_model_kwargs_for_multimodal_model(
|
||||
self,
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
num_reqs: int = -1,
|
||||
) -> dict[str, Any]:
|
||||
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if self.model_supports_multimodal_raw_input:
|
||||
# This model requires the raw multimodal data in input.
|
||||
if scheduler_output:
|
||||
multi_modal_kwargs_list = []
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_mm_inputs = req.mm_inputs
|
||||
if not isinstance(req_mm_inputs, list):
|
||||
req_mm_inputs = list(req_mm_inputs)
|
||||
multi_modal_kwargs_list.extend(req_mm_inputs)
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||
multi_modal_kwargs_list)
|
||||
else:
|
||||
# The only case where SchedulerOutput is None is for
|
||||
# a dummy run let's get some dummy data.
|
||||
dummy_data = [
|
||||
self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=1).multi_modal_data for i in range(num_reqs)
|
||||
]
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
|
||||
|
||||
model_kwargs.update(multi_modal_kwargs)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:num_scheduled_tokens]
|
||||
|
||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
||||
scheduler_output=scheduler_output)
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids=input_ids,
|
||||
multimodal_embeddings=mm_embeds or None,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
|
||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||
@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
model_kwargs = {}
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
else:
|
||||
@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_scheduled_tokens):
|
||||
model = self.model
|
||||
if self.is_multimodal_model:
|
||||
model_kwargs = self._init_model_kwargs_for_multimodal_model(
|
||||
num_reqs=num_reqs)
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
else:
|
||||
input_ids = self.input_ids[:num_tokens]
|
||||
inputs_embeds = None
|
||||
model_kwargs = {}
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_kwargs,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, _ = outputs
|
||||
else:
|
||||
|
Reference in New Issue
Block a user