[Bugfix] Fix Fuyu tensor parallel inference (#8986)

This commit is contained in:
Isotr0py
2024-10-01 17:51:41 +08:00
committed by GitHub
parent 82f3937e59
commit bc4eb65b54
3 changed files with 15 additions and 12 deletions

View File

@ -37,7 +37,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
# TP only models
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
],
)
@fork_new_process_for_each_test

View File

@ -237,8 +237,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(config,
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)

View File

@ -25,11 +25,11 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PersimmonConfig
from transformers.activations import ReLUSquaredActivation
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
@ -57,7 +57,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
quant_config=quant_config)
self.act = ReLUSquaredActivation()
self.act = get_act_fn(config.hidden_act, quant_config)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states)
@ -96,7 +96,7 @@ class PersimmonAttention(nn.Module):
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.num_heads * self.head_dim,
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=True,
quant_config=quant_config,
@ -213,10 +213,10 @@ class PersimmonModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.text_config.vocab_size, config.hidden_size)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PersimmonDecoderLayer(config,
cache_config=cache_config,
@ -252,19 +252,19 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module):
def __init__(self,
config,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.vocab_size = config.text_config.vocab_size
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(