mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] IBM/NASA Prithvi Geospatial model (#12830)
This commit is contained in:
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,530 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This is a demo script showing how to use the
|
||||
PrithviGeospatialMAE model with vLLM
|
||||
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
|
||||
|
||||
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
|
||||
|
||||
The requirements for running this script are:
|
||||
- Installing [terratorch, albumentations, rasterio] in your python environment
|
||||
- downloading the model weights in a 'model' folder local to the script
|
||||
(temporary measure until the proper config.json file is uploaded to HF)
|
||||
- download an input example image (India_900498_S2Hand.tif) and place it in
|
||||
the same folder with the script (or specify with the --data_file argument)
|
||||
|
||||
Run the example:
|
||||
python prithvi_geospatial_mae.py
|
||||
|
||||
""" # noqa: E501
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
model_config = """{
|
||||
"architectures": ["PrithviGeoSpatialMAE"],
|
||||
"num_classes": 0,
|
||||
"pretrained_cfg": {
|
||||
"task_args": {
|
||||
"task": "SemanticSegmentationTask",
|
||||
"model_factory": "EncoderDecoderFactory",
|
||||
"loss": "ce",
|
||||
"ignore_index": -1,
|
||||
"lr": 0.001,
|
||||
"freeze_backbone": false,
|
||||
"freeze_decoder": false,
|
||||
"plot_on_val": 10,
|
||||
"optimizer": "AdamW",
|
||||
"scheduler": "CosineAnnealingLR"
|
||||
},
|
||||
"model_args": {
|
||||
"backbone_pretrained": false,
|
||||
"backbone": "prithvi_eo_v2_300_tl",
|
||||
"decoder": "UperNetDecoder",
|
||||
"decoder_channels": 256,
|
||||
"decoder_scale_modules": true,
|
||||
"num_classes": 2,
|
||||
"rescale": true,
|
||||
"backbone_bands": [
|
||||
"BLUE",
|
||||
"GREEN",
|
||||
"RED",
|
||||
"NIR_NARROW",
|
||||
"SWIR_1",
|
||||
"SWIR_2"
|
||||
],
|
||||
"head_dropout": 0.1,
|
||||
"necks": [
|
||||
{
|
||||
"name": "SelectIndices",
|
||||
"indices": [
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ReshapeTokensToImage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"optimizer_params" : {
|
||||
"lr": 5.0e-05,
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": [1.0e-08],
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": false,
|
||||
"maximize": false,
|
||||
"capturable": false,
|
||||
"differentiable": false
|
||||
},
|
||||
"scheduler_params" : {
|
||||
"T_max": 50,
|
||||
"eta_min": 0,
|
||||
"last_epoch": -1,
|
||||
"verbose": "deprecated"
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
"""
|
||||
|
||||
# Temporarily creating the "config.json" for the model.
|
||||
# This is going to disappear once the correct config.json is available on HF
|
||||
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
|
||||
'w') as config_file:
|
||||
config_file.write(model_config)
|
||||
|
||||
datamodule_config = {
|
||||
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
|
||||
'batch_size':
|
||||
16,
|
||||
'constant_scale':
|
||||
0.0001,
|
||||
'data_root':
|
||||
'/dccstor/geofm-finetuning/datasets/sen1floods11',
|
||||
'drop_last':
|
||||
True,
|
||||
'no_data_replace':
|
||||
0.0,
|
||||
'no_label_replace':
|
||||
-1,
|
||||
'num_workers':
|
||||
8,
|
||||
'test_transform': [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
|
||||
def __init__(self):
|
||||
print("Initializing PrithviMAE model")
|
||||
self.model = LLM(model=os.path.join(os.path.dirname(__file__),
|
||||
"./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32")
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
print("################ Running inference on vLLM ##############")
|
||||
# merge the inputs into one data structure
|
||||
mm_data = {
|
||||
"pixel_values":
|
||||
torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords":
|
||||
torch.empty(0) if location_coords is None else location_coords
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
print(
|
||||
"################ Inference done (it took seconds) ##############"
|
||||
)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
|
||||
def generate_datamodule():
|
||||
datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config['data_root'],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"
|
||||
""])
|
||||
|
||||
return datamodule
|
||||
|
||||
|
||||
def process_channel_group(orig_img, channels):
|
||||
"""
|
||||
Args:
|
||||
orig_img: torch.Tensor representing original image (reference)
|
||||
with shape = (bands, H, W).
|
||||
channels: list of indices representing RGB channels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with shape (num_channels, height, width) for original image
|
||||
"""
|
||||
|
||||
orig_img = orig_img[channels, ...]
|
||||
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
||||
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
||||
|
||||
# Rescale (enhancing contrast)
|
||||
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
||||
min_value = OFFSET
|
||||
|
||||
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
|
||||
1)
|
||||
|
||||
# No data as zeros
|
||||
orig_img[~valid_mask] = 0
|
||||
|
||||
return orig_img
|
||||
|
||||
|
||||
def read_geotiff(file_path: str):
|
||||
"""Read all bands from *file_path* and return image + meta info.
|
||||
|
||||
Args:
|
||||
file_path: path to image file.
|
||||
|
||||
Returns:
|
||||
np.ndarray with shape (bands, height, width)
|
||||
meta info dict
|
||||
"""
|
||||
|
||||
with rasterio.open(file_path) as src:
|
||||
img = src.read()
|
||||
meta = src.meta
|
||||
try:
|
||||
coords = src.lnglat()
|
||||
except Exception:
|
||||
# Cannot read coords
|
||||
coords = None
|
||||
|
||||
return img, meta, coords
|
||||
|
||||
|
||||
def save_geotiff(image, output_path: str, meta: dict):
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
image: np.ndarray with shape (bands, height, width)
|
||||
output_path: path where to save the image
|
||||
meta: dict with meta info.
|
||||
"""
|
||||
|
||||
with rasterio.open(output_path, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _convert_np_uint8(float_image: torch.Tensor):
|
||||
image = float_image.numpy() * 255.0
|
||||
image = image.astype(dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def load_example(
|
||||
file_paths: List[str],
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
indices: Union[list[int], None] = None,
|
||||
):
|
||||
"""Build an input example by loading images in *file_paths*.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the images
|
||||
in *file_paths*.
|
||||
std: list containing std values for each band in the images
|
||||
in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
list of meta info for each image in *file_paths*
|
||||
"""
|
||||
|
||||
imgs = []
|
||||
metas = []
|
||||
temporal_coords = []
|
||||
location_coords = []
|
||||
|
||||
for file in file_paths:
|
||||
img, meta, coords = read_geotiff(file)
|
||||
|
||||
# Rescaling (don't normalize on nodata)
|
||||
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||
if indices is not None:
|
||||
img = img[..., indices]
|
||||
if mean is not None and std is not None:
|
||||
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||
|
||||
imgs.append(img)
|
||||
metas.append(meta)
|
||||
if coords is not None:
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r'(\d{7,8}T\d{6})', file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split('T')[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = datetime.datetime.strptime(
|
||||
julian_day, '%m%d').timetuple().tm_yday
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception as e:
|
||||
print(f'Could not extract timestamp for {file} ({e})')
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
def run_model(input_data,
|
||||
temporal_coords,
|
||||
location_coords,
|
||||
model,
|
||||
datamodule,
|
||||
img_size,
|
||||
lightning_model=None):
|
||||
# Reflect pad if not divisible by img_size
|
||||
original_h, original_w = input_data.shape[-2:]
|
||||
pad_h = (img_size - (original_h % img_size)) % img_size
|
||||
pad_w = (img_size - (original_w % img_size)) % img_size
|
||||
input_data = np.pad(input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
mode="reflect")
|
||||
|
||||
# Build sliding window
|
||||
batch_size = 1
|
||||
batch = torch.tensor(input_data, device="cpu")
|
||||
windows = (batch.unfold(3, img_size,
|
||||
img_size).unfold(4, img_size, img_size))
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(windows,
|
||||
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||
h=img_size,
|
||||
w=img_size)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[
|
||||
0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords,
|
||||
device=device).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0],
|
||||
device=device).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
# Run model
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
x = datamodule.test_transform(
|
||||
image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||
x = datamodule.aug(x)['image']
|
||||
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
if lightning_model:
|
||||
pred_lightning = lightning_model(
|
||||
x,
|
||||
temporal_coords=temporal_coords,
|
||||
location_coords=location_coords)
|
||||
pred_lightning = pred_lightning.output.detach().cpu()
|
||||
if not torch.equal(pred, pred_lightning):
|
||||
print("Inference output is not equal")
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
|
||||
size=img_size,
|
||||
mode="nearest")
|
||||
|
||||
pred_imgs.append(y_hat)
|
||||
|
||||
pred_imgs = torch.concat(pred_imgs, dim=0)
|
||||
|
||||
# Build images from patches
|
||||
pred_imgs = rearrange(
|
||||
pred_imgs,
|
||||
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||
h=img_size,
|
||||
w=img_size,
|
||||
b=1,
|
||||
c=1,
|
||||
h1=h1,
|
||||
w1=w1,
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
output_dir: str,
|
||||
rgb_outputs: bool,
|
||||
input_indices: list[int] = None,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load model ---------------------------------------------------------------
|
||||
|
||||
model_obj = PrithviMAE()
|
||||
datamodule = generate_datamodule()
|
||||
img_size = 256 # Size of Sen1Floods11
|
||||
|
||||
# Loading data -------------------------------------------------------------
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||
file_paths=[data_file],
|
||||
indices=input_indices,
|
||||
)
|
||||
|
||||
meta_data = meta_data[0] # only one image
|
||||
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
# Running model ------------------------------------------------------------
|
||||
|
||||
channels = [
|
||||
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
|
||||
pred = run_model(input_data, temporal_coords, location_coords, model_obj,
|
||||
datamodule, img_size)
|
||||
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
||||
|
||||
# Save image + pred
|
||||
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
||||
|
||||
if input_data.mean() < 1:
|
||||
input_data = input_data * 10000 # Scale to 0-10000
|
||||
|
||||
rgb_orig = process_channel_group(
|
||||
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
pred[pred == 0.] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
||||
|
||||
img_pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(img_pred),
|
||||
output_path=img_pred_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
rgb_file = os.path.join(
|
||||
output_dir, "original_rgb_"
|
||||
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
output_path=rgb_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
# The model on Huggingface is currently being updated,
|
||||
# hence I temporarily mark it as not available online
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
is_available_online=False),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
|
@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||
# "inter_data_list" attribute.
|
||||
# Let's check inter_data_list exists before we reference it.
|
||||
if hasattr(self.input_builder, "inter_data_list"):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
@ -254,8 +254,14 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
@ -273,9 +279,15 @@ class InputPreprocessor:
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> MultiModalInputs:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
|
||||
)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
|
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 IBM.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
pass
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
# This model input is fixed and is in the form of a torch Tensor.
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
mm_data={
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0)
|
||||
})
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
pass
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
mm_kwargs = {}
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_placeholders={},
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
def _instantiate_model(self, config: dict) -> nn.Module | None:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
loss=config["task_args"]["loss"],
|
||||
lr=config["task_args"]["lr"],
|
||||
ignore_index=config["task_args"]["ignore_index"],
|
||||
optimizer=config["task_args"]["optimizer"],
|
||||
optimizer_hparams=config["optimizer_params"],
|
||||
scheduler=config["task_args"]["scheduler"],
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
|
||||
return task.model
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# the actual model is dynamically instantiated using terratorch
|
||||
# allowing us to perform changes to the model architecture
|
||||
# at startup time (e.g., change the model decoder class.)
|
||||
self.model = self._instantiate_model(
|
||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"Unsupported task."
|
||||
"Only SemanticSegmentationTask is supported for now"
|
||||
"by PrithviGeospatialMAE.")
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of location_coords. "
|
||||
f"Got type: {type(location_coords)}")
|
||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||
if location_coords.shape == torch.Size([0]):
|
||||
location_coords = None
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
location_coords=location_coords)
|
||||
|
||||
return model_output.output
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_list = []
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
for key, value in weights:
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
# Load the remaining model parameters
|
||||
loader = AutoWeightsLoader(self)
|
||||
autoloaded_weights = loader.load_weights(params_list)
|
||||
|
||||
return autoloaded_weights.union(set(loaded_buffers))
|
@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
|
@ -74,7 +74,16 @@ class PoolingModelRunner(
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
# Pooling models are (ab-)used also to integrate non text models that
|
||||
# are not autoregressive (PrithviGeosaptialMAE).
|
||||
# These model might not use attention and do not really have a prefill
|
||||
# and decode phase. The model input is processed in one shot and both
|
||||
# decode_metadata and prefill_metadata would be None for such models.
|
||||
# See the PlaceholderAttentionMetadata class.
|
||||
# TODO: Figure out if cuda_graph is of any use for these models and
|
||||
# explore how to leverage it.
|
||||
if (prefill_meta is None and decode_meta is not None
|
||||
and decode_meta.use_cuda_graph):
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
|
Reference in New Issue
Block a user