Compare commits

...

13 Commits

Author SHA1 Message Date
1244c25908 minimize fill_
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 22:03:51 +00:00
34fb0cbbd0 minimize changes to gpu_model_runner.py
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 21:32:41 +00:00
f33ec6d2d2 Remove unneeded changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 18:13:50 +00:00
c33aeecf24 simplify - get rid of tokenshape
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 22:42:09 +00:00
230730c34d Updates for FA3 and other changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 20:58:48 +00:00
d151b63b8b Merge branch 'main' into hacking_full_cudagraph
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 14:05:45 -05:00
bed9efafeb Merge branch 'main' into hacking_full_cudagraph
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-24 16:08:18 -05:00
d4c9448b26 WIP initial working version
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-17 13:54:51 -05:00
76732ff701 Merge branch 'main' into hacking_full_cudagraph 2025-01-03 15:06:23 -05:00
22bd7296e4 disable inductor, disable piecewise
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-03 15:05:52 -05:00
4024253797 Merge branch 'main' into hacking_full_cudagraph
Eager mode works now

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-03 13:16:56 -05:00
7eba374599 flash attn changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-02 13:26:21 -05:00
0c7e6c1e36 Hacky hacky
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2024-12-18 15:13:26 -05:00
2 changed files with 65 additions and 27 deletions

View File

@ -2925,8 +2925,9 @@ class CompilationConfig(BaseModel):
# v1 must split the graph on attention ops # v1 must split the graph on attention ops
# for piecewise cudagraph # for piecewise cudagraph
self.splitting_ops = [ self.splitting_ops = [
"vllm.unified_attention", # HACK for full cuda graph
"vllm.unified_attention_with_output", #"vllm.unified_attention",
#"vllm.unified_attention_with_output",
] ]
else: else:
# v0 uses full graph compilation # v0 uses full graph compilation
@ -3339,8 +3340,7 @@ class VllmConfig:
batch_size_capture_list = [] batch_size_capture_list = []
if self.model_config is not None and \ if self.model_config is not None and \
not self.model_config.enforce_eager: not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4 batch_size_capture_list = [1, 2, 4] + [i for i in range(8, 513, 8)]
] + [i for i in range(8, 513, 8)]
self.compilation_config.init_with_cudagraph_sizes( self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list) batch_size_capture_list)

View File

@ -11,7 +11,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
@ -122,9 +122,7 @@ class GPUModelRunner:
vocab_size=model_config.get_vocab_size(), vocab_size=model_config.get_vocab_size(),
) )
self.use_cuda_graph = (self.vllm_config.compilation_config.level self.use_cuda_graph = not self.model_config.enforce_eager
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order. # self.cudagraph_batch_sizes sorts in ascending order.
@ -171,6 +169,20 @@ class GPUModelRunner:
dtype=self.dtype, dtype=self.dtype,
device=self.device) device=self.device)
# Attention metadata related persistent buffers
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(
self.max_num_tokens,
# CPU slot_mapping is int32, but
# this one must be int64
dtype=torch.int64,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step. # OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np = np.arange(max(self.max_num_reqs + 1, self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len, self.max_model_len,
@ -436,12 +448,19 @@ class GPUModelRunner:
self.positions[:total_num_scheduled_tokens].copy_( self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True) self.query_start_loc[:num_reqs + 1].copy_(
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
non_blocking=True) self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( non_blocking=True)
self.device, non_blocking=True).long() self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
# Prepare for cascade attention if needed. # Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks * common_prefix_len = (scheduler_output.num_common_prefix_blocks *
@ -524,12 +543,13 @@ class GPUModelRunner:
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc, query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=self.seq_lens[:num_reqs],
block_table=( block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]), self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping, slot_mapping=self.slot_mapping[:total_num_scheduled_tokens],
# Cascade stuff
use_cascade=use_cascade, use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens, cu_prefix_query_lens=cu_prefix_query_lens,
@ -541,7 +561,7 @@ class GPUModelRunner:
# partial request, we do so for simplicity. We will ignore the sampled # partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request. # token from the partial request.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = self.query_start_loc[1:num_reqs + 1] - 1
return attn_metadata, logits_indices return attn_metadata, logits_indices
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@ -855,6 +875,7 @@ class GPUModelRunner:
self, self,
num_tokens: int, num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[FlashAttentionMetadata] = None,
) -> torch.Tensor: ) -> torch.Tensor:
model = self.model model = self.model
if kv_caches is None: if kv_caches is None:
@ -865,7 +886,7 @@ class GPUModelRunner:
else: else:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
with set_forward_context(None, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
positions = self.mrope_positions[:, :num_tokens] \ positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \ if self.model_config.uses_mrope \
else self.positions[:num_tokens] else self.positions[:num_tokens]
@ -878,6 +899,28 @@ class GPUModelRunner:
) )
return hidden_states return hidden_states
def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata:
# Create placeholder metadata
num_reqs = num_tokens
max_query_len = num_tokens
max_seq_len = num_tokens
return FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len,
seq_lens=self.seq_lens[:num_reqs],
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=self.slot_mapping[:max_seq_len],
# Cascade stuff. Non-piecewise CUDA graphs NYI
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
def profile_run(self) -> None: def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass # use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`. # it by reference, rather by specializing on the value `None`.
@ -994,12 +1037,6 @@ class GPUModelRunner:
gc.collect() gc.collect()
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please add "
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter() start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -1008,10 +1045,11 @@ class GPUModelRunner:
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device): with graph_capture(device=self.device):
for num_tokens in reversed(self.cudagraph_batch_sizes): for num_tokens in reversed(self.cudagraph_batch_sizes):
attn_metadata = self.metadata_for_dummy_run(num_tokens)
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
self._dummy_run(num_tokens) self._dummy_run(num_tokens, attn_metadata=attn_metadata)
self._dummy_run(num_tokens) self._dummy_run(num_tokens, attn_metadata=attn_metadata)
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0] end_free_gpu_memory = torch.cuda.mem_get_info()[0]