[BugFix] AssertionError: Do not capture num_reqs > max_num_reqs for uniform batch (#25505)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-23 20:00:29 -04:00
committed by GitHub
parent 1210e4d95b
commit dc464a3d39

View File

@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
uniform_decode: bool = False,
allow_microbatching: bool = True,
@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args:
num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior.
- if not set will determine the cudagraph mode based on using
the self.cudagraph_dispatcher.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is
@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run
"""
assert cudagraph_runtime_mode in {
assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif uniform_decode:
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None
else:
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = \
self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
# sanity check
assert cudagraph_runtime_mode == _cg_mode, (
# filter out the valid batch descriptor
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
BatchDescriptor(num_tokens=num_tokens,
uniform_decode=uniform_decode))
if cudagraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for cudagraph capture
assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
else:
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None:
num_tokens = num_tokens // 2