mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#25513)
Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
@ -139,6 +139,21 @@ def test_custom_compile_config(
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"optimization_level",
|
||||
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
|
||||
)
|
||||
def test_fp8_kv_scale_compile(optimization_level: int):
|
||||
model = "Qwen/Qwen2-0.5B"
|
||||
model_kwargs = {
|
||||
"quantization": "fp8",
|
||||
"kv_cache_dtype": "fp8_e4m3",
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
run_model(optimization_level, model, model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available "
|
||||
|
@ -277,9 +277,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(query, key, value)
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
|
||||
self.layer_name)
|
||||
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector(
|
||||
attn_metadata[layer_name])
|
||||
|
||||
|
||||
def maybe_calc_kv_scales(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
if attn_metadata is None or not getattr(
|
||||
attn_metadata, 'enable_kv_scales_calculation', False):
|
||||
return
|
||||
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_calc_kv_scales",
|
||||
op_func=maybe_calc_kv_scales,
|
||||
mutates_args=["query", "key", "value"],
|
||||
fake_impl=maybe_calc_kv_scales_fake,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
|
@ -2351,6 +2351,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.cudagraph_dispatcher.dispatch(batch_descriptor,
|
||||
use_cascade_attn)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
if attn_metadata is not None:
|
||||
metadata_list = (attn_metadata.values() if isinstance(
|
||||
attn_metadata, dict) else [attn_metadata])
|
||||
if any(
|
||||
getattr(m, 'enable_kv_scales_calculation', False)
|
||||
for m in metadata_list):
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
# This is currently to get around the assert in the DPMetadata
|
||||
# where it wants `num_tokens_across_dp` to align with `num_tokens`
|
||||
if ubatch_slices is not None:
|
||||
|
Reference in New Issue
Block a user