[Hardware][Intel-Gaudi] Update hpu-extension and update bucketing system for HPU device (#17186)

Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
This commit is contained in:
Agata Dobrzyniewicz
2025-04-26 14:55:14 +02:00
committed by GitHub
parent 909fdaf152
commit c48334d405
6 changed files with 128 additions and 335 deletions

View File

@ -9,4 +9,4 @@ numpy==1.26.4
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624

View File

@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import vllm_hpu_extension.kernels as kernels
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
ops.pa_impl = ops.pa
self.fused_scaled_dot_product_attention = kernels.fsdpa()
self.prefill_impl = 'naive'
if "flex_attention" in enabled_flags():
self.prefill_impl = 'flex'
if "fsdpa" in enabled_flags():
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
self.prefill_impl = 'fsdpa'
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.fused_scaled_dot_product_attention = None
if self.prefill_usefusedsdpa:
if self.prefill_impl == 'fsdpa':
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(
FusedSDPA)
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")
supported_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if attn_type != AttentionType.DECODER:
self.attn_type = attn_type
if self.attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key_cache = None
value_cache = None
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY \
and attn_metadata.block_list is None:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if attn_metadata.is_prompt:
# Prompt run.
if not self.prefill_usefusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
attn_bias = attn_metadata.attn_bias
if attn_bias is not None and self.alibi_slopes is not None:
position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
attn_bias.shape[-1])
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
impl=self.prefill_impl,
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
is_causal=True,
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args())
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
**self.common_attention_args())
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
def common_attention_args(self):
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
if self.fused_scaled_dot_product_attention is not None else None
return {
'scale': self.scale,
'matmul_qk_op': self.matmul_qk,
'matmul_av_op': self.matmul_av,
'batch2block_matmul_op': self.batch2block_matmul,
'block2batch_matmul_op': self.block2batch_matmul,
'fsdpa_op': fsdpa_op,
'keys_fetch_func': self.k_cache.fetch_from_cache,
'values_fetch_func': self.v_cache.fetch_from_cache,
'softmax_op': self.softmax,
}
def _make_alibi_bias(
alibi_slopes: torch.Tensor,

View File

@ -22,7 +22,6 @@ class HPUPagedAttentionMetadata:
block_usage: Optional[torch.Tensor]
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]

View File

@ -168,7 +168,8 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm_hpu_extension.ops import HPUFusedRMSNorm
from vllm_hpu_extension.kernels import rms_norm
HPUFusedRMSNorm = rms_norm()
if HPUFusedRMSNorm is None:
return self.forward_native(x, residual)
if residual is not None:

View File

@ -11,11 +11,9 @@ import functools
import gc
import itertools
import math
import operator
import os
import time
from array import array
from dataclasses import dataclass, field
from enum import IntEnum
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Set, Tuple, Type, TypeVar, Union)
@ -24,8 +22,9 @@ import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing.common import get_bucketing_context
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
@ -77,25 +76,6 @@ LORA_WARMUP_RANK = 8
DUMMY_TOKEN_ID = -1
class Singleton(type):
_instances: Dict[type, object] = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]
@dataclass
class HPUBucketingGlobalState(metaclass=Singleton):
prompt_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
decode_bs_bucket_cfg: Tuple[int, int, int] = field(init=False)
prompt_seq_bucket_cfg: Tuple[int, int, int] = field(init=False)
decode_block_bucket_cfg: Tuple[int, int, int] = field(init=False)
prompt_buckets: List[Tuple[int, int]] = field(init=False)
decode_buckets: List[Tuple[int, int]] = field(init=False)
def subtuple(obj: object,
typename: str,
to_copy: List[str],
@ -115,134 +95,10 @@ def subtuple(obj: object,
return _TYPE_CACHE[typename](**values)
def read_bucket_settings(phase: str, dim: str, **defaults):
"""Read bucketing configuration from env variables.
phase is either 'prompt' or 'decode'
dim is either 'bs', 'seq' or 'block'
param is either 'min', 'step' or 'max'
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
"""
params = ['min', 'step', 'max']
env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params]
default_values = [defaults[p] for p in params]
values = [
int(os.environ.get(e, d)) for e, d in zip(env_vars, default_values)
]
for e, v, d in zip(env_vars, values, default_values):
logger.info('%s=%s (default:%s)', e, v, d)
return values
def warmup_range(config: Tuple[int, int, int]):
"""Generate a warmup range.
Start from bmin and multiply by 2 until you reach bstep.
Then, increase the values in the range by the value of bstep until you
reach bmax.
Example:
bmin = 2, bstep = 32, bmax = 64
=> ramp_up = (2, 4, 8, 16)
=> stable = (32, 64)
=> return ramp_up + stable => (2, 4, 8, 16, 32, 64)
"""
bmin, bstep, bmax = config
assert bmin <= bmax, ("Min. batch size cannot be greater than max. "
"batch size. If you want to skip warmup, "
"set VLLM_SKIP_WARMUP=true")
base = itertools.repeat(2)
ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
ramp_up_acc)
stable = range(bstep, bmax + 1, bstep)
buckets = list(ramp_up_tw) + list(stable)
return list(filter(lambda bucket: bucket >= bmin, buckets))
def generate_prompt_buckets(bs_bucket_config,
seq_bucket_config,
max_num_batched_tokens=None):
buckets = list(
itertools.product(warmup_range(bs_bucket_config),
warmup_range(seq_bucket_config)))
if len(buckets) == 0:
msg = ("No buckets could be captured with following config "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config}")
raise ValueError(msg)
filtered_buckets = buckets
if max_num_batched_tokens is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * bucket[1] <= max_num_batched_tokens,
buckets))
if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
min_bucket_bs, min_bucket_seq = min(buckets,
key=lambda b: (b[0] * b[1]))
min_reqd_budget = min_bucket_bs * min_bucket_seq
msg = (
"The current bucketing configuration "
f"(min, step, max_warmup): "
f"bs:{bs_bucket_config}, "
f"seq:{seq_bucket_config} cannot be used with specified "
f"max_num_batched_tokens ({max_num_batched_tokens}), as the "
f"smallest bucket ({min_reqd_budget}) would exceed token "
"budget. Please increase max_num_batched_tokens or decrease "
"bucket minimum Ignoring max_num_batched_tokens at risk of "
"out-of-memory errors.")
logger.error(msg)
return list(
sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))), []
captured_buckets = list(
sorted(filtered_buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
omitted_buckets = list(
sorted([x for x in buckets if x not in filtered_buckets]))
return captured_buckets, omitted_buckets
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
max_blocks):
buckets = []
bs_buckets = warmup_range(bs_bucket_config)
block_buckets = warmup_range(blocks_bucket_config)
bmin, bstep, bmax = blocks_bucket_config
last_bucket = round_up(max_blocks, bstep)
for bs in bs_buckets:
for blocks in block_buckets:
if blocks < bs:
continue
if blocks > last_bucket:
break
buckets.append((bs, blocks))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
def next_pow2(value: int, base: int):
res = base
while value > 1:
value = (value + 1) // 2
res *= 2
return res
def round_up(value: int, k: int):
return (value + k - 1) // k * k
def find_bucket(value: int, config: Tuple[int, int, int]):
bmin, bstep, _ = config
next_step = round_up(value, bstep)
next_pow = next_pow2(value, bmin)
return max(bmin, min(next_step, next_pow))
def align_workers(value, op):
group = get_world_group().cpu_group
world_size = torch.distributed.get_world_size()
@ -406,16 +262,6 @@ class HpuModelAdapter:
attn_bias=attn_bias)
return metadata
def _set_block_scales(self, metadata, device):
block_mapping = metadata.block_mapping
ones = torch.ones((block_mapping.size(0), ),
device=device,
dtype=block_mapping.dtype)
sums = batch2block(block2batch(ones, block_mapping), block_mapping)
block_scales = torch.reciprocal(torch.maximum(ones, sums))
metadata = metadata._replace(block_scales=block_scales)
return metadata
def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
@ -426,7 +272,6 @@ class HpuModelAdapter:
meta = attn_metadata
attn_metadata = self._set_block_mapping(meta, batch_size, device,
dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
return attn_metadata
def forward(self, *args, **kwargs):
@ -625,6 +470,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
return_hidden_states: bool = False,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
environment.set_model_config(self.model_config)
self.is_driver_worker = is_driver_worker
self.return_hidden_states = return_hidden_states
@ -664,8 +510,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.bucketing_global_state = HPUBucketingGlobalState()
self._setup_buckets()
HPUBucketingContext = get_bucketing_context()
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens,
False, self.max_model_len)
self.graphed_buckets: Set[Any] = set()
self._set_gc_threshold()
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
@ -773,6 +624,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def _add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list, real_batch_size, batch_size_padded
def _maybe_wrap_in_hpu_graph(self, *args, **kwargs):
return htorch.hpu.wrap_in_hpu_graph(
HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True
@ -792,46 +658,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
def _setup_buckets(self) -> None:
align_bs = lambda x: min(self.max_num_seqs, x)
#FIXME: The default values should be max_model_len
max_prompt_seq = 1024
max_decode_seq = 2048
self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt',
'bs',
min=1,
step=align_bs(32),
max=self.max_num_prefill_seqs)
self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings(
'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs)
self.bucketing_global_state.prompt_seq_bucket_cfg = \
read_bucket_settings(
'prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=max_prompt_seq)
self.bucketing_global_state.decode_block_bucket_cfg = \
read_bucket_settings(
'decode',
'block',
min=self.block_size,
step=self.block_size,
max=max(self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size))
self.graphed_buckets: Set[Any] = set()
msg = ("Prompt bucket config (min, step, max_warmup) "
f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, "
f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}")
logger.info(msg)
msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, "
f"block:{self.bucketing_global_state.decode_block_bucket_cfg}")
logger.info(msg)
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -947,8 +773,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
assert max_query_len > 0
max_prompt_len = max(
find_bucket(max(seq_lens),
self.bucketing_global_state.prompt_seq_bucket_cfg),
self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len),
self.block_size)
lora_ids: List[int] = []
@ -997,7 +822,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
@ -1124,9 +948,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
padding_fn = None
if self.use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
block_bucket_size = find_bucket(
block_bucket_size,
self.bucketing_global_state.decode_block_bucket_cfg)
block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
@ -1134,9 +957,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
padding_fn = lambda tensor, pad_value: gather_list(
tensor, indices, pad_value)
else:
block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_bucket_size = \
self.bucketing_ctx.get_padded_decode_num_blocks(
len(block_list))
padding_fn = lambda tensor, pad_value: pad_list(
tensor, block_bucket_size, pad_value)
@ -1167,7 +990,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
@ -1210,17 +1032,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
base_event_name = 'prompt' if is_prompt else 'decode'
self.profiler.start('internal', base_event_name)
real_batch_size = len(seq_group_metadata_list)
bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \
if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
seq_group_metadata_list, real_batch_size, batch_size_padded = (
self._add_dummy_seq(seq_group_metadata_list, is_prompt))
prefill_reqs = []
decode_reqs = []
@ -1382,7 +1195,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_scales', 'block_groups'
'block_offsets', 'block_groups'
])
return attention_metadata
@ -1420,16 +1233,18 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
bind_kv_cache(
self.vllm_config.compilation_config.static_forward_context,
[kv_caches])
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
def warmup_scenario(self,
batch_size,
seq_len,
is_prompt,
kv_caches,
is_pt_profiler_run=False,
is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
@ -1565,16 +1380,17 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
f"free_mem:{free_mem}")
logger.info(msg)
def warmup_all_buckets(self, buckets, is_prompt):
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
def warmup_graphs(self,
strategy,
buckets,
is_prompt,
kv_caches,
available_mem,
starting_mem=0,
total_batch_seq=0.001):
@ -1606,7 +1422,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX)
available_mem -= used_mem
@ -1630,50 +1446,21 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
max_blocks = kv_caches[0][0].size(0)
self.bucketing_ctx.generate_decode_buckets(max_blocks)
if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
if self.skip_warmup:
logger.info("Skipping warmup...")
return
self.profiler.start('internal', 'warmup')
max_blocks = kv_caches[0][0].size(0)
self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
self.bucketing_global_state.prompt_bs_bucket_cfg,
self.bucketing_global_state.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)
msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: \
{list(sorted(self.bucketing_global_state.prompt_buckets))}")
logger.info(msg)
msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)
msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)
self.bucketing_global_state.decode_buckets = generate_decode_buckets(
self.bucketing_global_state.decode_bs_bucket_cfg,
self.bucketing_global_state.decode_block_bucket_cfg, max_blocks)
logger.info("Generated %d decode buckets [bs, total_blocks]: %s",
len(self.bucketing_global_state.decode_buckets),
list(sorted(self.bucketing_global_state.decode_buckets)))
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = 1 + 3 * (
len(self.bucketing_global_state.prompt_buckets) +
len(self.bucketing_global_state.decode_buckets))
len(self.bucketing_ctx.prompt_buckets) +
len(self.bucketing_ctx.decode_buckets))
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between
@ -1681,7 +1468,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
torch._dynamo.config.accumulated_cache_size_limit = max(
cache_size_limit * 8,
torch._dynamo.config.accumulated_cache_size_limit)
if self.skip_warmup:
logger.info("Skipping warmup...")
return
self.profiler.start('internal', 'warmup')
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()
@ -1700,10 +1490,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
'Please update Gaudi Software Suite.')
with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
True)
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
False)
print("aa")
self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True,
kv_caches)
print("bb")
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False,
kv_caches)
if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
@ -1733,12 +1525,12 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.bucketing_global_state.prompt_buckets,
True, prompt_available_memory)
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.bucketing_global_state.decode_buckets,
False, decode_available_memory)
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches, decode_available_memory)
# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
@ -1747,8 +1539,8 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
and not prompt_captured_all and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy,
self.bucketing_global_state.prompt_buckets, True,
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))
@ -1759,17 +1551,15 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy,
self.bucketing_global_state.decode_buckets, False,
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)
self.log_graph_warmup_summary(
self.bucketing_global_state.prompt_buckets, True,
mem_post_prompt)
self.bucketing_ctx.prompt_buckets, True, mem_post_prompt)
self.log_graph_warmup_summary(
self.bucketing_global_state.decode_buckets, False,
mem_post_decode)
self.bucketing_ctx.decode_buckets, False, mem_post_decode)
end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()

View File

@ -245,6 +245,7 @@ class HPUWorker(LocalOrDistributedWorkerBase):
cache_block_size)
num_hpu_blocks = max(num_hpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()