mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Hardware][Intel-Gaudi] Update hpu-extension and update bucketing system for HPU device (#17186)
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
This commit is contained in:
committed by
GitHub
parent
909fdaf152
commit
c48334d405
@ -9,4 +9,4 @@ numpy==1.26.4
|
||||
tabulate
|
||||
setuptools>=61
|
||||
setuptools-scm>=8
|
||||
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
|
||||
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624
|
||||
|
@ -4,14 +4,14 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
|
||||
###############################################################################
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import vllm_hpu_extension.kernels as kernels
|
||||
import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
|
||||
VLLMKVCache)
|
||||
from vllm_hpu_extension.flags import enabled_flags
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
self.block2batch_matmul = Matmul()
|
||||
self.k_cache = VLLMKVCache()
|
||||
self.v_cache = VLLMKVCache()
|
||||
ops.pa_impl = ops.pa
|
||||
self.fused_scaled_dot_product_attention = kernels.fsdpa()
|
||||
|
||||
self.prefill_impl = 'naive'
|
||||
if "flex_attention" in enabled_flags():
|
||||
self.prefill_impl = 'flex'
|
||||
if "fsdpa" in enabled_flags():
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
self.prefill_impl = 'fsdpa'
|
||||
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
|
||||
'0').lower() in ['1', 'true']
|
||||
self.fused_scaled_dot_product_attention = None
|
||||
if self.prefill_usefusedsdpa:
|
||||
if self.prefill_impl == 'fsdpa':
|
||||
assert alibi_slopes is None, \
|
||||
'Prefill with FusedSDPA not supported with alibi slopes!'
|
||||
try:
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
|
||||
FusedSDPA)
|
||||
except ImportError:
|
||||
logger.warning("Could not import HPU FusedSDPA kernel. "
|
||||
"vLLM will use native implementation.")
|
||||
|
||||
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
self.attn_type = attn_type
|
||||
if self.attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
_, seq_len_kv, _ = key.shape
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
block_indices = attn_metadata.block_indices
|
||||
block_offsets = attn_metadata.block_offsets
|
||||
if attn_metadata.is_prompt:
|
||||
key_cache = None
|
||||
value_cache = None
|
||||
if attn_metadata.is_prompt and self.attn_type \
|
||||
is not AttentionType.ENCODER_ONLY \
|
||||
and attn_metadata.block_list is None:
|
||||
key = key.unflatten(0, (block_indices.size(0), -1))
|
||||
value = value.unflatten(0, (block_indices.size(0), -1))
|
||||
if kv_cache is not None:
|
||||
if kv_cache is not None and isinstance(kv_cache, tuple):
|
||||
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
# Prompt run.
|
||||
if not self.prefill_usefusedsdpa:
|
||||
# TODO: move this outside of model
|
||||
assert attn_metadata.attn_bias is not None, \
|
||||
'attn_bias must be set before calling model.forward!'
|
||||
attn_bias = attn_metadata.attn_bias
|
||||
if self.alibi_slopes is not None:
|
||||
position_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads,
|
||||
attn_bias.dtype,
|
||||
attn_bias.shape[-1])
|
||||
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
||||
attn_bias.add_(position_bias)
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
|
||||
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
attn_bias = attn_metadata.attn_bias
|
||||
if attn_bias is not None and self.alibi_slopes is not None:
|
||||
position_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads,
|
||||
attn_bias.dtype,
|
||||
attn_bias.shape[-1])
|
||||
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
|
||||
attn_bias.add_(position_bias)
|
||||
|
||||
out = ops.prompt_attention(
|
||||
query.view(query_shape),
|
||||
key.view(kv_shape),
|
||||
value.view(kv_shape),
|
||||
impl=self.prefill_impl,
|
||||
query=query.view(query_shape),
|
||||
key=key.view(kv_shape),
|
||||
value=value.view(kv_shape),
|
||||
is_causal=True,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
softmax_op=self.softmax,
|
||||
matmul_av_op=self.matmul_av,
|
||||
fsdpa_op=self.fused_scaled_dot_product_attention,
|
||||
)
|
||||
valid_seq_lengths=attn_metadata.seq_lens_tensor,
|
||||
**self.common_attention_args())
|
||||
output = out.reshape(batch_size, seq_len, hidden_size)
|
||||
else:
|
||||
# Decoding run.
|
||||
@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
block_list=attn_metadata.block_list,
|
||||
block_mapping=attn_metadata.block_mapping,
|
||||
block_bias=attn_metadata.attn_bias,
|
||||
block_scales=attn_metadata.block_scales,
|
||||
block_groups=attn_metadata.block_groups,
|
||||
scale=self.scale,
|
||||
matmul_qk_op=self.matmul_qk,
|
||||
matmul_av_op=self.matmul_av,
|
||||
batch2block_matmul_op=self.batch2block_matmul,
|
||||
block2batch_matmul_op=self.block2batch_matmul,
|
||||
keys_fetch_func=self.k_cache.fetch_from_cache,
|
||||
values_fetch_func=self.v_cache.fetch_from_cache)
|
||||
**self.common_attention_args())
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
def common_attention_args(self):
|
||||
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
|
||||
if self.fused_scaled_dot_product_attention is not None else None
|
||||
return {
|
||||
'scale': self.scale,
|
||||
'matmul_qk_op': self.matmul_qk,
|
||||
'matmul_av_op': self.matmul_av,
|
||||
'batch2block_matmul_op': self.batch2block_matmul,
|
||||
'block2batch_matmul_op': self.block2batch_matmul,
|
||||
'fsdpa_op': fsdpa_op,
|
||||
'keys_fetch_func': self.k_cache.fetch_from_cache,
|
||||
'values_fetch_func': self.v_cache.fetch_from_cache,
|
||||
'softmax_op': self.softmax,
|
||||
}
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
|
@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_indices: Optional[torch.Tensor]
|
||||
block_offsets: Optional[torch.Tensor]
|
||||
block_scales: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
|
||||
|
||||
|
@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
from vllm_hpu_extension.ops import HPUFusedRMSNorm
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
HPUFusedRMSNorm = rms_norm()
|
||||
if HPUFusedRMSNorm is None:
|
||||
return self.forward_native(x, residual)
|
||||
if residual is not None:
|
||||
|
@ -11,11 +11,9 @@ import functools
|
||||
import gc
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import time
|
||||
from array import array
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
||||
Optional, Set, Tuple, Type, TypeVar, Union)
|
||||
@ -24,8 +22,9 @@ import habana_frameworks.torch as htorch
|
||||
import habana_frameworks.torch.internal.bridge_config as bc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm_hpu_extension.environment as environment
|
||||
from vllm_hpu_extension.bucketing.common import get_bucketing_context
|
||||
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
||||
from vllm_hpu_extension.ops import batch2block, block2batch
|
||||
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||
HabanaMemoryProfiler, format_bytes)
|
||||
|
||||
@ -77,25 +76,6 @@ LORA_WARMUP_RANK = 8
|
||||
DUMMY_TOKEN_ID = -1
|
||||
|
||||
|
||||
class Singleton(type):
|
||||
_instances: Dict[type, object] = {}
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__call__(*args, **kwargs)
|
||||
return cls._instances[cls]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HPUBucketingGlobalState(metaclass=Singleton):
|
||||
prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
||||
decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
||||
prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
||||
decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False)
|
||||
prompt_buckets: List[Tuple[int, int]] = field(init=False)
|
||||
decode_buckets: List[Tuple[int, int]] = field(init=False)
|
||||
|
||||
|
||||
def subtuple(obj: object,
|
||||
typename: str,
|
||||
to_copy: List[str],
|
||||
@ -115,134 +95,10 @@ def subtuple(obj: object,
|
||||
return _TYPE_CACHE[typename](**values)
|
||||
|
||||
|
||||
def read_bucket_settings(phase: str, dim: str, **defaults):
|
||||
"""Read bucketing configuration from env variables.
|
||||
|
||||
phase is either 'prompt' or 'decode'
|
||||
dim is either 'bs', 'seq' or 'block'
|
||||
param is either 'min', 'step' or 'max'
|
||||
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
|
||||
"""
|
||||
params = ['min', 'step', 'max']
|
||||
env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params]
|
||||
default_values = [defaults[p] for p in params]
|
||||
values = [
|
||||
int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values)
|
||||
]
|
||||
for e, v, d in zip(env_vars, values, default_values):
|
||||
logger.info('%s=%s (default:%s)', e, v, d)
|
||||
return values
|
||||
|
||||
|
||||
def warmup_range(config: Tuple[int, int, int]):
|
||||
"""Generate a warmup range.
|
||||
|
||||
Start from bmin and multiply by 2 until you reach bstep.
|
||||
Then, increase the values in the range by the value of bstep until you
|
||||
reach bmax.
|
||||
|
||||
Example:
|
||||
bmin = 2, bstep = 32, bmax = 64
|
||||
=> ramp_up = (2, 4, 8, 16)
|
||||
=> stable = (32, 64)
|
||||
=> return ramp_up + stable => (2, 4, 8, 16, 32, 64)
|
||||
"""
|
||||
bmin, bstep, bmax = config
|
||||
assert bmin <= bmax, ("Min. batch size cannot be greater than max. "
|
||||
"batch size. If you want to skip warmup, "
|
||||
"set VLLM_SKIP_WARMUP=true")
|
||||
base = itertools.repeat(2)
|
||||
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
|
||||
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
|
||||
ramp_up_acc)
|
||||
stable = range(bstep, bmax + 1, bstep)
|
||||
buckets = list(ramp_up_tw) + list(stable)
|
||||
return list(filter(lambda bucket: bucket >= bmin, buckets))
|
||||
|
||||
|
||||
def generate_prompt_buckets(bs_bucket_config,
|
||||
seq_bucket_config,
|
||||
max_num_batched_tokens=None):
|
||||
buckets = list(
|
||||
itertools.product(warmup_range(bs_bucket_config),
|
||||
warmup_range(seq_bucket_config)))
|
||||
if len(buckets) == 0:
|
||||
msg = ("No buckets could be captured with following config "
|
||||
f"(min, step, max_warmup): "
|
||||
f"bs:{bs_bucket_config}, "
|
||||
f"seq:{seq_bucket_config}")
|
||||
raise ValueError(msg)
|
||||
|
||||
filtered_buckets = buckets
|
||||
if max_num_batched_tokens is not None:
|
||||
# Remove buckets exceeding batch token budget
|
||||
filtered_buckets = list(
|
||||
filter(
|
||||
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
|
||||
buckets))
|
||||
|
||||
if len(filtered_buckets) == 0:
|
||||
# we can handle this if we ignore max_num_batched_tokens
|
||||
min_bucket_bs, min_bucket_seq = min(buckets,
|
||||
key=lambda b: (b[0] * b[1]))
|
||||
min_reqd_budget = min_bucket_bs * min_bucket_seq
|
||||
msg = (
|
||||
"The current bucketing configuration "
|
||||
f"(min, step, max_warmup): "
|
||||
f"bs:{bs_bucket_config}, "
|
||||
f"seq:{seq_bucket_config} cannot be used with specified "
|
||||
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
|
||||
f"smallest bucket ({min_reqd_budget}) would exceed token "
|
||||
"budget. Please increase max_num_batched_tokens or decrease "
|
||||
"bucket minimum Ignoring max_num_batched_tokens at risk of "
|
||||
"out-of-memory errors.")
|
||||
logger.error(msg)
|
||||
return list(
|
||||
sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), []
|
||||
|
||||
captured_buckets = list(
|
||||
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
|
||||
omitted_buckets = list(
|
||||
sorted([x for x in buckets if x not in filtered_buckets]))
|
||||
return captured_buckets, omitted_buckets
|
||||
|
||||
|
||||
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
|
||||
max_blocks):
|
||||
buckets = []
|
||||
bs_buckets = warmup_range(bs_bucket_config)
|
||||
block_buckets = warmup_range(blocks_bucket_config)
|
||||
bmin, bstep, bmax = blocks_bucket_config
|
||||
last_bucket = round_up(max_blocks, bstep)
|
||||
for bs in bs_buckets:
|
||||
for blocks in block_buckets:
|
||||
if blocks < bs:
|
||||
continue
|
||||
if blocks > last_bucket:
|
||||
break
|
||||
buckets.append((bs, blocks))
|
||||
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
|
||||
|
||||
|
||||
def next_pow2(value: int, base: int):
|
||||
res = base
|
||||
while value > 1:
|
||||
value = (value + 1) // 2
|
||||
res *= 2
|
||||
return res
|
||||
|
||||
|
||||
def round_up(value: int, k: int):
|
||||
return (value + k - 1) // k * k
|
||||
|
||||
|
||||
def find_bucket(value: int, config: Tuple[int, int, int]):
|
||||
bmin, bstep, _ = config
|
||||
next_step = round_up(value, bstep)
|
||||
next_pow = next_pow2(value, bmin)
|
||||
return max(bmin, min(next_step, next_pow))
|
||||
|
||||
|
||||
def align_workers(value, op):
|
||||
group = get_world_group().cpu_group
|
||||
world_size = torch.distributed.get_world_size()
|
||||
@ -406,16 +262,6 @@ class HpuModelAdapter:
|
||||
attn_bias=attn_bias)
|
||||
return metadata
|
||||
|
||||
def _set_block_scales(self, metadata, device):
|
||||
block_mapping = metadata.block_mapping
|
||||
ones = torch.ones((block_mapping.size(0), ),
|
||||
device=device,
|
||||
dtype=block_mapping.dtype)
|
||||
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
|
||||
block_scales = torch.reciprocal(torch.maximum(ones, sums))
|
||||
metadata = metadata._replace(block_scales=block_scales)
|
||||
return metadata
|
||||
|
||||
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
|
||||
dtype):
|
||||
if attn_metadata.is_prompt:
|
||||
@ -426,7 +272,6 @@ class HpuModelAdapter:
|
||||
meta = attn_metadata
|
||||
attn_metadata = self._set_block_mapping(meta, batch_size, device,
|
||||
dtype)
|
||||
attn_metadata = self._set_block_scales(attn_metadata, device)
|
||||
return attn_metadata
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -625,6 +470,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
environment.set_model_config(self.model_config)
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
@ -664,8 +510,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self.profiler_counter_helper = HabanaProfilerCounterHelper()
|
||||
self.seen_configs: set = set()
|
||||
self._mem_margin: Optional[int] = None
|
||||
self.bucketing_global_state = HPUBucketingGlobalState()
|
||||
self._setup_buckets()
|
||||
HPUBucketingContext = get_bucketing_context()
|
||||
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
|
||||
self.max_num_prefill_seqs,
|
||||
self.block_size,
|
||||
self.max_num_batched_tokens,
|
||||
False, self.max_model_len)
|
||||
self.graphed_buckets: Set[Any] = set()
|
||||
self._set_gc_threshold()
|
||||
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
|
||||
|
||||
@ -773,6 +624,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
msg = f"Loading model weights took in total {m.get_summary_string()}"
|
||||
logger.info(msg)
|
||||
|
||||
def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
|
||||
real_batch_size = len(seq_group_metadata_list)
|
||||
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
|
||||
real_batch_size, is_prompt)
|
||||
batch_size_padding = batch_size_padded - real_batch_size
|
||||
|
||||
seq_group_metadata_list = seq_group_metadata_list.copy()
|
||||
|
||||
if batch_size_padding > 0:
|
||||
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
|
||||
0, 0, is_prompt)
|
||||
seq_group_metadata_list.extend(dummy_seq_group_metadata
|
||||
for _ in range(batch_size_padding))
|
||||
return seq_group_metadata_list, real_batch_size, batch_size_padded
|
||||
|
||||
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
|
||||
return htorch.hpu.wrap_in_hpu_graph(
|
||||
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
|
||||
@ -792,46 +658,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
def _is_valid_bucket(self, bucket):
|
||||
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
|
||||
|
||||
def _setup_buckets(self) -> None:
|
||||
align_bs = lambda x: min(self.max_num_seqs, x)
|
||||
#FIXME: The default values should be max_model_len
|
||||
max_prompt_seq = 1024
|
||||
max_decode_seq = 2048
|
||||
self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings(
|
||||
'prompt',
|
||||
'bs',
|
||||
min=1,
|
||||
step=align_bs(32),
|
||||
max=self.max_num_prefill_seqs)
|
||||
self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings(
|
||||
'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs)
|
||||
self.bucketing_global_state.prompt_seq_bucket_cfg = \
|
||||
read_bucket_settings(
|
||||
'prompt',
|
||||
'seq',
|
||||
min=self.block_size,
|
||||
step=self.block_size,
|
||||
max=max_prompt_seq)
|
||||
self.bucketing_global_state.decode_block_bucket_cfg = \
|
||||
read_bucket_settings(
|
||||
'decode',
|
||||
'block',
|
||||
min=self.block_size,
|
||||
step=self.block_size,
|
||||
max=max(self.block_size,
|
||||
self.max_num_seqs * max_decode_seq // self.block_size))
|
||||
self.graphed_buckets: Set[Any] = set()
|
||||
|
||||
msg = ("Prompt bucket config (min, step, max_warmup) "
|
||||
f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, "
|
||||
f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}")
|
||||
logger.info(msg)
|
||||
|
||||
msg = ("Decode bucket config (min, step, max_warmup) "
|
||||
f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, "
|
||||
f"block:{self.bucketing_global_state.decode_block_bucket_cfg}")
|
||||
logger.info(msg)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -947,8 +773,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
assert max_query_len > 0
|
||||
|
||||
max_prompt_len = max(
|
||||
find_bucket(max(seq_lens),
|
||||
self.bucketing_global_state.prompt_seq_bucket_cfg),
|
||||
self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len),
|
||||
self.block_size)
|
||||
|
||||
lora_ids: List[int] = []
|
||||
@ -997,7 +822,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_usage=None,
|
||||
block_indices=block_indices,
|
||||
block_offsets=block_offsets,
|
||||
block_scales=None,
|
||||
block_groups=None,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
@ -1124,9 +948,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
padding_fn = None
|
||||
if self.use_contiguous_pa:
|
||||
block_bucket_size = max(max(block_list) + 1, len(block_list))
|
||||
block_bucket_size = find_bucket(
|
||||
block_bucket_size,
|
||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||
block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||
block_bucket_size)
|
||||
indices: List[Any]
|
||||
indices = [None] * block_bucket_size
|
||||
for i, bid in enumerate(block_list):
|
||||
@ -1134,9 +957,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
padding_fn = lambda tensor, pad_value: gather_list(
|
||||
tensor, indices, pad_value)
|
||||
else:
|
||||
block_bucket_size = find_bucket(
|
||||
len(block_list),
|
||||
self.bucketing_global_state.decode_block_bucket_cfg)
|
||||
block_bucket_size = \
|
||||
self.bucketing_ctx.get_padded_decode_num_blocks(
|
||||
len(block_list))
|
||||
padding_fn = lambda tensor, pad_value: pad_list(
|
||||
tensor, block_bucket_size, pad_value)
|
||||
|
||||
@ -1167,7 +990,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
block_usage=block_usage,
|
||||
block_indices=block_indices,
|
||||
block_offsets=block_offsets,
|
||||
block_scales=None,
|
||||
block_groups=block_groups,
|
||||
attn_bias=None,
|
||||
seq_lens_tensor=None,
|
||||
@ -1210,17 +1032,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
base_event_name = 'prompt' if is_prompt else 'decode'
|
||||
self.profiler.start('internal', base_event_name)
|
||||
|
||||
real_batch_size = len(seq_group_metadata_list)
|
||||
bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \
|
||||
if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg
|
||||
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
|
||||
batch_size_padding = batch_size_padded - real_batch_size
|
||||
seq_group_metadata_list = seq_group_metadata_list.copy()
|
||||
if batch_size_padding > 0:
|
||||
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
|
||||
0, 0, is_prompt)
|
||||
seq_group_metadata_list.extend(dummy_seq_group_metadata
|
||||
for _ in range(batch_size_padding))
|
||||
seq_group_metadata_list, real_batch_size, batch_size_padded = (
|
||||
self._add_dummy_seq(seq_group_metadata_list, is_prompt))
|
||||
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
@ -1382,7 +1195,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
|
||||
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
|
||||
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
|
||||
'block_offsets', 'block_scales', 'block_groups'
|
||||
'block_offsets', 'block_groups'
|
||||
])
|
||||
return attention_metadata
|
||||
|
||||
@ -1420,16 +1233,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
bind_kv_cache(
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
[kv_caches])
|
||||
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
||||
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
||||
self.scheduler_config.max_num_seqs)
|
||||
self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
|
||||
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
|
||||
max_batch_size = min(self.max_num_seqs,
|
||||
self.max_num_batched_tokens // max_seq_len)
|
||||
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
||||
False, True)
|
||||
return
|
||||
|
||||
def warmup_scenario(self,
|
||||
batch_size,
|
||||
seq_len,
|
||||
is_prompt,
|
||||
kv_caches,
|
||||
is_pt_profiler_run=False,
|
||||
is_lora_profile_run=False) -> None:
|
||||
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
|
||||
@ -1565,16 +1380,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
f"free_mem:{free_mem}")
|
||||
logger.info(msg)
|
||||
|
||||
def warmup_all_buckets(self, buckets, is_prompt):
|
||||
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
|
||||
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
|
||||
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
|
||||
len(buckets), batch_size, seq_len)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
||||
|
||||
def warmup_graphs(self,
|
||||
strategy,
|
||||
buckets,
|
||||
is_prompt,
|
||||
kv_caches,
|
||||
available_mem,
|
||||
starting_mem=0,
|
||||
total_batch_seq=0.001):
|
||||
@ -1606,7 +1422,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
self.graphed_buckets.add(graphed_bucket)
|
||||
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
|
||||
with HabanaMemoryProfiler() as mem_prof:
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt)
|
||||
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
|
||||
used_mem = align_workers(mem_prof.consumed_device_memory,
|
||||
torch.distributed.ReduceOp.MAX)
|
||||
available_mem -= used_mem
|
||||
@ -1630,50 +1446,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
|
||||
@torch.inference_mode()
|
||||
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
|
||||
max_blocks = kv_caches[0][0].size(0)
|
||||
self.bucketing_ctx.generate_decode_buckets(max_blocks)
|
||||
if profile := os.environ.get('VLLM_PT_PROFILE', None):
|
||||
phase, bs, seq_len, graph = profile.split('_')
|
||||
is_prompt = phase == 'prompt'
|
||||
graphs = graph == 't'
|
||||
if graphs:
|
||||
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
|
||||
self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
|
||||
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
|
||||
True)
|
||||
raise AssertionError("Finished profiling")
|
||||
if self.skip_warmup:
|
||||
logger.info("Skipping warmup...")
|
||||
return
|
||||
self.profiler.start('internal', 'warmup')
|
||||
max_blocks = kv_caches[0][0].size(0)
|
||||
|
||||
self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \
|
||||
generate_prompt_buckets(
|
||||
self.bucketing_global_state.prompt_bs_bucket_cfg,
|
||||
self.bucketing_global_state.prompt_seq_bucket_cfg,
|
||||
self.max_num_batched_tokens)
|
||||
|
||||
msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} "
|
||||
f"prompt buckets [bs, seq]: \
|
||||
{list(sorted(self.bucketing_global_state.prompt_buckets))}")
|
||||
logger.info(msg)
|
||||
|
||||
msg = (f"Omitted {len(prompt_omitted_buckets)} "
|
||||
"prompt buckets due to exceeded token budget "
|
||||
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
|
||||
logger.info(msg)
|
||||
|
||||
msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
|
||||
logger.debug(msg)
|
||||
|
||||
self.bucketing_global_state.decode_buckets = generate_decode_buckets(
|
||||
self.bucketing_global_state.decode_bs_bucket_cfg,
|
||||
self.bucketing_global_state.decode_block_bucket_cfg, max_blocks)
|
||||
logger.info("Generated %d decode buckets [bs, total_blocks]: %s",
|
||||
len(self.bucketing_global_state.decode_buckets),
|
||||
list(sorted(self.bucketing_global_state.decode_buckets)))
|
||||
|
||||
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
|
||||
cache_size_limit = 1 + 3 * (
|
||||
len(self.bucketing_global_state.prompt_buckets) +
|
||||
len(self.bucketing_global_state.decode_buckets))
|
||||
len(self.bucketing_ctx.prompt_buckets) +
|
||||
len(self.bucketing_ctx.decode_buckets))
|
||||
torch._dynamo.config.cache_size_limit = max(
|
||||
cache_size_limit, torch._dynamo.config.cache_size_limit)
|
||||
# Multiply by 8 to follow the original default ratio between
|
||||
@ -1681,7 +1468,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
torch._dynamo.config.accumulated_cache_size_limit = max(
|
||||
cache_size_limit * 8,
|
||||
torch._dynamo.config.accumulated_cache_size_limit)
|
||||
|
||||
if self.skip_warmup:
|
||||
logger.info("Skipping warmup...")
|
||||
return
|
||||
self.profiler.start('internal', 'warmup')
|
||||
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
@ -1700,10 +1490,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
'Please update Gaudi Software Suite.')
|
||||
with compile_only_mode_context(
|
||||
) if can_use_compile_only_mode else contextlib.nullcontext():
|
||||
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
|
||||
True)
|
||||
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
|
||||
False)
|
||||
print("aa")
|
||||
self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True,
|
||||
kv_caches)
|
||||
print("bb")
|
||||
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False,
|
||||
kv_caches)
|
||||
|
||||
if not self.enforce_eager and htorch.utils.internal.is_lazy():
|
||||
assert self.mem_margin is not None, \
|
||||
@ -1733,12 +1525,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
'max_bs')
|
||||
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
|
||||
self.warmup_graphs(
|
||||
prompt_strategy, self.bucketing_global_state.prompt_buckets,
|
||||
True, prompt_available_memory)
|
||||
prompt_strategy, self.bucketing_ctx.prompt_buckets,
|
||||
True, kv_caches, prompt_available_memory)
|
||||
mem_post_decode, decode_batch_seq, decode_captured_all = \
|
||||
self.warmup_graphs(
|
||||
decode_strategy, self.bucketing_global_state.decode_buckets,
|
||||
False, decode_available_memory)
|
||||
decode_strategy, self.bucketing_ctx.decode_buckets,
|
||||
False, kv_caches, decode_available_memory)
|
||||
|
||||
# Not all prompt buckets were captured, but all decode buckets
|
||||
# were captured and we have some free graph-allocated space
|
||||
@ -1747,8 +1539,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
and not prompt_captured_all and decode_captured_all):
|
||||
mem_post_prompt, _, prompt_captured_all = (
|
||||
self.warmup_graphs(
|
||||
prompt_strategy,
|
||||
self.bucketing_global_state.prompt_buckets, True,
|
||||
prompt_strategy, self.bucketing_ctx.prompt_buckets,
|
||||
True, kv_caches,
|
||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||
mem_post_prompt, prompt_batch_seq))
|
||||
|
||||
@ -1759,17 +1551,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
||||
and not decode_captured_all \
|
||||
and prompt_captured_all:
|
||||
mem_post_decode, _, _ = self.warmup_graphs(
|
||||
decode_strategy,
|
||||
self.bucketing_global_state.decode_buckets, False,
|
||||
decode_strategy, self.bucketing_ctx.decode_buckets,
|
||||
False, kv_caches,
|
||||
graph_free_mem - mem_post_prompt - mem_post_decode,
|
||||
mem_post_decode, decode_batch_seq)
|
||||
|
||||
self.log_graph_warmup_summary(
|
||||
self.bucketing_global_state.prompt_buckets, True,
|
||||
mem_post_prompt)
|
||||
self.bucketing_ctx.prompt_buckets, True, mem_post_prompt)
|
||||
self.log_graph_warmup_summary(
|
||||
self.bucketing_global_state.decode_buckets, False,
|
||||
mem_post_decode)
|
||||
self.bucketing_ctx.decode_buckets, False, mem_post_decode)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
|
||||
|
@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase):
|
||||
cache_block_size)
|
||||
num_hpu_blocks = max(num_hpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
|
||||
|
||||
if self.model_runner.lora_manager:
|
||||
self.model_runner.remove_all_loras()
|
||||
|
Reference in New Issue
Block a user