Compare commits

..

5 Commits

Author SHA1 Message Date
eda8aaa849 remove shape check 2025-11-13 16:17:45 +01:00
39fea75bf3 finalize 2025-11-13 16:17:11 +01:00
70bb3bb300 remove slice 2025-11-13 15:55:16 +01:00
c4cfc2e023 [TP] Fix parameter detection issue and some invalid TP-plans (#42129)
* fix

* add test

* fix test

* fix the obvious

* more fix

* fix

* continue to improve

* more fix

* more

* fix

* fix

* finally

* CI
2025-11-13 15:44:56 +01:00
5c6d6bed4d [PEFT] Fix the general test for prefix tuning (#42185)
fix
2025-11-13 14:40:01 +00:00
207 changed files with 488 additions and 962 deletions

View File

@ -17,7 +17,6 @@ import contextlib
import json
import os
import time
from itertools import cycle
from typing import Optional
import datasets
@ -30,32 +29,42 @@ from transformers.generation import GenerationConfig
from transformers.generation.continuous_batching.requests import logger
def generate_without_cb(
model_id: str, sliding_window: int, attn_impl: str, batched_inputs: list[int], generation_config: GenerationConfig
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
SLIDING_WINDOW = 0
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
SKIP_SPECIAL_TOKENS = False
def generate_simple(
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
) -> dict[str, str]:
# Setup model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, attn_implementation=attn_impl)
attn_impl = {
"sdpa": "sdpa",
"eager": "eager",
"paged_attention": "eager", # TODO: this does not work on AMD docker
"flash_paged": "flash_attention_2", # TODO: this does not work on AMD docker
"kernels-community/flash-attn": "eager",
}[attn_impl]
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
model = model.cuda().eval()
if sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
model.config.sliding_window = sliding_window
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate one by one
if getattr(model.config, "sliding_window", None) is not None:
model.config.sliding_window = SLIDING_WINDOW
decoded_outputs = {}
for input_ids in tqdm(batched_inputs, desc="Generating outputs without CB"):
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
input_ids = torch.tensor([input_ids]).to("cuda")
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(
input_ids, attention_mask=attention_mask, generation_config=generation_config, use_model_defaults=False
)
# attention_mask = torch.ones_like(input_ids)
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
generated_tokens = outputs[0][input_ids.shape[1] :]
decoded_outputs[key] = tokenizer.decode(generated_tokens, skip_special_tokens=False)
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS)
decoded_outputs[key] = decoded_output
return decoded_outputs
def maybe_setup_metrics(use_metrics: bool) -> None:
if not use_metrics:
return
def setup_metrics():
try:
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -110,14 +119,16 @@ def batch_generate(
token_count = 0
data = []
for i, request in enumerate(batch_outputs):
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=SKIP_SPECIAL_TOKENS)
# The key is used to tie back to the output of unbatched generation
key = " ".join(map(str, batch_outputs[request].prompt_ids))
data.append({"input": input_text, "key": key})
# Try to decode the output
try:
output_text = tokenizer.decode(batch_outputs[request].generated_tokens, skip_special_tokens=False)
output_text = tokenizer.decode(
batch_outputs[request].generated_tokens, skip_special_tokens=SKIP_SPECIAL_TOKENS
)
token_count += len(batch_outputs[request].generated_tokens[1:])
data[-1]["cb_outputs"] = output_text
except Exception as e:
@ -127,7 +138,14 @@ def batch_generate(
# Display sample if asked
if i < displayed_samples:
print("-" * 20, f"{request} Input: {input_text}", f"{request} Output: {output_text}", sep="\n")
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
print(f"{request} Output: {output_text}")
else:
print(f"{request} Input: {input_text}")
print("[WARN]")
print(f"{request} Output was empty!")
# Compare with classic generate if asked
if expected_outputs is not None:
@ -164,102 +182,75 @@ def batch_generate(
if __name__ == "__main__":
# Parse args
parser = argparse.ArgumentParser()
# Continuous batching parameters
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
# Model parameters
parser.add_argument("--sliding-window", type=int, default=0)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
# Performance parameters
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--force-max-length", action="store_true", help="Force generation to stop at max length")
# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--output-file", type=str, default=None)
parser.add_argument("--compare", action="store_true")
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--profile", type=str, default=None)
args = parser.parse_args()
# Create model
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
has_system_role = args.sliding_window == 0
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn, dtype=torch.bfloat16)
model = model.cuda().eval()
if args.sliding_window > 0 and getattr(model.config, "sliding_window", None) is not None:
print(f"Setting sliding window from {model.config.sliding_window} to {args.sliding_window}")
model.config.sliding_window = args.sliding_window
# Set up diagnostics
# Set log level
logger.setLevel(args.log_level.upper())
maybe_setup_metrics(args.metrics)
# Set up performance
# If turned on, we setup metrics
if args.metrics:
setup_metrics()
# Set matmul precision if not none
if args.matmul_precision != "none":
torch.set_float32_matmul_precision(args.matmul_precision)
# Parse cuda graph argument
if args.cuda_graph is not None:
use_cuda_graph = {
"none": None,
"yes": True, "y": True, "true": True, "t": True, "1": True,
"no": False, "n": False, "false": False, "f": False, "0": False,
}[args.cuda_graph.lower()] # fmt: skip
else:
use_cuda_graph = None
cuda_graph_arg = args.cuda_graph.lower() if args.cuda_graph is not None else None
use_cuda_graph = {
"none": None, None: None,
"yes": True, "y": True, "true": True, "t": True, "1": True,
"no": False, "n": False, "false": False, "f": False, "0": False,
}[cuda_graph_arg] # fmt: skip
# Prepare model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
attn_implementation=args.attn,
dtype=torch.bfloat16,
)
model = model.cuda().eval()
if getattr(model.config, "sliding_window", None) is not None:
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
model.config.sliding_window = SLIDING_WINDOW
# If turned on, we compile the model
if args.compile:
model.forward = torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
# Prepare tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))
if args.add_prefix:
possible_prefixes = [
None,
"You are a bot that solves math problems.",
"You are a bot who solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning.",
"You are a bot with the aim to solves math problems. Try to make your answer clear and understandable, and include your stages of reasoning. No loud words or emojis, all responses must be readable by a child. Here is now the problem:",
] # fmt: skip
else:
possible_prefixes = [None]
batched_inputs = []
for item, prefix in zip(dataset, cycle(possible_prefixes)):
messages = []
question = item["question"]
if prefix is not None:
if has_system_role:
messages.append({"role": "system", "content": prefix})
else:
question = prefix + "\n\n" + question
messages.append({"role": "user", "content": question})
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
batched_inputs.append(inputs["input_ids"])
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
# Prepare generation config
generation_cfg = GenerationConfig(
generation_config = GenerationConfig(
max_new_tokens=512,
use_cuda_graph=use_cuda_graph,
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=args.do_sample,
do_sample=not args.compare,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
@ -267,12 +258,7 @@ if __name__ == "__main__":
)
# If we need to compare, we need to generate the reference outputs
if args.compare:
expected_outputs = generate_without_cb(
model_id, args.sliding_window, args.attn, batched_inputs, generation_cfg
)
else:
expected_outputs = None
expected_outputs = generate_simple(args.attn, simple_batch_inputs, generation_config) if args.compare else None
# If no output file is provided, we pick a name based on the args
if args.output_file is None:
@ -285,8 +271,8 @@ if __name__ == "__main__":
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
batch_generate(
model,
batched_inputs[: min(5, args.samples)],
generation_cfg,
simple_batch_inputs[: min(5, args.samples)],
generation_config,
tokenizer,
displayed_samples=-1,
)
@ -299,8 +285,8 @@ if __name__ == "__main__":
# Run batch generation
gen_time, tok_per_sec = batch_generate(
model,
batched_inputs,
generation_cfg,
simple_batch_inputs,
generation_config,
tokenizer,
displayed_samples=args.displayed,
output_file=args.output_file,
@ -311,5 +297,5 @@ if __name__ == "__main__":
prof.export_chrome_trace(filename)
# Example usage:
# python examples/pytorch/continuous_batching.py --attn sdpa --add-prefix --samples 10 --compare
# python examples/pytorch/continuous_batching.py --attn flash_attention_2 -mp none --add-prefix --samples 500
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --samples 3 --compare
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from math import floor, gcd, sqrt
from typing import Optional
@ -20,8 +21,8 @@ import torch
from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...utils.metrics import attach_tracer, traced
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import RequestState, get_device_and_memory_breakdown, logger
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import get_device_and_memory_breakdown, logger
def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
@ -31,7 +32,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
- All groups have the same number of layers
For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
"""
# If the config has no layer_type attribute, it means all layers are the same attention type
layer_types = getattr(config, "layer_types", None)
@ -115,6 +116,7 @@ class PagedAttentionCache:
for the sliding-attention group, although it is not needed.
"""
# TODO: this init is quite long, maybe a refactor is in order
def __init__(
self,
config: PreTrainedConfig,
@ -122,10 +124,8 @@ class PagedAttentionCache:
device: torch.device,
dtype: torch.dtype = torch.float16,
tp_size: Optional[int] = None,
allow_prefix_sharing: bool = True,
) -> None:
"""Initialize a paged attention cache for efficient memory usage. Also turns in prefix sharing if the model has
only full attention layers.
"""Initialize a paged attention cache for efficient memory usage.
Args:
config: Model configuration
@ -133,7 +133,6 @@ class PagedAttentionCache:
device: Device for the cache tensors
dtype: Data type of the cache
tp_size: Tensor parallelism size
allow_prefix_sharing: A flag to allow prefix sharing if the model has only full attention layers.
"""
self.config = config
self.dtype = dtype
@ -174,12 +173,10 @@ class PagedAttentionCache:
page_size = self.head_dim * self.num_key_value_heads
if "flash" in self.config._attn_implementation:
num_attention_masks = 0 # only used to compute the default memory footprint args
elif "sliding_attention" in group_types:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2
num_attention_masks = 1 # only used to compute the default meme args
else:
num_attention_masks = 1
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2 if "sliding_attention" in group_types else 1
memory_handler = PagedAttentionMemoryHandler(
block_size=self.block_size,
@ -221,6 +218,7 @@ class PagedAttentionCache:
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
# Block management data structures
self._free_blocks = deque(range(num_blocks))
self.group_cache_managers: list[CacheAllocator] = []
for i, group_type in enumerate(group_types):
if group_type == "full_attention":
@ -231,19 +229,13 @@ class PagedAttentionCache:
raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)
# We only use prefix sharing if the whole model has only full attention layers
self.use_prefix_sharing = allow_prefix_sharing and group_types == ["full_attention"]
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
self.blocks_to_complete: dict[str, int] = {}
self._total_prefix_length: int = 0 # a counter to measure the impact of prefix sharing, also used in tests
@traced
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
max_allocated = 0
for cm in self.group_cache_managers:
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
if allocated is None:
return None
max_allocated = max(max_allocated, allocated)
@ -254,11 +246,11 @@ class PagedAttentionCache:
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
by the cache managers."""
for cm in self.group_cache_managers:
cm.free_blocks(request_id, self._block_manager)
cm.free_blocks(request_id, self._free_blocks)
def get_num_free_blocks(self) -> int:
"""Get the current number of unallocated blocks available for new requests."""
return self._block_manager.num_free_blocks
return len(self._free_blocks)
@traced
def extend_read_indices(
@ -345,44 +337,6 @@ class PagedAttentionCache:
# Return the new KV values
return key_states_with_cache, value_states_with_cache
def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
"""Searches for a prefix match in the cache for the given (prompts_ids). If one is found, we reference the
matching blocks in the (request_id), increase the reference count of the blocks and return the number of blocks
that match. If no prefix match is found, we return 0."""
current_hash = None
allocated_blocks = []
for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
block_id = self._block_manager._hash_to_id.get(current_hash)
if block_id is not None:
allocated_blocks.append(block_id)
self._block_manager.increase_ref_count(block_id)
else:
break
# If we found a matching prefix, we reference the blocks in the request
if allocated_blocks:
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
cm = self.group_cache_managers[0]
cm.block_table[request_id] = allocated_blocks
prefix_length = len(allocated_blocks) * self.block_size
self._total_prefix_length += prefix_length
return prefix_length
def mark_blocks_as_complete(self, state: RequestState) -> None:
"""Marks the blocks that have been computed in the forward pass as complete. If prefix sharing is off, this is
a no-op."""
num_complete_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
if num_complete_blocks == 0:
return None
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
self._block_manager.mark_blocks_as_complete(
num_complete_blocks=num_complete_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.full_prompt_ids + state.static_outputs),
)
# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
@ -517,8 +471,6 @@ class PagedAttentionMemoryHandler:
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
])
If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
"""
cache_memory = self.get_available_memory(max_memory_percent)
logger.info(f"Cache memory: {cache_memory}")
@ -530,16 +482,11 @@ class PagedAttentionMemoryHandler:
c = -cache_memory
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
if self.num_attention_masks == 0:
greatest_solution = -c / b
# Otherwise, we solve the quadratic equation
else:
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
# Compute discriminant and greatest solution
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
if greatest_solution < 0:
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")

View File

@ -14,211 +14,29 @@
# limitations under the License.
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from math import ceil
from typing import Optional, TypeVar
from typing import Optional
from .requests import logger
T = TypeVar("T")
def reverse_enumerate(xs: list[T]) -> Iterator[tuple[int, T]]:
index = len(xs) - 1
for x in xs[::-1]:
yield index, x
index -= 1
class Block:
"""A class to represent a block managed by the block manager. We say that a block is complete when the physical KV
cache it points to is fully computed. A block can have a parent, which is the block that came before in the
sequence. Once a block is complete, it is given a hash, which takes into account the tokens ids of the block and
its parent's hash (if there is a parent)."""
def __init__(self, id_: int, parent_id: int | None) -> None:
self.id: int = id_
self.parent_id: int | None = parent_id
self.hash: int | None = None
self.ref_count: int = 1
def __repr__(self) -> str:
return f"Block(id={self.id}, parent_id={self.parent_id}, hash={self.hash}, ref_count={self.ref_count})"
@property
def is_complete(self) -> bool:
return self.hash is not None
class BlockManager:
"""A class to manage the number of free blocks and block re-use. If prefix sharing is off, the block manager is a
simple FIFO structure where blocks are either free or in use. If prefix sharing is on, blocks can have 3 states:
- in use: one or more requests references this block, thus it cannot be written over. The number of requests
referencing this block is stored as ref_count in the Block object.
- un-initialized: the block points to a space in the KV cache tensor that contains no data yet. Those blocks can
be given as free blocks to new requests without any overhead.
- initialized: the block is complete and was used by one or more request that are finished. It contains KV cache
data and its hash is stored in the hash table. If a new request needs a block with the same hash, we increase
the ref_count of the block and remove it from the list of initialized blocks, because it is now in use.
Still, the block can be freed if no un-initialized blocks are left. In that case, we remove its hash from the
hash table.
There is no structure to keep track of the blocks in use: if a block is neither un-initialized nor initialized,
it is in use.
"""
def __init__(self, num_blocks: int, block_size: int, use_prefix_sharing: bool) -> None:
"""Initializes the block manager with a given number of blocks (num_blocks) of size (block_size). Prefix sharing
can be turned on with the (use_prefix_sharing) flag, which only happens if the model has only full attention
layers."""
self.num_blocks = num_blocks
self.block_size = block_size
self._uninit_block_ids = deque(range(num_blocks))
self._init_block_ids: dict[int, None] = {} # effectively act as an ordered set
self._use_prefix_sharing = use_prefix_sharing
self._hash_to_id: dict[int, int] = {}
self._id_to_block: dict[int, Block] = {}
@property
def num_free_blocks(self) -> int:
"""Returns the number of free blocks left. Both initialized and uninitialized blocks are considered free."""
return len(self._uninit_block_ids) + len(self._init_block_ids)
def is_enough_free_blocks(self, n_blocks: int) -> bool:
"""Checks if there are enough free blocks to allocate the requested number of blocks (n_blocks). If there are
not enough uninitialized blocks, we uninitialize the required number of initialized blocks."""
# Exit early if there are enough uninitialized blocks
if len(self._uninit_block_ids) >= n_blocks:
return True
# Exit early if even after uninitializing all initialized blocks, there are not enough free blocks
block_to_unintialize = n_blocks - len(self._uninit_block_ids)
if len(self._init_block_ids) < block_to_unintialize:
return False
# Uninitialize the required amount of blocks
for _ in range(block_to_unintialize):
id_to_unintialize = self._init_block_ids.popitem()[0]
block = self._id_to_block[id_to_unintialize]
self._hash_to_id.pop(block.hash)
self._uninit_block_ids.append(id_to_unintialize)
return True
def get_free_blocks(self, n_blocks: int, last_block_id: int | None) -> list[int] | None:
"""Returns a list of (n_blocks) free block and mark them as no longuer free in the internal data structures. One
can also pass a (last_block_id) to indicate the last block id in the sequence, which is used to keep track of
the parent block. If the manager cannot find enough free blocks, it returns None."""
if not self.is_enough_free_blocks(n_blocks):
return None
allocated_block_ids = [self._uninit_block_ids.popleft() for _ in range(n_blocks)]
# If we use prefix caching, we keep track of the allocated blocks as partial blocks
if self._use_prefix_sharing:
for block_id in allocated_block_ids:
block = Block(block_id, last_block_id)
self._id_to_block[block_id] = block
last_block_id = block_id
# In both cases, we return the allocated block ids
return allocated_block_ids
def increase_ref_count(self, block_id: int) -> None:
"""Increases the reference count of a given (block_id)."""
block = self._id_to_block[block_id]
block.ref_count += 1
if block.ref_count == 1:
self._init_block_ids.pop(block_id)
def decrease_ref_count(self, block_id: int) -> None:
"""Decreases the reference count of a given (block_id). If the reference count reaches 0, the block is no longer
in use, and becomes initialized (if it was complete) or uninitialized (if it was incomplete)."""
block = self._id_to_block[block_id]
block.ref_count -= 1
if block.ref_count == 0:
if block.is_complete:
self._init_block_ids[block_id] = None
else:
self._id_to_block.pop(block_id)
self._uninit_block_ids.append(block_id)
def free_blocks(self, blocks: list[int]) -> None:
"""Marks a list of (blocks) as free. If there is no prefix sharing, we simply add them to the uninitialized
blocks queue. Otherwise, their new state depends on whether they are complete."""
if self._use_prefix_sharing:
for block_id in blocks:
self.decrease_ref_count(block_id)
else:
self._uninit_block_ids.extend(blocks)
def mark_blocks_as_complete(
self, num_complete_blocks: int, allocated_blocks: list[int], prompt_ids: list[int]
) -> None:
"""Among the list of (allocated_blocks), mark (num_complete_blocks) incomplete blocks as now complete. The list
of (prompt_ids) is used to compute the hash of the new block."""
# Look for the first complete block, starting from the last block in the sequence
parent_hash = None
incomplete_blocks: list[Block] = []
for i, block_id in reverse_enumerate(allocated_blocks):
block = self._id_to_block[block_id]
if block.is_complete:
parent_hash = block.hash
break
incomplete_blocks.append((i, block))
# Now go through the incomplete blocks and updated them
new_parent_id = None
while incomplete_blocks:
i, block = incomplete_blocks.pop()
# If the parent id has been updated, we apply the change
if new_parent_id is not None:
block.parent_id = new_parent_id
new_parent_id = None
# If we have set the hash for all complete blocks, we can stop
if num_complete_blocks == 0:
break
# Otherwise, we compute the hash
num_complete_blocks -= 1
tokens = prompt_ids[i * self.block_size : (i + 1) * self.block_size]
block.hash = self.compute_hash(parent_hash, tokens)
existing_block_id = self._hash_to_id.get(block.hash)
# If the block hash is already in the hash to id mapping, we reference the existing block instead
if existing_block_id is not None:
logger.debug(f"Found existing block {existing_block_id} for block {block.id}")
allocated_blocks[i] = existing_block_id
self._id_to_block[existing_block_id].ref_count += 1
new_parent_id = existing_block_id
self.free_blocks([block.id])
# Otherwise, we add the completed block to the hash table
else:
self._hash_to_id[block.hash] = block.id
# Update loop variables
parent_hash = block.hash
def compute_hash(self, parent_hash: int | None, tokens: list[int]) -> int:
"""Computes the hash of a block containing the given (tokens) with a given (parent_hash). If the block has no
parent, the parent hash is None."""
return hash((parent_hash, tuple(tokens)))
class CacheAllocator(ABC):
"""Abstract base class for cache managers. Cache managers keep track of per-request cache allocations, determine
when a new physical block needs to be allocated and compute physical indices for reading or writing to the cache."""
_index: int
block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
_block_table: dict[str, list[int]] # request_id -> list of block_ids allocated to the request
@abstractmethod
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocates (n_blocks) for a given (request_id) using the (block_manager). Returns the num of blocks allocated
if successful and None otherwise."""
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocates n_blocks for a given request_id. Returns the num of blocks allocated if successful and None
otherwise."""
def free_blocks(self, request_id: str, block_manager: BlockManager) -> None:
"""Frees all blocks associated with a (request_id) using the (block_manager)."""
if request_id in self.block_table:
blocks_to_free = self.block_table.pop(request_id)
block_manager.free_blocks(blocks_to_free)
def free_blocks(self, request_id: str, free_blocks: deque[int]) -> None:
"""Frees all blocks associated with a request_id."""
if request_id in self._block_table:
blocks_to_free = self._block_table.pop(request_id)
free_blocks.extend(blocks_to_free)
else:
logger.warning(
f"CacheAllocator {self._index} attempted to free blocks for non-existent request_id: {request_id}"
@ -248,30 +66,23 @@ class FullAttentionCacheAllocator(CacheAllocator):
"""
self._index = index
self.block_size = block_size
self.block_table = {}
self._block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
allocated if successful and None otherwise. For group of full attention layers, we always allocate the number of
requested blocks."""
# Make sure the request_id is in the block table and get the first block id
if request_id not in self.block_table:
self.block_table[request_id] = [] # TODO: check the impact of making this a deque
last_block_id = None
else:
last_block_id = self.block_table[request_id][-1]
# Actual allocation, return early if failed
allocated_blocks = block_manager.get_free_blocks(n_blocks, last_block_id)
if allocated_blocks is None:
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of full attention layers, we always allocate the number of requested blocks."""
if len(free_blocks) < n_blocks:
return None
self.block_table[request_id].extend(allocated_blocks)
if request_id not in self._block_table:
self._block_table[request_id] = []
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(n_blocks))
return n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices of where to read request_id's cache. For a group of full attention layers, we
first write the new cache to the cache tensor and then read the entire cache from the beginning to the end."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -286,7 +97,7 @@ class FullAttentionCacheAllocator(CacheAllocator):
def get_write_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
"""Returns the physical indices for writing to the cache. For a group of full attention layers, we write the new
cache as a continuation of the existing cache for the same request."""
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Compute the physical indices
@ -318,26 +129,25 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
self.block_size = block_size
self.sliding_window = sliding_window
self._max_blocks_per_request = ceil(self.sliding_window / self.block_size)
self.block_table = {}
self._block_table = {}
def allocate_blocks(self, n_blocks: int, request_id: str, block_manager: BlockManager) -> Optional[int]:
"""Allocate (n_blocks) for a given (request_id) using the (block_manager). Returns the number of blocks
allocated otherwise. For group of sliding window attention layers, we only allocate up to the point where we can
fit an entire sliding window in the cache tensor."""
if request_id not in self.block_table:
self.block_table[request_id] = []
def allocate_blocks(self, n_blocks: int, request_id: str, free_blocks: deque[int]) -> Optional[int]:
"""Allocate blocks for a given request_id. Returns the number of blocks allocated if successful and None
otherwise. For group of sliding window attention layers, we only allocate up to the point where we can fit an
entire sliding window in the cache tensor."""
if request_id not in self._block_table:
self._block_table[request_id] = []
# Early return if we are already at the max number of blocks per request
already_allocated = len(self.block_table[request_id])
already_allocated = len(self._block_table[request_id])
if already_allocated == self._max_blocks_per_request:
return 0
# Compute actual number of blocks to allocate
after_allocation = min(already_allocated + n_blocks, self._max_blocks_per_request)
actual_n_blocks = after_allocation - already_allocated
# Classic allocation
allocated_blocks = block_manager.get_free_blocks(actual_n_blocks, None) # no prefix caching w/ sliding window
if allocated_blocks is None:
if len(free_blocks) < actual_n_blocks:
return None
self.block_table[request_id].extend(allocated_blocks)
self._block_table[request_id].extend(free_blocks.popleft() for _ in range(actual_n_blocks))
return actual_n_blocks
def get_read_indices(self, request_id: str, past_length: int, query_length: int) -> list[int]:
@ -347,7 +157,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding_window - 1 cache page and then manually add the new key / values states after. Hence the -1 indices
which indicate where to store the new key or values indices."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window
@ -368,7 +178,7 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
sliding window attention layers, we write the new cache in rolling-buffer kind of way: if we reach the end of
the allocated physical cache, we start writing from the beginning of the physical cache again."""
# Retrieve the block table for the request and raise an error if it doesn't exist
block_table = self.block_table.get(request_id)
block_table = self._block_table.get(request_id)
if block_table is None:
raise ValueError(f"No block table found for request {request_id}")
# Apply sliding window
@ -391,3 +201,22 @@ class SlidingAttentionCacheAllocator(CacheAllocator):
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""
seqlens_k = query_length + min(past_length, self.sliding_window - 1)
return "sliding_attention", seqlens_k
# TODO: test the impact of this
# def get_read_indices(self, request_id: str, past_length: int) -> list[int]:
# # Retrieve the block table for the request and raise an error if it doesn't exist
# block_table = self._block_table.get(request_id)
# if block_table is None:
# raise ValueError(f"No block table found for request {request_id}")
# # Compute the physical indices
# physical_indices = []
# n_left = past_length
# for block_idx in block_table:
# block_physical_index = block_idx * self.block_size
# pages_used = min(self.block_size, n_left)
# physical_indices.extend(block_physical_index + i for i in range(pages_used))
# n_left -= pages_used
# if n_left == 0:
# return physical_indices
# raise ValueError(f"Request {request_id} required too many indices: {past_length = } and {len(block_table) = }")

View File

@ -16,13 +16,12 @@
import queue
import threading
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from itertools import count
from math import ceil
from time import perf_counter
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
@ -447,7 +446,10 @@ class ContinuousBatchProcessor:
cumulative_seqlens_q = [0]
logits_indices = []
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
if isinstance(self.cumulative_seqlens_k, dict):
cumulative_seqlens_k = {layer_type: [0] for layer_type in self.cumulative_seqlens_k}
else:
cumulative_seqlens_k = [0]
read_index = [[] for _ in range(self.cache.num_groups)]
write_index = [[] for _ in range(self.cache.num_groups)]
@ -496,7 +498,10 @@ class ContinuousBatchProcessor:
self.metrics.record_kv_cache_memory_metrics(self.cache)
if logger.isEnabledFor(logging.DEBUG):
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
if isinstance(self.cumulative_seqlens_k, dict):
ck = max(cumulative_seqlens_k[layer_type][-1] for layer_type in self.cumulative_seqlens_k)
else:
ck = cumulative_seqlens_k[-1]
logger.debug(
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, "
f"Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. "
@ -512,7 +517,7 @@ class ContinuousBatchProcessor:
read_index: list[list[int]],
write_index: list[list[int]],
cumulative_seqlens_q: list[int],
cumulative_seqlens_k: dict[str, list[int]],
cumulative_seqlens_k: Union[list[int], dict[str, list[int]]],
logits_indices: list[int],
) -> None:
"""Builds the actual tensors for the current batch, by modifying the already allocated tensors in place."""
@ -556,7 +561,9 @@ class ContinuousBatchProcessor:
@traced
def _maybe_send_output(self, state: RequestState) -> None:
"""Send output to the queue based on streaming mode and request state."""
if state.streaming or state.status == RequestStatus.FINISHED:
if state.streaming:
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
@traced
@ -564,27 +571,17 @@ class ContinuousBatchProcessor:
"""Update request states based on generated tokens."""
out_tokens = self._sync()
for i, state in enumerate(self.requests_in_batch):
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prompt_ids) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = out_tokens[self.logits_indices[i]]
state.prompt_ids = [token]
# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
# We mark the completed blocks as such
self.cache.mark_blocks_as_complete(state)
if is_finished:
if state.update_with_token(token):
self.metrics.record_request_completion(state.created_time, state.request_id)
self.scheduler.finish_request(state.request_id, evict_from_cache=(not self.manual_eviction))
self._maybe_send_output(state)
# Otherwise, the request is still prefilling, but the prefill has been split
elif state.status == RequestStatus.PREFILLING_SPLIT:
self.cache.mark_blocks_as_complete(state)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
else:
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")
@ -729,7 +726,6 @@ class ContinuousBatchingManager:
max_queue_size: int = 0,
num_q_cuda_graphs: int = 0,
num_kv_cuda_graphs: int = 0,
allow_prefix_sharing: bool = True,
) -> None:
"""Initialize the continuous batching manager.
@ -739,7 +735,6 @@ class ContinuousBatchingManager:
max_queue_size: Maximum size of the request queue (0 = unlimited)
num_q_cuda_graphs: (optional) Number of CUDA graphs to use for the query dimension
num_kv_cuda_graphs: (optional) Number of CUDA graphs to use for the keys/values dimension
allow_prefix_sharing: (optional) Whether to allow prefix sharing if the model has only full attention layers
"""
if "paged|" not in model.config._attn_implementation:
attn_implementation = f"paged|{model.config._attn_implementation}"
@ -772,8 +767,6 @@ class ContinuousBatchingManager:
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self._allow_prefix_sharing = allow_prefix_sharing
# If a number of cuda graphs was specified for either Q or KV, we activate cuda graphs
if num_q_cuda_graphs > 0 or num_kv_cuda_graphs > 0:
self.use_cuda_graph = True
@ -806,6 +799,7 @@ class ContinuousBatchingManager:
logger.warning("Manager thread is already running.")
return
self._result_queue = queue.Queue()
self._generation_thread = threading.Thread(target=self._run_generation_loop)
self._generation_thread.start()
@ -820,16 +814,6 @@ class ContinuousBatchingManager:
block: Whether to wait for the thread to stop
timeout: Maximum time to wait for the thread to stop
"""
if self.batch_processor is None:
logger.warning("\nBatch processor was not initialized.")
else:
if self.batch_processor.cache.use_prefix_sharing:
logger.warning(
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
)
else:
logger.warning("\nPrefix sharing was off.")
if self._generation_thread is None:
logger.warning("Manager not started.")
return
@ -955,6 +939,20 @@ class ContinuousBatchingManager:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
@traced
def warmup(self, batch_processor: ContinuousBatchProcessor) -> None:
stream = torch.cuda.Stream(device=self.model.device)
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
# Warmup the model with a dummy forward pass
self._generation_step(batch_processor)
torch.cuda.current_stream().wait_stream(stream)
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, stream=stream):
self._generation_step(batch_processor)
@traced
# @torch.compile
def _generation_step(self) -> None:
"""Perform a single generation step. This is cuda graphed"""
self.batch_processor._generation_step(self.model, self.logit_processor, self.do_sample)
@ -970,7 +968,6 @@ class ContinuousBatchingManager:
self.model.device,
self.model.dtype,
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
allow_prefix_sharing=self._allow_prefix_sharing,
)
logger.debug(f"PagedAttentionCache created in {perf_counter() - t0} seconds")
@ -1062,15 +1059,6 @@ class ContinuousBatchingManager:
class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""
@contextmanager
def continuous_batching_context_manager(self, **kwargs) -> Generator[ContinuousBatchingManager]:
manager = self.init_continuous_batching(**kwargs)
manager.start()
try:
yield manager
finally:
manager.stop(block=True)
def init_continuous_batching(
self,
generation_config: Optional[GenerationConfig] = None,
@ -1078,7 +1066,6 @@ class ContinuousMixin:
max_queue_size: int = 0,
num_q_cuda_graphs: int = 0,
num_kv_cuda_graphs: int = 0,
allow_prefix_sharing: bool = True,
) -> ContinuousBatchingManager:
"""Initialize a manager for continuous batching inference.
@ -1111,7 +1098,6 @@ class ContinuousMixin:
max_queue_size=max_queue_size,
num_q_cuda_graphs=num_q_cuda_graphs,
num_kv_cuda_graphs=num_kv_cuda_graphs,
allow_prefix_sharing=allow_prefix_sharing,
)
# TODO: support streaming
@ -1183,6 +1169,5 @@ class ContinuousMixin:
except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)
finally:
logger.debug("Generate batch is finished.") # a dummy log needed for the logs of stop to show. Won't show.
manager.stop(block=True, timeout=5.0)
return results

View File

@ -116,10 +116,10 @@ class RequestState:
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
"""
# Required fields # TODO: come up with better names / not sure prompt_ids and such are not redundant
# Required fields
request_id: str
full_prompt_ids: Optional[list[int]] = None # Full initial prompt
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed
prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
static_outputs: list[int] = field(default_factory=list) # Generated tokens
allocated_blocks: int = 0 # Number of blocks allocated to the request
@ -164,7 +164,7 @@ class RequestState:
# TODO: this logic seems one token off, check it out
@traced
def update_and_check_completion(self, token_id: int) -> bool:
def update_with_token(self, token_id: int) -> bool:
"""Update the request with a newly generated token and check for completion.
Args:

View File

@ -104,7 +104,7 @@ class Scheduler(ABC):
)
@traced
def _allocate_blocks_if_needed(self, state: RequestState) -> bool:
def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
"""Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
accommodate the next tokens. It calculates how many blocks are needed based on the request's current
cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
@ -113,11 +113,10 @@ class Scheduler(ABC):
# 1. we check that the occupancy is less than the requested length
# 2. we allocate enough blocks to cover the requested length
current_len = state.current_len()
len_next_tokens = len(state.prompt_ids)
occupancy = state.allocated_blocks * self.cache.block_size - current_len
if occupancy < len_next_tokens or state.allocated_blocks == 0:
blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
allocated = self.cache.allocate_blocks(blocks_needed, state)
allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
if allocated is None:
return False
state.allocated_blocks += allocated
@ -126,29 +125,11 @@ class Scheduler(ABC):
@traced(span_name="prepare_request")
def _prepare_request_for_processing(
self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
) -> None:
"""Prepares a request for processing in the current batch. If prefix sharing is enabled, and the request was
pending, this is where we look for a prefix match and split the request if found."""
# If prefix sharing is enabled, we look for a prefix match and split the request if found
if self.cache.use_prefix_sharing and state.status == RequestStatus.PENDING:
prefill_length = self.cache.search_prefix_match(state.request_id, state.prompt_ids)
if prefill_length > 0:
self.active_requests[state.request_id] = state
request_ids_to_remove_from_waiting.add(state.request_id)
state.status = RequestStatus.SPLIT_PENDING_REMAINDER
# Even if we match the whole request, we keep at least 1 token to start decoding
prefill_length = min(prefill_length, len(state.prompt_ids) - 1)
state.remaining_prompt_ids = state.prompt_ids[prefill_length:]
state.prompt_ids = state.prompt_ids[prefill_length:]
state.position_offset += prefill_length
# If the request has a split prefill, the tokens to process are the remaining prompt ids
if state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
request_tokens = state.remaining_prompt_ids
# Otherwise, the tokens to process are the prompt ids, which are the full prompt or the last predicted tokens
else:
request_tokens = state.prompt_ids
):
"""Prepares a request for processing in the current batch."""
request_tokens = (
state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
)
if len(request_tokens) < token_budget:
# Can process the entire prompt/remainder
if state.status == RequestStatus.PENDING:
@ -171,7 +152,6 @@ class Scheduler(ABC):
state.prompt_ids = request_tokens[:token_budget]
# TODO: further common-ize the two classes
@attach_tracer()
class FIFOScheduler(Scheduler):
"""This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
@ -215,31 +195,30 @@ class FIFOScheduler(Scheduler):
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
# Add the request to the scheduled requests
scheduled_requests.append(state)
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
complete_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = complete_blocks
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
was_waiting = self.waiting_requests.pop(req_id, None) is not None
if was_waiting:
request_ids_to_remove_from_waiting.add(req_id)
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break
@ -270,7 +249,6 @@ class PrefillFirstScheduler(Scheduler):
elif state.status == RequestStatus.DECODING:
second_priority_states.append(state)
# Add waiting requests to second priority
for req_id in self.waiting_requests_order:
second_priority_states.append(self.waiting_requests[req_id])
@ -281,31 +259,30 @@ class PrefillFirstScheduler(Scheduler):
for state in candidates:
self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
request_len = len(state.prompt_ids)
# If we can't allocate blocks, do not schedule the request and break if the cache is full
if not self._allocate_blocks_if_needed(state):
if self.cache.get_num_free_blocks() == 0:
if not self._allocate_blocks_if_needed(
state, len(state.prompt_ids)
): # don't schedule if we can't allocate blocks
if len(self.cache._free_blocks) == 0:
break
continue
# Add the request to the scheduled requests
scheduled_requests.append(state)
@traced
def _add_to_scheduled_requests(state: RequestState):
scheduled_requests.append(state)
_add_to_scheduled_requests(state)
# Update the token budget
token_budget -= request_len
# If using prefix sharing, we make note of the blocks that will be computed in the forward pass
if self.cache.use_prefix_sharing:
tokens_in_current_block = state.current_len() % self.cache.block_size
tokens_after_forward = tokens_in_current_block + request_len
complete_blocks = tokens_after_forward // self.cache.block_size
self.cache.blocks_to_complete[state.request_id] = complete_blocks
# Remove the request from the waiting queue and mark it as removed
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
@traced
def _remove_from_waiting_requests(state: RequestState):
req_id = state.request_id
if req_id in self.waiting_requests:
del self.waiting_requests[req_id]
request_ids_to_remove_from_waiting.add(req_id)
_remove_from_waiting_requests(state)
# Early exit of the loop if we have no token budget left
if token_budget == 0:
break

View File

@ -140,6 +140,16 @@ def _blocks_to_block_sizes(total_size: int, blocks: int | list[int]) -> list[int
return [single_size] * blocks
def replace_layer_number_by_wildcard(name: str) -> str:
"""
Replace the numbers in the `name` by wildcards, only if they are in-between dots (`.`) or if they are between
a dot (`.`) and the end of the string.
This matches how modules are named/numbered when using a nn.ModuleList or nn.Sequential, but will NOT match
numbers in a parameter name itself, e.g. if the param is named `"w1"` or `"w2"`.
"""
return re.sub(r"\.\d+(\.|$)", lambda m: ".*" + m.group(1), name)
def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weight=True) -> str | None:
"""
Get the TP style for a parameter from the TP plan.
@ -150,11 +160,11 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
The `is_weight` is important because for weights, we want to support `.weights` and `.bias` cases seamlessly! but
not parent classes for `post_init` calls
"""
generic_param_name = re.sub(r"\d+", "*", parameter_name)
generic_param_name = replace_layer_number_by_wildcard(parameter_name)
if generic_param_name in tp_plan:
return tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan and is_weight:
return tp_plan[generic_param_name.rsplit(".", 1)[0]]
elif is_weight and "." in generic_param_name and (module_name := generic_param_name.rsplit(".", 1)[0]) in tp_plan:
return tp_plan[module_name]
return None
@ -1086,7 +1096,7 @@ def verify_tp_plan(expected_keys: list[str], tp_plan: dict[str, str] | None):
if tp_plan is None:
return
generic_keys = {re.sub(r"\d+", "*", key) for key in expected_keys}
generic_keys = {replace_layer_number_by_wildcard(key) for key in expected_keys}
unsharded_layers = set(generic_keys)
unused_rules = tp_plan

View File

@ -128,7 +128,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -580,8 +580,7 @@ def eager_attention_forward(
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -106,7 +106,6 @@ class ApertusConfig(PreTrainedConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -201,8 +201,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -123,7 +123,6 @@ class ApertusConfig(LlamaConfig):
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
}
def __init__(

View File

@ -208,8 +208,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -99,15 +99,14 @@ class AriaTextConfig(PreTrainedConfig):
model_type = "aria_text"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `AriaTextModel`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -431,8 +431,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -169,6 +169,15 @@ class AriaTextConfig(LlamaConfig):
model_type = "aria_text"
base_config_key = "text_config"
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
}
def __init__(
self,

View File

@ -114,7 +114,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -58,8 +58,8 @@ def eager_attention_forward(
scaling = query.size(-1) ** -0.5
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None and attention_mask.ndim == 4:
attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -292,8 +292,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -126,7 +126,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -130,7 +130,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -74,7 +74,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -1172,7 +1172,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -106,7 +106,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -139,8 +139,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -121,7 +121,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -105,7 +105,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -284,7 +284,7 @@ class BloomAttention(nn.Module):
# change view to [batch_size, num_heads, q_length, kv_length]
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
attn_weights = attn_weights + causal_mask

View File

@ -250,8 +250,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -420,7 +420,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -70,7 +70,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -232,8 +232,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -247,8 +247,7 @@ def eager_attention_forward(
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -1061,8 +1061,7 @@ def eager_attention_forward(
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -173,8 +173,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -149,8 +149,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -258,8 +258,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -167,8 +167,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -190,7 +190,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -175,7 +175,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -165,8 +165,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -64,8 +64,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, **kwargs)
if attention_mask is not None:
# Apply the attention mask
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -252,8 +252,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -300,8 +300,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -179,7 +179,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -254,8 +254,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -256,7 +256,7 @@ class DiffLlamaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

View File

@ -132,7 +132,7 @@ class DiffLlamaAttention(nn.Module):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

View File

@ -167,7 +167,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -187,7 +187,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -206,7 +206,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -137,7 +137,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -118,9 +118,9 @@ class DogeConfig(PreTrainedConfig):
"layers.*.self_attn.dt_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.input_residual.weight": "sequence_parallel",
"layers.*.input_residual": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_residual.weight": "sequence_parallel",
"layers.*.post_attention_residual": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",

View File

@ -196,8 +196,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -146,9 +146,9 @@ class DogeConfig(PreTrainedConfig):
"layers.*.self_attn.dt_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.input_residual.weight": "sequence_parallel",
"layers.*.input_residual": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_residual.weight": "sequence_parallel",
"layers.*.post_attention_residual": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",

View File

@ -188,8 +188,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -286,7 +286,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -358,8 +358,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -135,7 +135,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -105,8 +105,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -146,7 +146,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -156,8 +156,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -214,8 +214,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -270,7 +270,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -234,7 +234,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -194,8 +194,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -348,8 +348,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -114,9 +114,9 @@ class FlexOlmoConfig(PreTrainedConfig):
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -168,8 +168,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -125,9 +125,9 @@ class FlexOlmoConfig(OlmoeConfig):
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.experts.*.gate_proj": "colwise",
"layers.*.mlp.experts.*.up_proj": "colwise",
"layers.*.mlp.experts.*.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -205,7 +205,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -205,8 +205,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -217,9 +217,8 @@ def eager_attention_forward(
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

View File

@ -291,9 +291,8 @@ def eager_attention_forward(
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

View File

@ -291,9 +291,8 @@ def eager_attention_forward(
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
@ -630,7 +629,7 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Gemma3TextConfig
base_model_prefix = "language_model"
base_model_prefix = "model"
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)

View File

@ -715,7 +715,7 @@ class Gemma3TextModel(Gemma2Model):
class Gemma3ForCausalLM(Gemma2ForCausalLM):
config: Gemma3TextConfig
base_model_prefix = "language_model"
base_model_prefix = "model"
def __init__(self, config: Gemma3TextConfig):
super().__init__(config)

View File

@ -1185,9 +1185,8 @@ def eager_attention_forward(
attn_weights = attn_weights / softcap
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * softcap
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

View File

@ -156,8 +156,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -138,8 +138,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -136,8 +136,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -265,8 +265,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -214,8 +214,9 @@ class Glm4vMoeTextConfig(PreTrainedConfig):
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -185,8 +185,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -159,8 +159,9 @@ class Glm4vMoeTextConfig(Glm4MoeConfig):
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),

View File

@ -74,8 +74,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, **kwargs)
if attention_mask is not None:
# Apply the attention mask
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -106,8 +106,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -127,7 +127,7 @@ class GPTNeoSelfAttention(nn.Module):
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
@ -218,7 +218,7 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention):
attn_dropout = self.config.attention_dropout if self.training else 0.0
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons

View File

@ -173,9 +173,8 @@ def eager_attention_forward(
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

View File

@ -125,9 +125,8 @@ def eager_attention_forward(
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

View File

@ -300,7 +300,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
)
attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1)
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attention_scores = attention_scores + causal_mask

View File

@ -281,8 +281,7 @@ def eager_attention_forward(
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)

View File

@ -219,8 +219,7 @@ def eager_attention_forward(
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
combined_logits = torch.cat([attn_weights, sinks], dim=-1)

View File

@ -151,7 +151,7 @@ class GPTJAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / self.scale_attn
if attention_mask is not None: # no matter the length, we just slice it
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask

View File

@ -104,8 +104,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -325,8 +325,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -85,8 +85,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -315,8 +315,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -169,8 +169,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

View File

@ -250,7 +250,6 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

View File

@ -141,8 +141,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

Some files were not shown because too many files have changed in this diff Show More