Disable torch.compile for dynamic rope models in Transformers backend (#23738)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-27 20:03:05 +01:00
committed by GitHub
parent 3c0ef769ba
commit 0585a9e73c

View File

@ -88,6 +88,23 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
"""
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
Defaults to `True` but is disabled in the following situations:
- The model uses dynamic rope scaling.
"""
enable = True
text_config = vllm_config.model_config.hf_config.get_text_config()
# Dynamic rope scaling is not compatible with torch.compile
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
if rope_scaling.get("rope_type") == "dynamic":
enable = False
return enable
def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig
@ -641,7 +658,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@support_torch_compile
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -653,7 +670,7 @@ class TransformersModel(TransformersBase):
})
@support_torch_compile
@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -709,12 +726,14 @@ def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
@support_torch_compile(
# set `positions` to last dim to support Qwen-mrope
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}) # set `positions` to last dim to support Qwen-mrope
},
enable_if=can_enable_torch_compile)
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
# Backwards compatibility for prev released models. State dicts back then
# had different formats and cannot be loaded with `AutoModel` mapping as is