[Core][Model] Terratorch backend integration (#23513)

Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
mgazz
2025-09-04 08:22:41 +01:00
committed by GitHub
parent e7fc70016f
commit 51d5e9be7d
23 changed files with 305 additions and 208 deletions

View File

@ -45,7 +45,11 @@ datamodule_config = {
class PrithviMAE:
def __init__(self, model):
self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
model=model,
skip_tokenizer_init=True,
dtype="float16",
enforce_eager=True,
model_impl="terratorch",
)
def run(self, input_data, location_coords):

View File

@ -37,6 +37,7 @@ def main():
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
model_impl="terratorch",
)
pooling_params = PoolingParams(task="encode", softmax=False)

View File

@ -15,6 +15,7 @@ import requests
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india

View File

@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test

View File

@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning
tensorizer==2.10.1
# via -r requirements/test.in
terratorch==1.1rc2
terratorch==1.1rc3
# via -r requirements/test.in
threadpoolctl==3.5.0
# via scikit-learn

View File

@ -298,6 +298,8 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init
max_num_seqs = model_info.max_num_seqs
dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
@ -351,6 +353,10 @@ def _compare_tp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
testing_ray_compiled_graph = False

View File

@ -178,6 +178,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init
if load_format == "dummy":
# Avoid OOM
@ -227,6 +228,8 @@ def _compare_sp(
common_args.extend(["--load-format", load_format])
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
compilation_config = {
'level': 3,

View File

@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Initialize the tokenizer
tokenizer = get_tokenizer(

View File

@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"
@ -35,7 +35,9 @@ def server():
"--trust-remote-code",
"--skip-tokenizer-init",
"--max-num-seqs",
"32"
"32",
"--model-impl",
"terratorch"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup(
model,
@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,

View File

@ -69,6 +69,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init:
vllm_runner_kwargs_[
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs)

View File

@ -46,7 +46,7 @@ def _run_test(
vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model

View File

@ -66,7 +66,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]

View File

@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]

View File

@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id)

View File

@ -6,10 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional
import pytest
import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode
from vllm.config import ModelDType, TokenizerMode
@dataclass(frozen=True)
@ -47,6 +48,23 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement.
"""
skip_tokenizer_init: bool = False
"""
If true, skip initialization of tokenizer and detokenizer.
"""
dtype: ModelDType = "auto"
"""
The data type for the model weights and activations.
"""
enforce_eager: bool = False
"""
Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
"""
is_available_online: bool = True
"""
Set this to ``False`` if the name of this architecture no longer exists on
@ -76,6 +94,9 @@ class _HfExamplesInfo:
If not specified, the default revision will be used.
"""
max_num_seqs: Optional[int] = None
"""Maximum number of sequences to be processed in a single iteration."""
def check_transformers_version(
self,
*,
@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
}
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {

View File

@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init,
dtype=model_info.dtype,
speculative_config={
"model": model_info.speculative_model,
"num_speculative_tokens": 1,
@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
hf_overrides=hf_overrides_fn,
)
max_num_seqs=model_info.max_num_seqs)
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())

View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.conftest import VllmRunner
from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize(
"model",
[
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
],
)
def test_inference(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
prompt = dict(prompt_token_ids=[1],
multi_modal_data=dict(pixel_values=pixel_values,
location_coords=location_coords))
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
runner="pooling",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_output = vllm_model.llm.encode(prompt)
assert torch.equal(
torch.isnan(vllm_output[0].outputs.data).any(),
torch.tensor(False))

View File

@ -294,6 +294,8 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
**model_config_kwargs,
)
return InputContext(model_config)

View File

@ -7,12 +7,11 @@ import requests
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
@ -23,61 +22,7 @@ def test_loading_missing_plugin():
get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin():
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
@pytest.fixture(scope="function")
def server():
args = [
"--runner",
@ -90,7 +35,9 @@ def server():
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff_valencia"
"prithvi_to_tiff_valencia",
"--model-impl",
"terratorch",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)

View File

@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum):
AUTO = "auto"
VLLM = "vllm"
TRANSFORMERS = "transformers"
TERRATORCH = "terratorch"
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
@ -496,7 +497,9 @@ class ModelConfig:
back to the Transformers implementation if no vLLM implementation is
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
- "transformers" will use the Transformers model implementation.\n
- "terratorch" will use the TerraTorch model implementation.
"""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None

View File

@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
# input and output. I am adding it here because it piggybacks on embedding
# Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggy-backs on embedding
# models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
"Terratorch": ("terratorch", "Terratorch"),
}
_CROSS_ENCODER_MODELS = {
@ -639,6 +640,9 @@ class _ModelRegistry:
model_info = self._try_inspect_model_cls(arch)
if model_info is not None:
return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)
@ -687,6 +691,11 @@ class _ModelRegistry:
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
# Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures)

View File

@ -15,13 +15,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
"""Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
InputDefinition, InputTypeEnum)
from transformers import BatchFeature
from vllm.config import VllmConfig
@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems,
@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
from .interfaces_base import default_pooling_type
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
# This model receives in input a multi-dimensional tensor representing
# a single image patch and therefore it is not to be split
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
return dict(
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
def _terratorch_field_names(pretrained_cfg: dict):
input_definition = InputDefinition(**pretrained_cfg["input"])
return set(input_definition.data.keys())
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser):
def _terratorch_field_factory(
pretrained_cfg: dict
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields={"pixel_values", "location_coords"},
fields_factory=_prithvi_field_config,
)
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
for input_name, input in input_definition.data.items():
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
return super()._parse_image_data(data)
mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality)
return mm_fields_config
return _terratorch_field_config
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
class TerratorchProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
def __init__(self, info: TerratorchProcessingInfo):
super().__init__(info)
self.dummy_data_generator = DummyDataGenerator(
self.info.get_hf_config().to_dict()["pretrained_cfg"])
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
# This model input is fixed and is in the form of a torch Tensor.
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
image_data = {
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return {"image": image_data}
# Dummy data is generated based on the 'input' section
# defined in the HF configuration file
return self.dummy_data_generator.get_dummy_mm_data()
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs)
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
)
return super()._parse_image_data(data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return PrithviGeoSpatialMAEMultiModalDataParser()
return TerratorchMultiModalDataParser(
pretrained_cfg=self.pretrained_cfg)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _prithvi_field_config(hf_inputs)
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
def _get_prompt_updates(
self,
@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
@default_pooling_type("All")
@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
TerratorchMultiModalProcessor,
info=TerratorchProcessingInfo,
dummy_inputs=TerratorchInputBuilder,
)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Prithvi Masked Autoencoder"""
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
supports_multimodal_raw_input_only = True
is_pooling_model = True
@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
loss=config["task_args"]["loss"],
lr=config["task_args"]["lr"],
ignore_index=config["task_args"]["ignore_index"],
optimizer=config["task_args"]["optimizer"],
optimizer_hparams=config["optimizer_params"],
scheduler=config["task_args"]["scheduler"],
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
return None
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
# the actual model is dynamically instantiated using terratorch
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.)
self.model = self._instantiate_model(
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
if self.model is None:
raise ValueError(
"Unsupported task. "
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
self.inference_runner = InferenceRunner(config)
self.model = self.inference_runner.model
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
raise ValueError(f"Incorrect type of location_coords. "
f"Got type: {type(location_coords)}")
location_coords = torch.unbind(location_coords, dim=0)[0]
if location_coords.shape == torch.Size([0]):
location_coords = None
return pixel_values, location_coords
def get_input_embeddings(
self,
input_ids: torch.Tensor,
@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
pixel_values, location_coords = (
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
location_coords=location_coords)
model_output = self.inference_runner.forward(**kwargs)
return model_output.output
@ -283,28 +257,34 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
model_buffers = dict(self.named_buffers())
loaded_buffers = []
for key, value in weights:
if key == "state_dict":
weights_to_parse = value
for name, weight in weights_to_parse.items():
if "pos_embed" in name:
continue
if isinstance(value, (dict, OrderedDict)):
if key == "state_dict":
weights_to_parse = value
for name, weight in weights_to_parse.items():
name = f"inference_runner.{name}"
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
if "pos_embed" in name:
continue
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if name in model_buffers:
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else:
params_list.append((name, weight))
break
# this model requires a couple of buffers to be loaded
# that are not loadable with the AutoWeightsLoader
if name in model_buffers:
if "_timm_module." in name:
name = name.replace("_timm_module.", "")
buffer = model_buffers[name]
weight_loader = getattr(buffer, "weight_loader",
default_weight_loader)
weight_loader(buffer, weight)
loaded_buffers.append(name)
else:
params_list.append((name, weight))
break
elif isinstance(value, torch.Tensor):
params_list.append((f"inference_runner.model.{key}", value))
# Load the remaining model parameters
loader = AutoWeightsLoader(self)