diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 830f6ef3f5..5ac58bc028 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -9,4 +9,4 @@ numpy==1.26.4 tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 15625612e0..55a63a8167 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -4,14 +4,14 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### -import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type import torch +import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, - VLLMKVCache) +from vllm_hpu_extension.flags import enabled_flags +from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, @@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() - ops.pa_impl = ops.pa + self.fused_scaled_dot_product_attention = kernels.fsdpa() + + self.prefill_impl = 'naive' + if "flex_attention" in enabled_flags(): + self.prefill_impl = 'flex' + if "fsdpa" in enabled_flags(): + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + self.prefill_impl = 'fsdpa' self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window @@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] - self.fused_scaled_dot_product_attention = None - if self.prefill_usefusedsdpa: + if self.prefill_impl == 'fsdpa': assert alibi_slopes is None, \ 'Prefill with FusedSDPA not supported with alibi slopes!' - try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA - self.fused_scaled_dot_product_attention = ModuleFusedSDPA( - FusedSDPA) - except ImportError: - logger.warning("Could not import HPU FusedSDPA kernel. " - "vLLM will use native implementation.") supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() if head_size not in supported_head_sizes: @@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - if attn_type != AttentionType.DECODER: + self.attn_type = attn_type + if self.attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " @@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape - query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) block_indices = attn_metadata.block_indices block_offsets = attn_metadata.block_offsets - if attn_metadata.is_prompt: + key_cache = None + value_cache = None + if attn_metadata.is_prompt and self.attn_type \ + is not AttentionType.ENCODER_ONLY \ + and attn_metadata.block_list is None: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None: + if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): if attn_metadata.is_prompt: # Prompt run. - if not self.prefill_usefusedsdpa: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ - 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None: - position_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, - attn_bias.dtype, - attn_bias.shape[-1]) - attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) - else: - attn_bias = None - query_shape = (batch_size, seq_len, self.num_heads, self.head_size) kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, self.head_size) + + attn_bias = attn_metadata.attn_bias + if attn_bias is not None and self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) + out = ops.prompt_attention( - query.view(query_shape), - key.view(kv_shape), - value.view(kv_shape), + impl=self.prefill_impl, + query=query.view(query_shape), + key=key.view(kv_shape), + value=value.view(kv_shape), + is_causal=True, attn_bias=attn_bias, - p=0.0, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - softmax_op=self.softmax, - matmul_av_op=self.matmul_av, - fsdpa_op=self.fused_scaled_dot_product_attention, - ) + valid_seq_lengths=attn_metadata.seq_lens_tensor, + **self.common_attention_args()) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. @@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, - block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, - scale=self.scale, - matmul_qk_op=self.matmul_qk, - matmul_av_op=self.matmul_av, - batch2block_matmul_op=self.batch2block_matmul, - block2batch_matmul_op=self.block2batch_matmul, - keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + **self.common_attention_args()) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) + def common_attention_args(self): + fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + if self.fused_scaled_dot_product_attention is not None else None + return { + 'scale': self.scale, + 'matmul_qk_op': self.matmul_qk, + 'matmul_av_op': self.matmul_av, + 'batch2block_matmul_op': self.batch2block_matmul, + 'block2batch_matmul_op': self.block2batch_matmul, + 'fsdpa_op': fsdpa_op, + 'keys_fetch_func': self.k_cache.fetch_from_cache, + 'values_fetch_func': self.v_cache.fetch_from_cache, + 'softmax_op': self.softmax, + } + def _make_alibi_bias( alibi_slopes: torch.Tensor, diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 49ea420d09..1dedd2ffc5 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata: block_usage: Optional[torch.Tensor] block_indices: Optional[torch.Tensor] block_offsets: Optional[torch.Tensor] - block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 5e8eb6c54c..75a5317b10 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -168,7 +168,8 @@ class RMSNorm(CustomOp): x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from vllm_hpu_extension.ops import HPUFusedRMSNorm + from vllm_hpu_extension.kernels import rms_norm + HPUFusedRMSNorm = rms_norm() if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 2a49563436..7b606272ad 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -11,11 +11,9 @@ import functools import gc import itertools import math -import operator import os import time from array import array -from dataclasses import dataclass, field from enum import IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -24,8 +22,9 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch import torch.nn as nn +import vllm_hpu_extension.environment as environment +from vllm_hpu_extension.bucketing.common import get_bucketing_context from vllm_hpu_extension.ops import LoraMask as LoraMask -from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, HabanaMemoryProfiler, format_bytes) @@ -77,25 +76,6 @@ LORA_WARMUP_RANK = 8 DUMMY_TOKEN_ID = -1 -class Singleton(type): - _instances: Dict[type, object] = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] - - -@dataclass -class HPUBucketingGlobalState(metaclass=Singleton): - prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False) - decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False) - prompt_buckets: List[Tuple[int, int]] = field(init=False) - decode_buckets: List[Tuple[int, int]] = field(init=False) - - def subtuple(obj: object, typename: str, to_copy: List[str], @@ -115,134 +95,10 @@ def subtuple(obj: object, return _TYPE_CACHE[typename](**values) -def read_bucket_settings(phase: str, dim: str, **defaults): - """Read bucketing configuration from env variables. - - phase is either 'prompt' or 'decode' - dim is either 'bs', 'seq' or 'block' - param is either 'min', 'step' or 'max' - example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 - """ - params = ['min', 'step', 'max'] - env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] - default_values = [defaults[p] for p in params] - values = [ - int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values) - ] - for e, v, d in zip(env_vars, values, default_values): - logger.info('%s=%s (default:%s)', e, v, d) - return values - - -def warmup_range(config: Tuple[int, int, int]): - """Generate a warmup range. - - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. - - Example: - bmin = 2, bstep = 32, bmax = 64 - => ramp_up = (2, 4, 8, 16) - => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) - """ - bmin, bstep, bmax = config - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") - base = itertools.repeat(2) - ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) - stable = range(bstep, bmax + 1, bstep) - buckets = list(ramp_up_tw) + list(stable) - return list(filter(lambda bucket: bucket >= bmin, buckets)) - - -def generate_prompt_buckets(bs_bucket_config, - seq_bucket_config, - max_num_batched_tokens=None): - buckets = list( - itertools.product(warmup_range(bs_bucket_config), - warmup_range(seq_bucket_config))) - if len(buckets) == 0: - msg = ("No buckets could be captured with following config " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config}") - raise ValueError(msg) - - filtered_buckets = buckets - if max_num_batched_tokens is not None: - # Remove buckets exceeding batch token budget - filtered_buckets = list( - filter( - lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens, - buckets)) - - if len(filtered_buckets) == 0: - # we can handle this if we ignore max_num_batched_tokens - min_bucket_bs, min_bucket_seq = min(buckets, - key=lambda b: (b[0] * b[1])) - min_reqd_budget = min_bucket_bs * min_bucket_seq - msg = ( - "The current bucketing configuration " - f"(min, step, max_warmup): " - f"bs:{bs_bucket_config}, " - f"seq:{seq_bucket_config} cannot be used with specified " - f"max_num_batched_tokens ({max_num_batched_tokens}), as the " - f"smallest bucket ({min_reqd_budget}) would exceed token " - "budget. Please increase max_num_batched_tokens or decrease " - "bucket minimum Ignoring max_num_batched_tokens at risk of " - "out-of-memory errors.") - logger.error(msg) - return list( - sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), [] - - captured_buckets = list( - sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - omitted_buckets = list( - sorted([x for x in buckets if x not in filtered_buckets])) - return captured_buckets, omitted_buckets - - -def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, - max_blocks): - buckets = [] - bs_buckets = warmup_range(bs_bucket_config) - block_buckets = warmup_range(blocks_bucket_config) - bmin, bstep, bmax = blocks_bucket_config - last_bucket = round_up(max_blocks, bstep) - for bs in bs_buckets: - for blocks in block_buckets: - if blocks < bs: - continue - if blocks > last_bucket: - break - buckets.append((bs, blocks)) - return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) - - -def next_pow2(value: int, base: int): - res = base - while value > 1: - value = (value + 1) // 2 - res *= 2 - return res - - def round_up(value: int, k: int): return (value + k - 1) // k * k -def find_bucket(value: int, config: Tuple[int, int, int]): - bmin, bstep, _ = config - next_step = round_up(value, bstep) - next_pow = next_pow2(value, bmin) - return max(bmin, min(next_step, next_pow)) - - def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -406,16 +262,6 @@ class HpuModelAdapter: attn_bias=attn_bias) return metadata - def _set_block_scales(self, metadata, device): - block_mapping = metadata.block_mapping - ones = torch.ones((block_mapping.size(0), ), - device=device, - dtype=block_mapping.dtype) - sums = batch2block(block2batch(ones, block_mapping), block_mapping) - block_scales = torch.reciprocal(torch.maximum(ones, sums)) - metadata = metadata._replace(block_scales=block_scales) - return metadata - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -426,7 +272,6 @@ class HpuModelAdapter: meta = attn_metadata attn_metadata = self._set_block_mapping(meta, batch_size, device, dtype) - attn_metadata = self._set_block_scales(attn_metadata, device) return attn_metadata def forward(self, *args, **kwargs): @@ -625,6 +470,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): return_hidden_states: bool = False, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) + environment.set_model_config(self.model_config) self.is_driver_worker = is_driver_worker self.return_hidden_states = return_hidden_states @@ -664,8 +510,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_global_state = HPUBucketingGlobalState() - self._setup_buckets() + HPUBucketingContext = get_bucketing_context() + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, + self.max_num_prefill_seqs, + self.block_size, + self.max_num_batched_tokens, + False, self.max_model_len) + self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH @@ -773,6 +624,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): msg = f"Loading model weights took in total {m.get_summary_string()}" logger.info(msg) + def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): + real_batch_size = len(seq_group_metadata_list) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) + batch_size_padding = batch_size_padded - real_batch_size + + seq_group_metadata_list = seq_group_metadata_list.copy() + + if batch_size_padding > 0: + dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( + 0, 0, is_prompt) + seq_group_metadata_list.extend(dummy_seq_group_metadata + for _ in range(batch_size_padding)) + return seq_group_metadata_list, real_batch_size, batch_size_padded + def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): return htorch.hpu.wrap_in_hpu_graph( HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True @@ -792,46 +658,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_prefill_seqs) - self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( - 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) - self.bucketing_global_state.prompt_seq_bucket_cfg = \ - read_bucket_settings( - 'prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.bucketing_global_state.decode_block_bucket_cfg = \ - read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - self.graphed_buckets: Set[Any] = set() - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " - f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " - f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") - logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -947,8 +773,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): assert max_query_len > 0 max_prompt_len = max( - find_bucket(max(seq_lens), - self.bucketing_global_state.prompt_seq_bucket_cfg), + self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), self.block_size) lora_ids: List[int] = [] @@ -997,7 +822,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_usage=None, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, @@ -1124,9 +948,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = find_bucket( - block_bucket_size, - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -1134,9 +957,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) else: - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = \ + self.bucketing_ctx.get_padded_decode_num_blocks( + len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1167,7 +990,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): block_usage=block_usage, block_indices=block_indices, block_offsets=block_offsets, - block_scales=None, block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, @@ -1210,17 +1032,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): base_event_name = 'prompt' if is_prompt else 'decode' self.profiler.start('internal', base_event_name) - real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ - if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) - batch_size_padding = batch_size_padded - real_batch_size - seq_group_metadata_list = seq_group_metadata_list.copy() - if batch_size_padding > 0: - dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt) - seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) + seq_group_metadata_list, real_batch_size, batch_size_padded = ( + self._add_dummy_seq(seq_group_metadata_list, is_prompt)) prefill_reqs = [] decode_reqs = [] @@ -1382,7 +1195,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_scales', 'block_groups' + 'block_offsets', 'block_groups' ]) return attention_metadata @@ -1420,16 +1233,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, [kv_caches]) - max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = min(self.max_num_batched_tokens // max_seq_len, - self.scheduler_config.max_num_seqs) - self.warmup_scenario(max_batch_size, max_seq_len, True, False, True) + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = min(self.max_num_seqs, + self.max_num_batched_tokens // max_seq_len) + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False, True) return def warmup_scenario(self, batch_size, seq_len, is_prompt, + kv_caches, is_pt_profiler_run=False, is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) @@ -1565,16 +1380,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt): + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, buckets, is_prompt, + kv_caches, available_mem, starting_mem=0, total_batch_seq=0.001): @@ -1606,7 +1422,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt) + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX) available_mem -= used_mem @@ -1630,50 +1446,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): @torch.inference_mode() def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: + max_blocks = kv_caches[0][0].size(0) + self.bucketing_ctx.generate_decode_buckets(max_blocks) if profile := os.environ.get('VLLM_PT_PROFILE', None): phase, bs, seq_len, graph = profile.split('_') is_prompt = phase == 'prompt' graphs = graph == 't' if graphs: self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, True) + self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, + True) raise AssertionError("Finished profiling") - if self.skip_warmup: - logger.info("Skipping warmup...") - return - self.profiler.start('internal', 'warmup') - max_blocks = kv_caches[0][0].size(0) - - self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ - generate_prompt_buckets( - self.bucketing_global_state.prompt_bs_bucket_cfg, - self.bucketing_global_state.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " - f"prompt buckets [bs, seq]: \ - {list(sorted(self.bucketing_global_state.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - self.bucketing_global_state.decode_buckets = generate_decode_buckets( - self.bucketing_global_state.decode_bs_bucket_cfg, - self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.bucketing_global_state.decode_buckets), - list(sorted(self.bucketing_global_state.decode_buckets))) - if not htorch.utils.internal.is_lazy() and not self.enforce_eager: cache_size_limit = 1 + 3 * ( - len(self.bucketing_global_state.prompt_buckets) + - len(self.bucketing_global_state.decode_buckets)) + len(self.bucketing_ctx.prompt_buckets) + + len(self.bucketing_ctx.decode_buckets)) torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1681,7 +1468,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): torch._dynamo.config.accumulated_cache_size_limit = max( cache_size_limit * 8, torch._dynamo.config.accumulated_cache_size_limit) - + if self.skip_warmup: + logger.info("Skipping warmup...") + return + self.profiler.start('internal', 'warmup') start_mem = HabanaMemoryProfiler.current_device_memory_usage() start_time = time.perf_counter() @@ -1700,10 +1490,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True) - self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False) + print("aa") + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + kv_caches) + print("bb") + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1733,12 +1525,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.bucketing_global_state.prompt_buckets, - True, prompt_available_memory) + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.bucketing_global_state.decode_buckets, - False, decode_available_memory) + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets # were captured and we have some free graph-allocated space @@ -1747,8 +1539,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, - self.bucketing_global_state.prompt_buckets, True, + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1759,17 +1551,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, - self.bucketing_global_state.decode_buckets, False, + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( - self.bucketing_global_state.prompt_buckets, True, - mem_post_prompt) + self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) self.log_graph_warmup_summary( - self.bucketing_global_state.decode_buckets, False, - mem_post_decode) + self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index ccb175d88f..8d7d5d7adc 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase): cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks if self.model_runner.lora_manager: self.model_runner.remove_all_loras()