mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][Model] Terratorch backend integration (#23513)
Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -45,7 +45,11 @@ datamodule_config = {
|
||||
class PrithviMAE:
|
||||
def __init__(self, model):
|
||||
self.model = LLM(
|
||||
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
|
@ -37,6 +37,7 @@ def main():
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff_india",
|
||||
model_impl="terratorch",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
@ -15,6 +15,7 @@ import requests
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin prithvi_to_tiff_india
|
||||
|
@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
||||
decord==0.6.0
|
||||
terratorch==1.1rc3 # required for PrithviMAE test
|
||||
|
@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
|
||||
# via lightning
|
||||
tensorizer==2.10.1
|
||||
# via -r requirements/test.in
|
||||
terratorch==1.1rc2
|
||||
terratorch==1.1rc3
|
||||
# via -r requirements/test.in
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
|
@ -298,6 +298,8 @@ def _compare_tp(
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
hf_config = get_config(model_id, trust_remote_code)
|
||||
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||
max_num_seqs = model_info.max_num_seqs
|
||||
|
||||
dtype = "float16"
|
||||
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
|
||||
@ -351,6 +353,10 @@ def _compare_tp(
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
if skip_tokenizer_init:
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
if max_num_seqs:
|
||||
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
|
||||
|
||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||
testing_ray_compiled_graph = False
|
||||
|
@ -178,6 +178,7 @@ def _compare_sp(
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
@ -227,6 +228,8 @@ def _compare_sp(
|
||||
common_args.extend(["--load-format", load_format])
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
if skip_tokenizer_init:
|
||||
common_args.append("--skip-tokenizer-init")
|
||||
|
||||
compilation_config = {
|
||||
'level': 3,
|
||||
|
@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
revision=model_info.revision,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
||||
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
@ -35,7 +35,9 @@ def server():
|
||||
"--trust-remote-code",
|
||||
"--skip-tokenizer-init",
|
||||
"--max-num-seqs",
|
||||
"32"
|
||||
"32",
|
||||
"--model-impl",
|
||||
"terratorch"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Build the tokenizer group and grab the underlying tokenizer
|
||||
tokenizer_group = TokenizerGroup(
|
||||
@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model,
|
||||
@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model_config.tokenizer,
|
||||
|
@ -69,6 +69,9 @@ def run_test(
|
||||
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
||||
if model_info.hf_overrides:
|
||||
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
|
||||
if model_info.skip_tokenizer_init:
|
||||
vllm_runner_kwargs_[
|
||||
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
|
||||
|
||||
if vllm_runner_kwargs:
|
||||
vllm_runner_kwargs_.update(vllm_runner_kwargs)
|
||||
|
@ -46,7 +46,7 @@ def _run_test(
|
||||
vllm_model.encode(prompt)
|
||||
|
||||
|
||||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
|
||||
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
|
@ -66,7 +66,9 @@ def _test_processing_correctness(
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
# Ensure that the cache can fit all of the data
|
||||
mm_processor_cache_gb=2048,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
|
@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
|
||||
|
@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
original_weights = create_repo_dummy_weights(model_id)
|
||||
|
@ -6,10 +6,11 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.config import TokenizerMode
|
||||
from vllm.config import ModelDType, TokenizerMode
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -47,6 +48,23 @@ class _HfExamplesInfo:
|
||||
The reason for the minimum/maximum version requirement.
|
||||
"""
|
||||
|
||||
skip_tokenizer_init: bool = False
|
||||
"""
|
||||
If true, skip initialization of tokenizer and detokenizer.
|
||||
"""
|
||||
|
||||
dtype: ModelDType = "auto"
|
||||
"""
|
||||
The data type for the model weights and activations.
|
||||
"""
|
||||
|
||||
enforce_eager: bool = False
|
||||
"""
|
||||
Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
"""
|
||||
|
||||
is_available_online: bool = True
|
||||
"""
|
||||
Set this to ``False`` if the name of this architecture no longer exists on
|
||||
@ -76,6 +94,9 @@ class _HfExamplesInfo:
|
||||
If not specified, the default revision will be used.
|
||||
"""
|
||||
|
||||
max_num_seqs: Optional[int] = None
|
||||
"""Maximum number of sequences to be processed in a single iteration."""
|
||||
|
||||
def check_transformers_version(
|
||||
self,
|
||||
*,
|
||||
@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
is_available_online=False), # noqa: E501
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# This is to avoid the model
|
||||
# going OOM in CI
|
||||
max_num_seqs=32,
|
||||
),
|
||||
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# This is to avoid the model going OOM in CI
|
||||
max_num_seqs=32,
|
||||
),
|
||||
}
|
||||
|
||||
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
|
@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
tokenizer=model_info.tokenizer,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
dtype=model_info.dtype,
|
||||
speculative_config={
|
||||
"model": model_info.speculative_model,
|
||||
"num_speculative_tokens": 1,
|
||||
@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
model_impl=ModelImpl.TRANSFORMERS
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
)
|
||||
max_num_seqs=model_info.max_num_seqs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
|
45
tests/models/test_terratorch.py
Normal file
45
tests/models/test_terratorch.py
Normal file
@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
|
||||
],
|
||||
)
|
||||
def test_inference(
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: str,
|
||||
) -> None:
|
||||
|
||||
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
||||
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
||||
prompt = dict(prompt_token_ids=[1],
|
||||
multi_modal_data=dict(pixel_values=pixel_values,
|
||||
location_coords=location_coords))
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model,
|
||||
runner="pooling",
|
||||
dtype=torch.float16,
|
||||
enforce_eager=True,
|
||||
skip_tokenizer_init=True,
|
||||
# Limit the maximum number of sequences to avoid the
|
||||
# test going OOM during the warmup run
|
||||
max_num_seqs=32,
|
||||
) as vllm_model,
|
||||
):
|
||||
|
||||
vllm_output = vllm_model.llm.encode(prompt)
|
||||
assert torch.equal(
|
||||
torch.isnan(vllm_output[0].outputs.data).any(),
|
||||
torch.tensor(False))
|
@ -294,6 +294,8 @@ def build_model_context(
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
**model_config_kwargs,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
|
@ -7,12 +7,11 @@ import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
||||
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
@ -23,61 +22,7 @@ def test_loading_missing_plugin():
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
|
||||
|
||||
def test_loading_engine_with_wrong_plugin():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LLM(
|
||||
model=MODEL_NAME,
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="wrong_plugin",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="function")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
@ -90,7 +35,9 @@ def server():
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff_valencia"
|
||||
"prithvi_to_tiff_valencia",
|
||||
"--model-impl",
|
||||
"terratorch",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(plugin_data["data"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
model_impl="terratorch",
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
|
@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
VLLM = "vllm"
|
||||
TRANSFORMERS = "transformers"
|
||||
TERRATORCH = "terratorch"
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
@ -496,7 +497,9 @@ class ModelConfig:
|
||||
back to the Transformers implementation if no vLLM implementation is
|
||||
available.\n
|
||||
- "vllm" will use the vLLM model implementation.\n
|
||||
- "transformers" will use the Transformers model implementation."""
|
||||
- "transformers" will use the Transformers model implementation.\n
|
||||
- "terratorch" will use the TerraTorch model implementation.
|
||||
"""
|
||||
override_attention_dtype: Optional[str] = None
|
||||
"""Override dtype for attention"""
|
||||
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
||||
|
@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggybacks on embedding
|
||||
# Technically Terratorch models work on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
|
||||
"Terratorch": ("terratorch", "Terratorch"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@ -639,6 +640,9 @@ class _ModelRegistry:
|
||||
model_info = self._try_inspect_model_cls(arch)
|
||||
if model_info is not None:
|
||||
return (model_info, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
model_info = self._try_inspect_model_cls("Terratorch")
|
||||
return (model_info, "Terratorch")
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
@ -687,6 +691,11 @@ class _ModelRegistry:
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||
arch = "Terratorch"
|
||||
model_cls = self._try_load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
|
||||
# Fallback to transformers impl (after resolving convert_type)
|
||||
if (all(arch not in self.models for arch in architectures)
|
||||
|
@ -15,13 +15,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
"""Wrapper around `Terratorch` models"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
|
||||
InputDefinition, InputTypeEnum)
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
# This model receives in input a multi-dimensional tensor representing
|
||||
# a single image patch and therefore it is not to be split
|
||||
# into multiple elements, but rather to be considered a single one.
|
||||
# Hence, the decision of using a MultiModalSharedField.
|
||||
# The expected shape is (num_channels, width, height).
|
||||
|
||||
# This model however allows the user to also submit multiple image
|
||||
# patches as a batch, adding a further dimension to the above shape.
|
||||
# At this stage we only support submitting one patch per request and
|
||||
# batching is achieved via vLLM batching.
|
||||
# TODO (christian-pinto): enable support for multi patch requests
|
||||
# in tandem with vLLM batching.
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
||||
modality="image"),
|
||||
)
|
||||
def _terratorch_field_names(pretrained_cfg: dict):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
return set(input_definition.data.keys())
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser):
|
||||
def _terratorch_field_factory(
|
||||
pretrained_cfg: dict
|
||||
) -> Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
]:
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields={"pixel_values", "location_coords"},
|
||||
fields_factory=_prithvi_field_config,
|
||||
)
|
||||
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
fields = {}
|
||||
for input_name, input in input_definition.data.items():
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[input_name] = "image"
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
mm_fields_config = {}
|
||||
for field_name, field_modality in fields.items():
|
||||
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||
batch_size=1, modality=field_modality)
|
||||
return mm_fields_config
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
class TerratorchProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
|
||||
def __init__(self, info: TerratorchProcessingInfo):
|
||||
super().__init__(info)
|
||||
self.dummy_data_generator = DummyDataGenerator(
|
||||
self.info.get_hf_config().to_dict()["pretrained_cfg"])
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
# This model input is fixed and is in the form of a torch Tensor.
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
image_data = {
|
||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
||||
dtype=torch.float16),
|
||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
||||
}
|
||||
|
||||
return {"image": image_data}
|
||||
# Dummy data is generated based on the 'input' section
|
||||
# defined in the HF configuration file
|
||||
return self.dummy_data_generator.get_dummy_mm_data()
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
|
||||
self._pretrained_cfg = pretrained_cfg
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||
if isinstance(data, dict):
|
||||
|
||||
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
|
||||
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields=terratorch_fields,
|
||||
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
info: TerratorchProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
|
||||
|
||||
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return PrithviGeoSpatialMAEMultiModalDataParser()
|
||||
return TerratorchMultiModalDataParser(
|
||||
pretrained_cfg=self.pretrained_cfg)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _prithvi_field_config(hf_inputs)
|
||||
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
@default_pooling_type("All")
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
||||
TerratorchMultiModalProcessor,
|
||||
info=TerratorchProcessingInfo,
|
||||
dummy_inputs=TerratorchInputBuilder,
|
||||
)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
"""Prithvi Masked Autoencoder"""
|
||||
|
||||
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
supports_multimodal_raw_input_only = True
|
||||
is_pooling_model = True
|
||||
|
||||
@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
loss=config["task_args"]["loss"],
|
||||
lr=config["task_args"]["lr"],
|
||||
ignore_index=config["task_args"]["ignore_index"],
|
||||
optimizer=config["task_args"]["optimizer"],
|
||||
optimizer_hparams=config["optimizer_params"],
|
||||
scheduler=config["task_args"]["scheduler"],
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
||||
)
|
||||
|
||||
return task.model
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# the actual model is dynamically instantiated using terratorch
|
||||
# allowing us to perform changes to the model architecture
|
||||
# at startup time (e.g., change the model decoder class.)
|
||||
self.model = self._instantiate_model(
|
||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"Unsupported task. "
|
||||
"Only SemanticSegmentationTask is supported for now "
|
||||
"by PrithviGeospatialMAE.")
|
||||
config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
|
||||
|
||||
self.inference_runner = InferenceRunner(config)
|
||||
self.model = self.inference_runner.model
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of location_coords. "
|
||||
f"Got type: {type(location_coords)}")
|
||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||
if location_coords.shape == torch.Size([0]):
|
||||
location_coords = None
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
location_coords=location_coords)
|
||||
model_output = self.inference_runner.forward(**kwargs)
|
||||
|
||||
return model_output.output
|
||||
|
||||
@ -283,28 +257,34 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
for key, value in weights:
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
if isinstance(value, (dict, OrderedDict)):
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
name = f"inference_runner.{name}"
|
||||
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
elif isinstance(value, torch.Tensor):
|
||||
params_list.append((f"inference_runner.model.{key}", value))
|
||||
|
||||
# Load the remaining model parameters
|
||||
loader = AutoWeightsLoader(self)
|
Reference in New Issue
Block a user