mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Chore] Remove unused batched RoPE op & kernel (#24789)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key,
|
||||
int64_t head_size, torch::Tensor& cos_sin_cache,
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
|
@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // nullptr or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr =
|
||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
@ -211,96 +182,3 @@ void rotary_embedding(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
Batched version of rotary embedding, pack multiple LoRAs together
|
||||
and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
std::optional<torch::Tensor>
|
||||
key, // null or
|
||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
||||
) {
|
||||
// num_tokens = batch_size * seq_len
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
TORCH_CHECK(
|
||||
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
||||
"positions must have the same num_tokens or batch_size as "
|
||||
"cos_sin_cache_offsets");
|
||||
|
||||
int positions_ndim = positions.dim();
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have concistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||
// head_size
|
||||
int query_ndim = query.dim();
|
||||
int64_t head_stride =
|
||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
||||
// (supports multiple loras).
|
||||
ops.def(
|
||||
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox,"
|
||||
" int rot_dim,"
|
||||
" Tensor cos_sin_cache_offsets) -> ()");
|
||||
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Quantized GEMM for AWQ.
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from itertools import accumulate, product
|
||||
from itertools import product
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
@ -111,151 +111,6 @@ def test_rotary_embedding(
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"rope_type": "linear",
|
||||
"factor": (1, )
|
||||
})
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# slice tensor if required, noop otherwise
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions,
|
||||
query,
|
||||
key,
|
||||
offsets=torch.zeros(batch_size * seq_len,
|
||||
dtype=torch.long,
|
||||
device=device))
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding_multi_lora(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
scaling_factors: list[int] = [1, 2, 4]
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"rope_type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate([0] + [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
])))
|
||||
query_types = torch.randint(0,
|
||||
len(scaling_factors), (batch_size, seq_len),
|
||||
device=device)
|
||||
query_offsets = offset_map[query_types]
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key,
|
||||
query_offsets)
|
||||
out_query, out_key = rope.forward(positions, query, key,
|
||||
query_offsets.flatten())
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
|
@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
def rotary_embedding_opcheck(rot,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
key: Optional[torch.Tensor] = None):
|
||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
opcheck(torch.ops._C.batched_rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||
rot.is_neox_style, rot.rotary_dim, offsets))
|
||||
else:
|
||||
opcheck(torch.ops._C.rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||
rot.is_neox_style))
|
||||
# ops.rotary_embedding() is a in-place operation
|
||||
# that updates the query and key tensors.
|
||||
opcheck(torch.ops._C.rotary_embedding,
|
||||
(positions, query, key, rot.head_size, cos_sin_cache,
|
||||
rot.is_neox_style))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
rotary_embedding_opcheck(rot, positions, query, key)
|
||||
offsets = torch.zeros(batch_size * seq_len,
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
rotary_embedding_opcheck(rot, positions, query, key, offsets)
|
||||
|
||||
# if we have a contiguous head stride, test the alternate
|
||||
# [..., num_heads * head_dim] shape/layout
|
||||
|
@ -257,16 +257,6 @@ def rotary_embedding(
|
||||
cos_sin_cache, is_neox)
|
||||
|
||||
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: Optional[torch.Tensor], head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> None:
|
||||
|
@ -148,17 +148,6 @@ class ipex_ops:
|
||||
head_size, cos_sin_cache,
|
||||
is_neox, rot_dim)
|
||||
|
||||
@staticmethod
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
||||
head_size, cos_sin_cache,
|
||||
is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> torch.Tensor:
|
||||
|
@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp):
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def forward_xpu(
|
||||
@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp):
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
if key is None:
|
||||
# XPU kernel doesn't support key=None so fall back to native impl
|
||||
# TODO(sarckk): add support for optional key in
|
||||
# ipex.llm.functional.rotary_embedding_batched
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
return self.forward_native(positions, query, key)
|
||||
else:
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim, offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
Reference in New Issue
Block a user