This commit is contained in:
Antoni Baum
2024-03-13 13:56:49 -07:00
committed by GitHub
parent 7e9bd08f60
commit c33afd89f5

View File

@ -143,8 +143,8 @@ class RotaryEmbedding(nn.Module):
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
# ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that
# update the query and key tensors.
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,