[BugFix][torch.compile] KV scale calculation issues with FP8 quantization (#25513)

Signed-off-by: adabeyta <aabeyta@redhat.com>
This commit is contained in:
Adrian Abeyta
2025-09-29 14:52:04 -05:00
committed by GitHub
parent d5ab28511c
commit c42ff4f4fd
3 changed files with 64 additions and 3 deletions

View File

@ -139,6 +139,21 @@ def test_custom_compile_config(
run_model(compilation_config, model, model_kwargs)
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE],
)
def test_fp8_kv_scale_compile(optimization_level: int):
model = "Qwen/Qwen2-0.5B"
model_kwargs = {
"quantization": "fp8",
"kv_cache_dtype": "fp8_e4m3",
"calculate_kv_scales": True,
"max_model_len": 512,
}
run_model(optimization_level, model, model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "

View File

@ -277,9 +277,8 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
@ -554,6 +553,44 @@ def maybe_save_kv_layer_to_connector(
attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
if attn_metadata is None or not getattr(
attn_metadata, 'enable_kv_scales_calculation', False):
return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,

View File

@ -2351,6 +2351,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.cudagraph_dispatcher.dispatch(batch_descriptor,
use_cascade_attn)
# Set cudagraph mode to none if calc_kv_scales is true.
if attn_metadata is not None:
metadata_list = (attn_metadata.values() if isinstance(
attn_metadata, dict) else [attn_metadata])
if any(
getattr(m, 'enable_kv_scales_calculation', False)
for m in metadata_list):
cudagraph_runtime_mode = CUDAGraphMode.NONE
# This is currently to get around the assert in the DPMetadata
# where it wants `num_tokens_across_dp` to align with `num_tokens`
if ubatch_slices is not None: