mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-14 22:55:41 +08:00
Compare commits
5 Commits
cb-prefix-
...
remove-sli
| Author | SHA1 | Date | |
|---|---|---|---|
| eda8aaa849 | |||
| 39fea75bf3 | |||
| 70bb3bb300 | |||
| c4cfc2e023 | |||
| 5c6d6bed4d |
@ -17,7 +17,6 @@ import contextlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from itertools import cycle
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
@ -30,32 +29,42 @@ from transformers.generation import GenerationConfig
|
||||
from transformers.generation.continuous_batching.requests import logger
|
||||
|
||||
|
||||
def generate_without_cb(
|
||||
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
|
||||
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
||||
SLIDING_WINDOW = 0
|
||||
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
|
||||
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
|
||||
SKIP_SPECIAL_TOKENS = False
|
||||
|
||||
|
||||
def generate_simple(
|
||||
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
|
||||
) -> dict[str, str]:
|
||||
# Setup model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
attn_impl = {
|
||||
"sdpa": "sdpa",
|
||||
"eager": "eager",
|
||||
"paged_attention": "eager", # TODO: this does not work on AMD docker
|
||||
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
|
||||
"kernels-community/flash-attn": "eager",
|
||||
}[attn_impl]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
|
||||
model = model.cuda().eval()
|
||||
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = sliding_window
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
# Generate one by one
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
|
||||
decoded_outputs = {}
|
||||
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
|
||||
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
|
||||
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
|
||||
input_ids = torch.tensor([input_ids]).to("cuda")
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(
|
||||
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
|
||||
)
|
||||
# attention_mask = torch.ones_like(input_ids)
|
||||
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
|
||||
generated_tokens = outputs[0][input_ids.shape[1] :]
|
||||
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
|
||||
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
decoded_outputs[key] = decoded_output
|
||||
return decoded_outputs
|
||||
|
||||
|
||||
def maybe_setup_metrics(use_metrics: bool) -> None:
|
||||
if not use_metrics:
|
||||
return
|
||||
def setup_metrics():
|
||||
try:
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
@ -110,14 +119,16 @@ def batch_generate(
|
||||
token_count = 0
|
||||
data = []
|
||||
for i, request in enumerate(batch_outputs):
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
|
||||
# The key is used to tie back to the output of unbatched generation
|
||||
key = " ".join(map(str, batch_outputs[request].prompt_ids))
|
||||
data.append({"input": input_text, "key": key})
|
||||
|
||||
# Try to decode the output
|
||||
try:
|
||||
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
|
||||
output_text = tokenizer.decode(
|
||||
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
|
||||
)
|
||||
token_count += len(batch_outputs[request].generated_tokens[1:])
|
||||
data[-1]["cb_outputs"] = output_text
|
||||
except Exception as e:
|
||||
@ -127,7 +138,14 @@ def batch_generate(
|
||||
|
||||
# Display sample if asked
|
||||
if i < displayed_samples:
|
||||
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
print(f"{request} Input: {input_text}")
|
||||
print(f"{request} Output: {output_text}")
|
||||
else:
|
||||
print(f"{request} Input: {input_text}")
|
||||
print("[WARN]")
|
||||
print(f"{request} Output was empty!")
|
||||
|
||||
# Compare with classic generate if asked
|
||||
if expected_outputs is not None:
|
||||
@ -164,102 +182,75 @@ def batch_generate(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse args
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Continuous batching parameters
|
||||
parser.add_argument("--num-blocks", "-n", type=int, default=None)
|
||||
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument("--sliding-window", type=int, default=0)
|
||||
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
|
||||
|
||||
# Performance parameters
|
||||
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
|
||||
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
|
||||
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
|
||||
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
|
||||
|
||||
# Benchmark parameters
|
||||
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
|
||||
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
|
||||
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
|
||||
|
||||
# Display parameters
|
||||
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
|
||||
parser.add_argument("--log-level", type=str, default="INFO")
|
||||
parser.add_argument("--output-file", type=str, default=None)
|
||||
|
||||
parser.add_argument("--compare", action="store_true")
|
||||
parser.add_argument("--metrics", action="store_true")
|
||||
parser.add_argument("--profile", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create model
|
||||
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
|
||||
has_system_role = args.sliding_window == 0
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
|
||||
model = model.cuda().eval()
|
||||
|
||||
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
|
||||
model.config.sliding_window = args.sliding_window
|
||||
|
||||
# Set up diagnostics
|
||||
# Set log level
|
||||
logger.setLevel(args.log_level.upper())
|
||||
maybe_setup_metrics(args.metrics)
|
||||
|
||||
# Set up performance
|
||||
# If turned on, we setup metrics
|
||||
if args.metrics:
|
||||
setup_metrics()
|
||||
|
||||
# Set matmul precision if not none
|
||||
if args.matmul_precision != "none":
|
||||
torch.set_float32_matmul_precision(args.matmul_precision)
|
||||
# Parse cuda graph argument
|
||||
if args.cuda_graph is not None:
|
||||
use_cuda_graph = {
|
||||
"none": None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[args.cuda_graph.lower()] # fmt: skip
|
||||
else:
|
||||
use_cuda_graph = None
|
||||
|
||||
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
|
||||
use_cuda_graph = {
|
||||
"none": None, None: None,
|
||||
"yes": True, "y": True, "true": True, "t": True, "1": True,
|
||||
"no": False, "n": False, "false": False, "f": False, "0": False,
|
||||
}[cuda_graph_arg] # fmt: skip
|
||||
# Prepare model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
attn_implementation=args.attn,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
model = model.cuda().eval()
|
||||
if getattr(model.config, "sliding_window", None) is not None:
|
||||
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
|
||||
model.config.sliding_window = SLIDING_WINDOW
|
||||
|
||||
# If turned on, we compile the model
|
||||
if args.compile:
|
||||
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
|
||||
|
||||
# Prepare tokenizer and dataset
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
|
||||
|
||||
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
dataset = dataset.select(range(args.samples))
|
||||
|
||||
if args.add_prefix:
|
||||
possible_prefixes = [
|
||||
None,
|
||||
"You are a bot that solves math problems.",
|
||||
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
|
||||
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
|
||||
] # fmt: skip
|
||||
else:
|
||||
possible_prefixes = [None]
|
||||
|
||||
batched_inputs = []
|
||||
for item, prefix in zip(dataset, cycle(possible_prefixes)):
|
||||
messages = []
|
||||
question = item["question"]
|
||||
if prefix is not None:
|
||||
if has_system_role:
|
||||
messages.append({"role": "system", "content": prefix})
|
||||
else:
|
||||
question = prefix + "\n\n" + question
|
||||
messages.append({"role": "user", "content": question})
|
||||
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
batched_inputs.append(inputs["input_ids"])
|
||||
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
|
||||
|
||||
# Prepare generation config
|
||||
generation_cfg = GenerationConfig(
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=512,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=args.do_sample,
|
||||
do_sample=not args.compare,
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
num_blocks=args.num_blocks,
|
||||
@ -267,12 +258,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# If we need to compare, we need to generate the reference outputs
|
||||
if args.compare:
|
||||
expected_outputs = generate_without_cb(
|
||||
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
|
||||
)
|
||||
else:
|
||||
expected_outputs = None
|
||||
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
|
||||
|
||||
# If no output file is provided, we pick a name based on the args
|
||||
if args.output_file is None:
|
||||
@ -285,8 +271,8 @@ if __name__ == "__main__":
|
||||
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
|
||||
batch_generate(
|
||||
model,
|
||||
batched_inputs[: min(5, args.samples)],
|
||||
generation_cfg,
|
||||
simple_batch_inputs[: min(5, args.samples)],
|
||||
generation_config,
|
||||
tokenizer,
|
||||
displayed_samples=-1,
|
||||
)
|
||||
@ -299,8 +285,8 @@ if __name__ == "__main__":
|
||||
# Run batch generation
|
||||
gen_time, tok_per_sec = batch_generate(
|
||||
model,
|
||||
batched_inputs,
|
||||
generation_cfg,
|
||||
simple_batch_inputs,
|
||||
generation_config,
|
||||
tokenizer,
|
||||
displayed_samples=args.displayed,
|
||||
output_file=args.output_file,
|
||||
@ -311,5 +297,5 @@ if __name__ == "__main__":
|
||||
prof.export_chrome_trace(filename)
|
||||
|
||||
# Example usage:
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
|
||||
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
|
||||
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
|
||||
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import deque
|
||||
from math import floor, gcd, sqrt
|
||||
from typing import Optional
|
||||
|
||||
@ -20,8 +21,8 @@ import torch
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...utils.metrics import attach_tracer, traced
|
||||
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import RequestState, get_device_and_memory_breakdown, logger
|
||||
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
|
||||
from .requests import get_device_and_memory_breakdown, logger
|
||||
|
||||
|
||||
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
|
||||
@ -31,7 +32,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
|
||||
- All groups have the same number of layers
|
||||
|
||||
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
|
||||
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
|
||||
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
|
||||
"""
|
||||
# If the config has no layer_type attribute, it means all layers are the same attention type
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
@ -115,6 +116,7 @@ class PagedAttentionCache:
|
||||
for the sliding-attention group, although it is not needed.
|
||||
"""
|
||||
|
||||
# TODO: this init is quite long, maybe a refactor is in order
|
||||
def __init__(
|
||||
self,
|
||||
config: PreTrainedConfig,
|
||||
@ -122,10 +124,8 @@ class PagedAttentionCache:
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
tp_size: Optional[int] = None,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
|
||||
only full attention layers.
|
||||
"""Initialize a paged attention cache for efficient memory usage.
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
@ -133,7 +133,6 @@ class PagedAttentionCache:
|
||||
device: Device for the cache tensors
|
||||
dtype: Data type of the cache
|
||||
tp_size: Tensor parallelism size
|
||||
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
|
||||
"""
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
@ -174,12 +173,10 @@ class PagedAttentionCache:
|
||||
page_size = self.head_dim * self.num_key_value_heads
|
||||
|
||||
if "flash" in self.config._attn_implementation:
|
||||
num_attention_masks = 0 # only used to compute the default memory footprint args
|
||||
elif "sliding_attention" in group_types:
|
||||
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
||||
num_attention_masks = 2
|
||||
num_attention_masks = 1 # only used to compute the default meme args
|
||||
else:
|
||||
num_attention_masks = 1
|
||||
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
|
||||
num_attention_masks = 2 if "sliding_attention" in group_types else 1
|
||||
|
||||
memory_handler = PagedAttentionMemoryHandler(
|
||||
block_size=self.block_size,
|
||||
@ -221,6 +218,7 @@ class PagedAttentionCache:
|
||||
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
|
||||
|
||||
# Block management data structures
|
||||
self._free_blocks = deque(range(num_blocks))
|
||||
self.group_cache_managers: list[CacheAllocator] = []
|
||||
for i, group_type in enumerate(group_types):
|
||||
if group_type == "full_attention":
|
||||
@ -231,19 +229,13 @@ class PagedAttentionCache:
|
||||
raise ValueError(f"Invalid group type: {group_type}")
|
||||
self.group_cache_managers.append(cm)
|
||||
|
||||
# We only use prefix sharing if the whole model has only full attention layers
|
||||
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
|
||||
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
|
||||
self.blocks_to_complete: dict[str, int] = {}
|
||||
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
|
||||
|
||||
@traced
|
||||
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
|
||||
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
|
||||
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
|
||||
max_allocated = 0
|
||||
for cm in self.group_cache_managers:
|
||||
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
|
||||
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
|
||||
if allocated is None:
|
||||
return None
|
||||
max_allocated = max(max_allocated, allocated)
|
||||
@ -254,11 +246,11 @@ class PagedAttentionCache:
|
||||
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
|
||||
by the cache managers."""
|
||||
for cm in self.group_cache_managers:
|
||||
cm.free_blocks(request_id, self._block_manager)
|
||||
cm.free_blocks(request_id, self._free_blocks)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the current number of unallocated blocks available for new requests."""
|
||||
return self._block_manager.num_free_blocks
|
||||
return len(self._free_blocks)
|
||||
|
||||
@traced
|
||||
def extend_read_indices(
|
||||
@ -345,44 +337,6 @@ class PagedAttentionCache:
|
||||
# Return the new KV values
|
||||
return key_states_with_cache, value_states_with_cache
|
||||
|
||||
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
|
||||
"""Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
|
||||
matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
|
||||
that match. If no prefix match is found, we return 0."""
|
||||
current_hash = None
|
||||
allocated_blocks = []
|
||||
for b in range(len(prompt_ids) // self.block_size):
|
||||
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
|
||||
current_hash = self._block_manager.compute_hash(current_hash, tokens)
|
||||
block_id = self._block_manager._hash_to_id.get(current_hash)
|
||||
if block_id is not None:
|
||||
allocated_blocks.append(block_id)
|
||||
self._block_manager.increase_ref_count(block_id)
|
||||
else:
|
||||
break
|
||||
# If we found a matching prefix, we reference the blocks in the request
|
||||
if allocated_blocks:
|
||||
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
|
||||
cm = self.group_cache_managers[0]
|
||||
cm.block_table[request_id] = allocated_blocks
|
||||
|
||||
prefix_length = len(allocated_blocks) * self.block_size
|
||||
self._total_prefix_length += prefix_length
|
||||
return prefix_length
|
||||
|
||||
def mark_blocks_as_complete(self, state: RequestState) -> None:
|
||||
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
|
||||
a no-op."""
|
||||
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
|
||||
if num_complete_blocks == 0:
|
||||
return None
|
||||
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
|
||||
self._block_manager.mark_blocks_as_complete(
|
||||
num_complete_blocks=num_complete_blocks,
|
||||
allocated_blocks=cm.block_table[state.request_id],
|
||||
prompt_ids=(state.full_prompt_ids + state.static_outputs),
|
||||
)
|
||||
|
||||
|
||||
# TODO: rework computation with the groups and their sizes
|
||||
class PagedAttentionMemoryHandler:
|
||||
@ -517,8 +471,6 @@ class PagedAttentionMemoryHandler:
|
||||
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
|
||||
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
|
||||
])
|
||||
|
||||
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
|
||||
"""
|
||||
cache_memory = self.get_available_memory(max_memory_percent)
|
||||
logger.info(f"Cache memory: {cache_memory}")
|
||||
@ -530,16 +482,11 @@ class PagedAttentionMemoryHandler:
|
||||
c = -cache_memory
|
||||
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
|
||||
|
||||
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
|
||||
if self.num_attention_masks == 0:
|
||||
greatest_solution = -c / b
|
||||
# Otherwise, we solve the quadratic equation
|
||||
else:
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
|
||||
# Compute discriminant and greatest solution
|
||||
discriminant = b**2 - 4 * a * c
|
||||
if discriminant < 0:
|
||||
raise ValueError(f"Discriminant is negative: {discriminant = }")
|
||||
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
|
||||
if greatest_solution < 0:
|
||||
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
|
||||
|
||||
|
||||
@ -14,211 +14,29 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterator
|
||||
from math import ceil
|
||||
from typing import Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
from .requests import logger
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
|
||||
index = len(xs) - 1
|
||||
for x in xs[::-1]:
|
||||
yield index, x
|
||||
index -= 1
|
||||
|
||||
|
||||
class Block:
|
||||
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
|
||||
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
|
||||
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
|
||||
its parent's hash (if there is a parent)."""
|
||||
|
||||
def __init__(self, id_: int, parent_id: int | None) -> None:
|
||||
self.id: int = id_
|
||||
self.parent_id: int | None = parent_id
|
||||
self.hash: int | None = None
|
||||
self.ref_count: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
return self.hash is not None
|
||||
|
||||
|
||||
class BlockManager:
|
||||
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
|
||||
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
|
||||
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
|
||||
referencing this block is stored as ref_count in the Block object.
|
||||
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
|
||||
be given as free blocks to new requests without any overhead.
|
||||
- initialized: the block is complete and was used by one or more request that are finished. It contains KV cache
|
||||
data and its hash is stored in the hash table. If a new request needs a block with the same hash, we increase
|
||||
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
|
||||
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
|
||||
hash table.
|
||||
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
|
||||
it is in use.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
|
||||
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
|
||||
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
|
||||
layers."""
|
||||
self.num_blocks = num_blocks
|
||||
self.block_size = block_size
|
||||
self._uninit_block_ids = deque(range(num_blocks))
|
||||
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
|
||||
self._use_prefix_sharing = use_prefix_sharing
|
||||
self._hash_to_id: dict[int, int] = {}
|
||||
self._id_to_block: dict[int, Block] = {}
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
|
||||
return len(self._uninit_block_ids) + len(self._init_block_ids)
|
||||
|
||||
def is_enough_free_blocks(self, n_blocks: int) -> bool:
|
||||
"""Checks if there are enough free blocks to allocate the requested number of blocks (n_blocks). If there are
|
||||
not enough uninitialized blocks, we uninitialize the required number of initialized blocks."""
|
||||
# Exit early if there are enough uninitialized blocks
|
||||
if len(self._uninit_block_ids) >= n_blocks:
|
||||
return True
|
||||
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
|
||||
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
|
||||
if len(self._init_block_ids) < block_to_unintialize:
|
||||
return False
|
||||
# Uninitialize the required amount of blocks
|
||||
for _ in range(block_to_unintialize):
|
||||
id_to_unintialize = self._init_block_ids.popitem()[0]
|
||||
block = self._id_to_block[id_to_unintialize]
|
||||
self._hash_to_id.pop(block.hash)
|
||||
self._uninit_block_ids.append(id_to_unintialize)
|
||||
return True
|
||||
|
||||
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
|
||||
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
|
||||
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
|
||||
the parent block. If the manager cannot find enough free blocks, it returns None."""
|
||||
if not self.is_enough_free_blocks(n_blocks):
|
||||
return None
|
||||
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
|
||||
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in allocated_block_ids:
|
||||
block = Block(block_id, last_block_id)
|
||||
self._id_to_block[block_id] = block
|
||||
last_block_id = block_id
|
||||
# In both cases, we return the allocated block ids
|
||||
return allocated_block_ids
|
||||
|
||||
def increase_ref_count(self, block_id: int) -> None:
|
||||
"""Increases the reference count of a given (block_id)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count += 1
|
||||
if block.ref_count == 1:
|
||||
self._init_block_ids.pop(block_id)
|
||||
|
||||
def decrease_ref_count(self, block_id: int) -> None:
|
||||
"""Decreases the reference count of a given (block_id). If the reference count reaches 0, the block is no longer
|
||||
in use, and becomes initialized (if it was complete) or uninitialized (if it was incomplete)."""
|
||||
block = self._id_to_block[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
if block.is_complete:
|
||||
self._init_block_ids[block_id] = None
|
||||
else:
|
||||
self._id_to_block.pop(block_id)
|
||||
self._uninit_block_ids.append(block_id)
|
||||
|
||||
def free_blocks(self, blocks: list[int]) -> None:
|
||||
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
|
||||
blocks queue. Otherwise, their new state depends on whether they are complete."""
|
||||
if self._use_prefix_sharing:
|
||||
for block_id in blocks:
|
||||
self.decrease_ref_count(block_id)
|
||||
else:
|
||||
self._uninit_block_ids.extend(blocks)
|
||||
|
||||
def mark_blocks_as_complete(
|
||||
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
|
||||
) -> None:
|
||||
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
|
||||
of (prompt_ids) is used to compute the hash of the new block."""
|
||||
# Look for the first complete block, starting from the last block in the sequence
|
||||
parent_hash = None
|
||||
incomplete_blocks: list[Block] = []
|
||||
for i, block_id in reverse_enumerate(allocated_blocks):
|
||||
block = self._id_to_block[block_id]
|
||||
if block.is_complete:
|
||||
parent_hash = block.hash
|
||||
break
|
||||
incomplete_blocks.append((i, block))
|
||||
|
||||
# Now go through the incomplete blocks and updated them
|
||||
new_parent_id = None
|
||||
while incomplete_blocks:
|
||||
i, block = incomplete_blocks.pop()
|
||||
|
||||
# If the parent id has been updated, we apply the change
|
||||
if new_parent_id is not None:
|
||||
block.parent_id = new_parent_id
|
||||
new_parent_id = None
|
||||
|
||||
# If we have set the hash for all complete blocks, we can stop
|
||||
if num_complete_blocks == 0:
|
||||
break
|
||||
|
||||
# Otherwise, we compute the hash
|
||||
num_complete_blocks -= 1
|
||||
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
|
||||
block.hash = self.compute_hash(parent_hash, tokens)
|
||||
|
||||
existing_block_id = self._hash_to_id.get(block.hash)
|
||||
# If the block hash is already in the hash to id mapping, we reference the existing block instead
|
||||
if existing_block_id is not None:
|
||||
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
|
||||
allocated_blocks[i] = existing_block_id
|
||||
self._id_to_block[existing_block_id].ref_count += 1
|
||||
new_parent_id = existing_block_id
|
||||
self.free_blocks([block.id])
|
||||
|
||||
# Otherwise, we add the completed block to the hash table
|
||||
else:
|
||||
self._hash_to_id[block.hash] = block.id
|
||||
|
||||
# Update loop variables
|
||||
parent_hash = block.hash
|
||||
|
||||
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
|
||||
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
|
||||
parent, the parent hash is None."""
|
||||
return hash((parent_hash, tuple(tokens)))
|
||||
|
||||
|
||||
class CacheAllocator(ABC):
|
||||
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
|
||||
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
|
||||
|
||||
_index: int
|
||||
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocates (n_blocks) for a given (request_id) using the (block_manager). Returns the num of blocks allocated
|
||||
if successful and None otherwise."""
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
|
||||
otherwise."""
|
||||
|
||||
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
|
||||
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
|
||||
if request_id in self.block_table:
|
||||
blocks_to_free = self.block_table.pop(request_id)
|
||||
block_manager.free_blocks(blocks_to_free)
|
||||
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
|
||||
"""Frees all blocks associated with a request_id."""
|
||||
if request_id in self._block_table:
|
||||
blocks_to_free = self._block_table.pop(request_id)
|
||||
free_blocks.extend(blocks_to_free)
|
||||
else:
|
||||
logger.warning(
|
||||
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
|
||||
@ -248,30 +66,23 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
"""
|
||||
self._index = index
|
||||
self.block_size = block_size
|
||||
self.block_table = {}
|
||||
self._block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated if successful and None otherwise. For group of full attention layers, we always allocate the number of
|
||||
requested blocks."""
|
||||
# Make sure the request_id is in the block table and get the first block id
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
|
||||
last_block_id = None
|
||||
else:
|
||||
last_block_id = self.block_table[request_id][-1]
|
||||
# Actual allocation, return early if failed
|
||||
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
|
||||
if allocated_blocks is None:
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
|
||||
if len(free_blocks) < n_blocks:
|
||||
return None
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
|
||||
return n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
|
||||
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self.block_table.get(request_id)
|
||||
block_table = self._block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -286,7 +97,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
|
||||
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
|
||||
cache as a continuation of the existing cache for the same request."""
|
||||
block_table = self.block_table.get(request_id)
|
||||
block_table = self._block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Compute the physical indices
|
||||
@ -318,26 +129,25 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
self.block_size = block_size
|
||||
self.sliding_window = sliding_window
|
||||
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
|
||||
self.block_table = {}
|
||||
self._block_table = {}
|
||||
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
|
||||
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
|
||||
allocated otherwise. For group of sliding window attention layers, we only allocate up to the point where we can
|
||||
fit an entire sliding window in the cache tensor."""
|
||||
if request_id not in self.block_table:
|
||||
self.block_table[request_id] = []
|
||||
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
|
||||
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
|
||||
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
|
||||
entire sliding window in the cache tensor."""
|
||||
if request_id not in self._block_table:
|
||||
self._block_table[request_id] = []
|
||||
# Early return if we are already at the max number of blocks per request
|
||||
already_allocated = len(self.block_table[request_id])
|
||||
already_allocated = len(self._block_table[request_id])
|
||||
if already_allocated == self._max_blocks_per_request:
|
||||
return 0
|
||||
# Compute actual number of blocks to allocate
|
||||
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
|
||||
actual_n_blocks = after_allocation - already_allocated
|
||||
# Classic allocation
|
||||
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
|
||||
if allocated_blocks is None:
|
||||
if len(free_blocks) < actual_n_blocks:
|
||||
return None
|
||||
self.block_table[request_id].extend(allocated_blocks)
|
||||
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
|
||||
return actual_n_blocks
|
||||
|
||||
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
|
||||
@ -347,7 +157,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
|
||||
which indicate where to store the new key or values indices."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self.block_table.get(request_id)
|
||||
block_table = self._block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -368,7 +178,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
|
||||
the allocated physical cache, we start writing from the beginning of the physical cache again."""
|
||||
# Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
block_table = self.block_table.get(request_id)
|
||||
block_table = self._block_table.get(request_id)
|
||||
if block_table is None:
|
||||
raise ValueError(f"No block table found for request {request_id}")
|
||||
# Apply sliding window
|
||||
@ -391,3 +201,22 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
|
||||
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
|
||||
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
|
||||
return "sliding_attention", seqlens_k
|
||||
|
||||
|
||||
# TODO: test the impact of this
|
||||
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
|
||||
# # Retrieve the block table for the request and raise an error if it doesn't exist
|
||||
# block_table = self._block_table.get(request_id)
|
||||
# if block_table is None:
|
||||
# raise ValueError(f"No block table found for request {request_id}")
|
||||
# # Compute the physical indices
|
||||
# physical_indices = []
|
||||
# n_left = past_length
|
||||
# for block_idx in block_table:
|
||||
# block_physical_index = block_idx * self.block_size
|
||||
# pages_used = min(self.block_size, n_left)
|
||||
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
|
||||
# n_left -= pages_used
|
||||
# if n_left == 0:
|
||||
# return physical_indices
|
||||
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")
|
||||
|
||||
@ -16,13 +16,12 @@
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
from math import ceil
|
||||
from time import perf_counter
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -447,7 +446,10 @@ class ContinuousBatchProcessor:
|
||||
cumulative_seqlens_q = [0]
|
||||
logits_indices = []
|
||||
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
|
||||
else:
|
||||
cumulative_seqlens_k = [0]
|
||||
|
||||
read_index = [[] for _ in range(self.cache.num_groups)]
|
||||
write_index = [[] for _ in range(self.cache.num_groups)]
|
||||
@ -496,7 +498,10 @@ class ContinuousBatchProcessor:
|
||||
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
if isinstance(self.cumulative_seqlens_k, dict):
|
||||
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
|
||||
else:
|
||||
ck = cumulative_seqlens_k[-1]
|
||||
logger.debug(
|
||||
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
|
||||
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
|
||||
@ -512,7 +517,7 @@ class ContinuousBatchProcessor:
|
||||
read_index: list[list[int]],
|
||||
write_index: list[list[int]],
|
||||
cumulative_seqlens_q: list[int],
|
||||
cumulative_seqlens_k: dict[str, list[int]],
|
||||
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
|
||||
logits_indices: list[int],
|
||||
) -> None:
|
||||
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
|
||||
@ -556,7 +561,9 @@ class ContinuousBatchProcessor:
|
||||
@traced
|
||||
def _maybe_send_output(self, state: RequestState) -> None:
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if state.streaming or state.status == RequestStatus.FINISHED:
|
||||
if state.streaming:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
|
||||
@traced
|
||||
@ -564,27 +571,17 @@ class ContinuousBatchProcessor:
|
||||
"""Update request states based on generated tokens."""
|
||||
out_tokens = self._sync()
|
||||
for i, state in enumerate(self.requests_in_batch):
|
||||
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
|
||||
if len(state.remaining_prompt_ids) == 0:
|
||||
self.metrics.record_ttft_metric(state.created_time, state.request_id)
|
||||
state.status = RequestStatus.DECODING
|
||||
token = out_tokens[self.logits_indices[i]]
|
||||
state.prompt_ids = [token]
|
||||
# Update the request and stop if it is complete
|
||||
is_finished = state.update_and_check_completion(token)
|
||||
# We mark the completed blocks as such
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
if is_finished:
|
||||
if state.update_with_token(token):
|
||||
self.metrics.record_request_completion(state.created_time, state.request_id)
|
||||
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
|
||||
self._maybe_send_output(state)
|
||||
# Otherwise, the request is still prefilling, but the prefill has been split
|
||||
elif state.status == RequestStatus.PREFILLING_SPLIT:
|
||||
self.cache.mark_blocks_as_complete(state)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
else:
|
||||
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
|
||||
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
raise ValueError("No more free blocks")
|
||||
|
||||
@ -729,7 +726,6 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the continuous batching manager.
|
||||
|
||||
@ -739,7 +735,6 @@ class ContinuousBatchingManager:
|
||||
max_queue_size: Maximum size of the request queue (0 = unlimited)
|
||||
num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
|
||||
num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
|
||||
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
|
||||
"""
|
||||
if "paged|" not in model.config._attn_implementation:
|
||||
attn_implementation = f"paged|{model.config._attn_implementation}"
|
||||
@ -772,8 +767,6 @@ class ContinuousBatchingManager:
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
|
||||
self._allow_prefix_sharing = allow_prefix_sharing
|
||||
|
||||
# If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
|
||||
if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
|
||||
self.use_cuda_graph = True
|
||||
@ -806,6 +799,7 @@ class ContinuousBatchingManager:
|
||||
logger.warning("Manager thread is already running.")
|
||||
return
|
||||
|
||||
self._result_queue = queue.Queue()
|
||||
self._generation_thread = threading.Thread(target=self._run_generation_loop)
|
||||
self._generation_thread.start()
|
||||
|
||||
@ -820,16 +814,6 @@ class ContinuousBatchingManager:
|
||||
block: Whether to wait for the thread to stop
|
||||
timeout: Maximum time to wait for the thread to stop
|
||||
"""
|
||||
if self.batch_processor is None:
|
||||
logger.warning("\nBatch processor was not initialized.")
|
||||
else:
|
||||
if self.batch_processor.cache.use_prefix_sharing:
|
||||
logger.warning(
|
||||
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
|
||||
)
|
||||
else:
|
||||
logger.warning("\nPrefix sharing was off.")
|
||||
|
||||
if self._generation_thread is None:
|
||||
logger.warning("Manager not started.")
|
||||
return
|
||||
@ -955,6 +939,20 @@ class ContinuousBatchingManager:
|
||||
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
|
||||
|
||||
@traced
|
||||
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
|
||||
stream = torch.cuda.Stream(device=self.model.device)
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(stream):
|
||||
# Warmup the model with a dummy forward pass
|
||||
self._generation_step(batch_processor)
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph, stream=stream):
|
||||
self._generation_step(batch_processor)
|
||||
|
||||
@traced
|
||||
# @torch.compile
|
||||
def _generation_step(self) -> None:
|
||||
"""Perform a single generation step. This is cuda graphed"""
|
||||
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
|
||||
@ -970,7 +968,6 @@ class ContinuousBatchingManager:
|
||||
self.model.device,
|
||||
self.model.dtype,
|
||||
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
|
||||
allow_prefix_sharing=self._allow_prefix_sharing,
|
||||
)
|
||||
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
|
||||
|
||||
@ -1062,15 +1059,6 @@ class ContinuousBatchingManager:
|
||||
class ContinuousMixin:
|
||||
"""Mixin class for models to add continuous batching capabilities."""
|
||||
|
||||
@contextmanager
|
||||
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
|
||||
manager = self.init_continuous_batching(**kwargs)
|
||||
manager.start()
|
||||
try:
|
||||
yield manager
|
||||
finally:
|
||||
manager.stop(block=True)
|
||||
|
||||
def init_continuous_batching(
|
||||
self,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
@ -1078,7 +1066,6 @@ class ContinuousMixin:
|
||||
max_queue_size: int = 0,
|
||||
num_q_cuda_graphs: int = 0,
|
||||
num_kv_cuda_graphs: int = 0,
|
||||
allow_prefix_sharing: bool = True,
|
||||
) -> ContinuousBatchingManager:
|
||||
"""Initialize a manager for continuous batching inference.
|
||||
|
||||
@ -1111,7 +1098,6 @@ class ContinuousMixin:
|
||||
max_queue_size=max_queue_size,
|
||||
num_q_cuda_graphs=num_q_cuda_graphs,
|
||||
num_kv_cuda_graphs=num_kv_cuda_graphs,
|
||||
allow_prefix_sharing=allow_prefix_sharing,
|
||||
)
|
||||
|
||||
# TODO: support streaming
|
||||
@ -1183,6 +1169,5 @@ class ContinuousMixin:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during batch generation: {e}", exc_info=True)
|
||||
finally:
|
||||
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
|
||||
manager.stop(block=True, timeout=5.0)
|
||||
return results
|
||||
|
||||
@ -116,10 +116,10 @@ class RequestState:
|
||||
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
|
||||
"""
|
||||
|
||||
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
|
||||
# Required fields
|
||||
request_id: str
|
||||
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
|
||||
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
|
||||
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
|
||||
static_outputs: list[int] = field(default_factory=list) # Generated tokens
|
||||
allocated_blocks: int = 0 # Number of blocks allocated to the request
|
||||
@ -164,7 +164,7 @@ class RequestState:
|
||||
|
||||
# TODO: this logic seems one token off, check it out
|
||||
@traced
|
||||
def update_and_check_completion(self, token_id: int) -> bool:
|
||||
def update_with_token(self, token_id: int) -> bool:
|
||||
"""Update the request with a newly generated token and check for completion.
|
||||
|
||||
Args:
|
||||
|
||||
@ -104,7 +104,7 @@ class Scheduler(ABC):
|
||||
)
|
||||
|
||||
@traced
|
||||
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
|
||||
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
|
||||
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
|
||||
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
|
||||
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
|
||||
@ -113,11 +113,10 @@ class Scheduler(ABC):
|
||||
# 1. we check that the occupancy is less than the requested length
|
||||
# 2. we allocate enough blocks to cover the requested length
|
||||
current_len = state.current_len()
|
||||
len_next_tokens = len(state.prompt_ids)
|
||||
occupancy = state.allocated_blocks * self.cache.block_size - current_len
|
||||
if occupancy < len_next_tokens or state.allocated_blocks == 0:
|
||||
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state)
|
||||
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
|
||||
if allocated is None:
|
||||
return False
|
||||
state.allocated_blocks += allocated
|
||||
@ -126,29 +125,11 @@ class Scheduler(ABC):
|
||||
@traced(span_name="prepare_request")
|
||||
def _prepare_request_for_processing(
|
||||
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
|
||||
) -> None:
|
||||
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
|
||||
pending, this is where we look for a prefix match and split the request if found."""
|
||||
# If prefix sharing is enabled, we look for a prefix match and split the request if found
|
||||
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
|
||||
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
|
||||
if prefill_length > 0:
|
||||
self.active_requests[state.request_id] = state
|
||||
request_ids_to_remove_from_waiting.add(state.request_id)
|
||||
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
|
||||
# Even if we match the whole request, we keep at least 1 token to start decoding
|
||||
prefill_length = min(prefill_length, len(state.prompt_ids) - 1)
|
||||
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.prompt_ids = state.prompt_ids[prefill_length:]
|
||||
state.position_offset += prefill_length
|
||||
|
||||
# If the request has a split prefill, the tokens to process are the remaining prompt ids
|
||||
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
|
||||
request_tokens = state.remaining_prompt_ids
|
||||
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
|
||||
else:
|
||||
request_tokens = state.prompt_ids
|
||||
|
||||
):
|
||||
"""Prepares a request for processing in the current batch."""
|
||||
request_tokens = (
|
||||
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
|
||||
)
|
||||
if len(request_tokens) < token_budget:
|
||||
# Can process the entire prompt/remainder
|
||||
if state.status == RequestStatus.PENDING:
|
||||
@ -171,7 +152,6 @@ class Scheduler(ABC):
|
||||
state.prompt_ids = request_tokens[:token_budget]
|
||||
|
||||
|
||||
# TODO: further common-ize the two classes
|
||||
@attach_tracer()
|
||||
class FIFOScheduler(Scheduler):
|
||||
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
|
||||
@ -215,31 +195,30 @@ class FIFOScheduler(Scheduler):
|
||||
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
was_waiting = self.waiting_requests.pop(req_id, None) is not None
|
||||
if was_waiting:
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
@ -270,7 +249,6 @@ class PrefillFirstScheduler(Scheduler):
|
||||
elif state.status == RequestStatus.DECODING:
|
||||
second_priority_states.append(state)
|
||||
|
||||
# Add waiting requests to second priority
|
||||
for req_id in self.waiting_requests_order:
|
||||
second_priority_states.append(self.waiting_requests[req_id])
|
||||
|
||||
@ -281,31 +259,30 @@ class PrefillFirstScheduler(Scheduler):
|
||||
for state in candidates:
|
||||
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
|
||||
request_len = len(state.prompt_ids)
|
||||
# If we can't allocate blocks, do not schedule the request and break if the cache is full
|
||||
if not self._allocate_blocks_if_needed(state):
|
||||
if self.cache.get_num_free_blocks() == 0:
|
||||
if not self._allocate_blocks_if_needed(
|
||||
state, len(state.prompt_ids)
|
||||
): # don't schedule if we can't allocate blocks
|
||||
if len(self.cache._free_blocks) == 0:
|
||||
break
|
||||
continue
|
||||
|
||||
# Add the request to the scheduled requests
|
||||
scheduled_requests.append(state)
|
||||
@traced
|
||||
def _add_to_scheduled_requests(state: RequestState):
|
||||
scheduled_requests.append(state)
|
||||
|
||||
_add_to_scheduled_requests(state)
|
||||
|
||||
# Update the token budget
|
||||
token_budget -= request_len
|
||||
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
|
||||
if self.cache.use_prefix_sharing:
|
||||
tokens_in_current_block = state.current_len() % self.cache.block_size
|
||||
tokens_after_forward = tokens_in_current_block + request_len
|
||||
complete_blocks = tokens_after_forward // self.cache.block_size
|
||||
self.cache.blocks_to_complete[state.request_id] = complete_blocks
|
||||
|
||||
# Remove the request from the waiting queue and mark it as removed
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
@traced
|
||||
def _remove_from_waiting_requests(state: RequestState):
|
||||
req_id = state.request_id
|
||||
if req_id in self.waiting_requests:
|
||||
del self.waiting_requests[req_id]
|
||||
request_ids_to_remove_from_waiting.add(req_id)
|
||||
|
||||
_remove_from_waiting_requests(state)
|
||||
|
||||
# Early exit of the loop if we have no token budget left
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
|
||||
@ -140,6 +140,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def replace_layer_number_by_wildcard(name: str) -> str:
|
||||
"""
|
||||
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
|
||||
a dot (`.`) and the end of the string.
|
||||
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
|
||||
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
|
||||
"""
|
||||
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
|
||||
|
||||
|
||||
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
|
||||
"""
|
||||
Get the TP style for a parameter from the TP plan.
|
||||
@ -150,11 +160,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
|
||||
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
|
||||
not parent classes for `post_init` calls
|
||||
"""
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
return tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
|
||||
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
|
||||
return tp_plan[module_name]
|
||||
return None
|
||||
|
||||
|
||||
@ -1086,7 +1096,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
|
||||
if tp_plan is None:
|
||||
return
|
||||
|
||||
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
|
||||
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
|
||||
unsharded_layers = set(generic_keys)
|
||||
unused_rules = tp_plan
|
||||
|
||||
|
||||
@ -128,7 +128,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -580,8 +580,7 @@ def eager_attention_forward(
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -106,7 +106,6 @@ class ApertusConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -201,8 +201,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -123,7 +123,6 @@ class ApertusConfig(LlamaConfig):
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -208,8 +208,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -99,15 +99,14 @@ class AriaTextConfig(PreTrainedConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `AriaTextModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -431,8 +431,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -169,6 +169,15 @@ class AriaTextConfig(LlamaConfig):
|
||||
|
||||
model_type = "aria_text"
|
||||
base_config_key = "text_config"
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.shared_experts.gate_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -114,7 +114,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -58,8 +58,8 @@ def eager_attention_forward(
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
|
||||
@ -292,8 +292,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -126,7 +126,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -130,7 +130,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -74,7 +74,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -1172,7 +1172,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -106,7 +106,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -139,8 +139,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -121,7 +121,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -105,7 +105,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -284,7 +284,7 @@ class BloomAttention(nn.Module):
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
|
||||
@ -250,8 +250,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -420,7 +420,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -70,7 +70,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -232,8 +232,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -247,8 +247,7 @@ def eager_attention_forward(
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -1061,8 +1061,7 @@ def eager_attention_forward(
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -173,8 +173,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -149,8 +149,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -258,8 +258,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -167,8 +167,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -190,7 +190,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -175,7 +175,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -165,8 +165,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -64,8 +64,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, **kwargs)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
|
||||
@ -252,8 +252,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -300,8 +300,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -179,7 +179,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -254,8 +254,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -256,7 +256,7 @@ class DiffLlamaAttention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
|
||||
@ -132,7 +132,7 @@ class DiffLlamaAttention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
|
||||
@ -167,7 +167,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -187,7 +187,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -206,7 +206,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -137,7 +137,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -118,9 +118,9 @@ class DogeConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.dt_proj": "rowwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.input_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.input_residual.weight": "sequence_parallel",
|
||||
"layers.*.input_residual": "sequence_parallel",
|
||||
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_residual.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_residual": "sequence_parallel",
|
||||
"norm.weight": "sequence_parallel",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
|
||||
@ -196,8 +196,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -146,9 +146,9 @@ class DogeConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.dt_proj": "rowwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.input_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.input_residual.weight": "sequence_parallel",
|
||||
"layers.*.input_residual": "sequence_parallel",
|
||||
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_residual.weight": "sequence_parallel",
|
||||
"layers.*.post_attention_residual": "sequence_parallel",
|
||||
"norm.weight": "sequence_parallel",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
|
||||
@ -188,8 +188,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -286,7 +286,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -358,8 +358,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -135,7 +135,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -105,8 +105,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -146,7 +146,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -156,8 +156,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -214,8 +214,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -270,7 +270,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -234,7 +234,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -194,8 +194,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -348,8 +348,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -114,9 +114,9 @@ class FlexOlmoConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -168,8 +168,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -125,9 +125,9 @@ class FlexOlmoConfig(OlmoeConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -205,7 +205,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -205,8 +205,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -217,9 +217,8 @@ def eager_attention_forward(
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
@ -291,9 +291,8 @@ def eager_attention_forward(
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
@ -291,9 +291,8 @@ def eager_attention_forward(
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
@ -630,7 +629,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
config: Gemma3TextConfig
|
||||
base_model_prefix = "language_model"
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -715,7 +715,7 @@ class Gemma3TextModel(Gemma2Model):
|
||||
|
||||
class Gemma3ForCausalLM(Gemma2ForCausalLM):
|
||||
config: Gemma3TextConfig
|
||||
base_model_prefix = "language_model"
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: Gemma3TextConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1185,9 +1185,8 @@ def eager_attention_forward(
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
@ -156,8 +156,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -138,8 +138,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -136,8 +136,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -265,8 +265,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -214,8 +214,9 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -185,8 +185,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -159,8 +159,9 @@ class Glm4vMoeTextConfig(Glm4MoeConfig):
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
|
||||
@ -74,8 +74,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, **kwargs)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
|
||||
@ -106,8 +106,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -127,7 +127,7 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
@ -218,7 +218,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
|
||||
|
||||
attn_dropout = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
|
||||
@ -173,9 +173,8 @@ def eager_attention_forward(
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
|
||||
@ -125,9 +125,8 @@ def eager_attention_forward(
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
)
|
||||
|
||||
attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attention_scores = attention_scores + causal_mask
|
||||
|
||||
|
||||
@ -281,8 +281,7 @@ def eager_attention_forward(
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
|
||||
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
|
||||
|
||||
@ -219,8 +219,7 @@ def eager_attention_forward(
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
|
||||
combined_logits = torch.cat([attn_weights, sinks], dim=-1)
|
||||
|
||||
@ -151,7 +151,7 @@ class GPTJAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
attn_weights = attn_weights / self.scale_attn
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
|
||||
@ -104,8 +104,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -325,8 +325,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -85,8 +85,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -315,8 +315,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -169,8 +169,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
@ -250,7 +250,6 @@ def eager_attention_forward(
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@ -141,8 +141,7 @@ def eager_attention_forward(
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user