mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] AssertionError: Do not capture num_reqs > max_num_reqs for uniform batch (#25505)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
|
||||
force_attention: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
allow_microbatching: bool = True,
|
||||
@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Args:
|
||||
num_tokens: Number of tokens to run the dummy forward pass.
|
||||
cudagraph_runtime_mode: used to control the behavior.
|
||||
- if not set will determine the cudagraph mode based on using
|
||||
the self.cudagraph_dispatcher.
|
||||
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
|
||||
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
|
||||
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
|
||||
@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
(1 token) and prefill (multiple tokens) requests.
|
||||
remove_lora: If False, dummy LoRAs are not destroyed after the run
|
||||
"""
|
||||
assert cudagraph_runtime_mode in {
|
||||
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif uniform_decode:
|
||||
assert not create_mixed_batch
|
||||
num_reqs = cdiv(num_tokens, max_query_len)
|
||||
assert num_reqs <= max_num_reqs, \
|
||||
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
|
||||
f"{max_num_reqs} for uniform batch. Num tokens: " \
|
||||
f"{num_tokens}, max_query_len: {max_query_len}"
|
||||
num_scheduled_tokens_list = [max_query_len] * num_reqs
|
||||
if num_tokens % max_query_len != 0:
|
||||
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
|
||||
@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens, None, False)
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
|
||||
batch_descriptor = None
|
||||
else:
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = \
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
# sanity check
|
||||
assert cudagraph_runtime_mode == _cg_mode, (
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
if cudagraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for cudagraph capture
|
||||
assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
|
||||
cudagraph_runtime_mode == _cg_mode, (
|
||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
|
||||
else:
|
||||
cudagraph_runtime_mode = _cg_mode
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_tokens = num_tokens // 2
|
||||
|
Reference in New Issue
Block a user