mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Avoid multiple instantiations of the RoPE class (#1828)
This commit is contained in:
@ -272,6 +272,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_rope(
|
def get_rope(
|
||||||
head_size: int,
|
head_size: int,
|
||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
@ -280,6 +283,10 @@ def get_rope(
|
|||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
|
rope_scaling)
|
||||||
|
if key in _ROPE_DICT:
|
||||||
|
return _ROPE_DICT[key]
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style)
|
is_neox_style)
|
||||||
@ -312,4 +319,5 @@ def get_rope(
|
|||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
|
_ROPE_DICT[key] = rotary_emb
|
||||||
return rotary_emb
|
return rotary_emb
|
||||||
|
Reference in New Issue
Block a user