Disable torch.compile for dynamic rope models in Transformers backend (#23738)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -88,6 +88,23 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
|
||||
def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
|
||||
"""
|
||||
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
||||
|
||||
Defaults to `True` but is disabled in the following situations:
|
||||
|
||||
- The model uses dynamic rope scaling.
|
||||
"""
|
||||
enable = True
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
# Dynamic rope scaling is not compatible with torch.compile
|
||||
rope_scaling: dict = getattr(text_config, "rope_scaling", None) or {}
|
||||
if rope_scaling.get("rope_type") == "dynamic":
|
||||
enable = False
|
||||
return enable
|
||||
|
||||
|
||||
def replace_linear_class(
|
||||
linear: nn.Linear, style: Literal["colwise", "rowwise"],
|
||||
quant_config: QuantizationConfig
|
||||
@ -641,7 +658,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersModel(TransformersBase):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -653,7 +670,7 @@ class TransformersModel(TransformersBase):
|
||||
})
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@support_torch_compile(enable_if=can_enable_torch_compile)
|
||||
class TransformersForCausalLM(TransformersBase):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -709,12 +726,14 @@ def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
|
||||
info=MultiModalProcessingInfo,
|
||||
dummy_inputs=MultiModalDummyInputsBuilder)
|
||||
@support_torch_compile(
|
||||
# set `positions` to last dim to support Qwen-mrope
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"intermediate_tensors": 0,
|
||||
"inputs_embeds": 0,
|
||||
}) # set `positions` to last dim to support Qwen-mrope
|
||||
},
|
||||
enable_if=can_enable_torch_compile)
|
||||
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
|
||||
# Backwards compatibility for prev released models. State dicts back then
|
||||
# had different formats and cannot be loaded with `AutoModel` mapping as is
|
||||
|
||||
Reference in New Issue
Block a user