mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[TPU] Optimize RoPE forward_native2 (#7636)
This commit is contained in:
@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_ = torch.view_as_complex(
|
||||
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
|
||||
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
|
||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
||||
-1).transpose(1, 2)
|
||||
return x_out
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
"""
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
cos = cos.unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2)
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
return torch.cat((o1, o2), dim=-1).to(orig_dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(CustomOp):
|
||||
@ -78,14 +86,10 @@ class RotaryEmbedding(CustomOp):
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
self.use_native2 = current_platform.is_tpu() and is_neox_style
|
||||
if not self.use_native2:
|
||||
cache = cache.to(dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
else:
|
||||
cos, sin = cache.chunk(2, dim=-1)
|
||||
freqs_cis = cos + 1j * sin
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
@ -173,28 +177,25 @@ class RotaryEmbedding(CustomOp):
|
||||
|
||||
This method might perform better than `forward_native()` when compiled.
|
||||
"""
|
||||
if positions.dim() == 1:
|
||||
batch_size = 1
|
||||
seq_len = positions.shape[0]
|
||||
else:
|
||||
batch_size, seq_len = positions.shape
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
|
||||
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(batch_size, seq_len, -1, self.head_size)
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(batch_size, seq_len, -1, self.head_size)
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
Reference in New Issue
Block a user