Compare commits

...

13 Commits

Author SHA1 Message Date
1244c25908 minimize fill_
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 22:03:51 +00:00
34fb0cbbd0 minimize changes to gpu_model_runner.py
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 21:32:41 +00:00
f33ec6d2d2 Remove unneeded changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-04 18:13:50 +00:00
c33aeecf24 simplify - get rid of tokenshape
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 22:42:09 +00:00
230730c34d Updates for FA3 and other changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 20:58:48 +00:00
d151b63b8b Merge branch 'main' into hacking_full_cudagraph
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-03 14:05:45 -05:00
bed9efafeb Merge branch 'main' into hacking_full_cudagraph
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-24 16:08:18 -05:00
d4c9448b26 WIP initial working version
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-17 13:54:51 -05:00
76732ff701 Merge branch 'main' into hacking_full_cudagraph 2025-01-03 15:06:23 -05:00
22bd7296e4 disable inductor, disable piecewise
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-03 15:05:52 -05:00
4024253797 Merge branch 'main' into hacking_full_cudagraph
Eager mode works now

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-03 13:16:56 -05:00
7eba374599 flash attn changes
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-01-02 13:26:21 -05:00
0c7e6c1e36 Hacky hacky
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
2024-12-18 15:13:26 -05:00
2 changed files with 65 additions and 27 deletions

View File

@ -2925,8 +2925,9 @@ class CompilationConfig(BaseModel):
# v1 must split the graph on attention ops
# for piecewise cudagraph
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
# HACK for full cuda graph
#"vllm.unified_attention",
#"vllm.unified_attention_with_output",
]
else:
# v0 uses full graph compilation
@ -3339,8 +3340,7 @@ class VllmConfig:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]
batch_size_capture_list = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

View File

@ -11,7 +11,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
@ -122,9 +122,7 @@ class GPUModelRunner:
vocab_size=model_config.get_vocab_size(),
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.use_cuda_graph = not self.model_config.enforce_eager
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
@ -171,6 +169,20 @@ class GPUModelRunner:
dtype=self.dtype,
device=self.device)
# Attention metadata related persistent buffers
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(
self.max_num_tokens,
# CPU slot_mapping is int32, but
# this one must be int64
dtype=torch.int64,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
@ -436,12 +448,19 @@ class GPUModelRunner:
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
# Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
@ -524,12 +543,13 @@ class GPUModelRunner:
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len,
seq_lens=seq_lens,
seq_lens=self.seq_lens[:num_reqs],
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
slot_mapping=self.slot_mapping[:total_num_scheduled_tokens],
# Cascade stuff
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
@ -541,7 +561,7 @@ class GPUModelRunner:
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
logits_indices = self.query_start_loc[1:num_reqs + 1] - 1
return attn_metadata, logits_indices
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@ -855,6 +875,7 @@ class GPUModelRunner:
self,
num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[FlashAttentionMetadata] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
@ -865,7 +886,7 @@ class GPUModelRunner:
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
with set_forward_context(None, self.vllm_config):
with set_forward_context(attn_metadata, self.vllm_config):
positions = self.mrope_positions[:, :num_tokens] \
if self.model_config.uses_mrope \
else self.positions[:num_tokens]
@ -878,6 +899,28 @@ class GPUModelRunner:
)
return hidden_states
def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata:
# Create placeholder metadata
num_reqs = num_tokens
max_query_len = num_tokens
max_seq_len = num_tokens
return FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len,
seq_lens=self.seq_lens[:num_reqs],
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=self.slot_mapping[:max_seq_len],
# Cascade stuff. Non-piecewise CUDA graphs NYI
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
@ -994,12 +1037,6 @@ class GPUModelRunner:
gc.collect()
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please add "
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
@ -1008,10 +1045,11 @@ class GPUModelRunner:
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.cudagraph_batch_sizes):
attn_metadata = self.metadata_for_dummy_run(num_tokens)
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens, attn_metadata=attn_metadata)
self._dummy_run(num_tokens, attn_metadata=attn_metadata)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]