Compare commits

...

2 Commits

Author SHA1 Message Date
b54358e4cf feat: rework widgets 2025-06-10 17:03:21 +02:00
2274ce74a7 feat: add a ContinuousBatchingVisualizer 2025-06-06 19:04:29 +02:00
7 changed files with 1052 additions and 25 deletions

View File

@ -0,0 +1,42 @@
import datasets
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map=0
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
generation_config = GenerationConfig(
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=2048,
block_size=128,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
def tokenize_function(examples):
return tokenizer(examples["question"])
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
progress_bar=False,
enable_visualizer=True,
tokenizer=tokenizer,
)

View File

@ -11,7 +11,7 @@ torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map=0
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

View File

@ -204,6 +204,7 @@ _deps = [
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-sdk",
"textual",
]
@ -441,6 +442,8 @@ extras["benchmark"] = deps_list("optimum-benchmark")
# OpenTelemetry dependencies for metrics collection in continuous batching
extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk")
extras["continuous-batching-visualizer"] = deps_list("rich", "textual")
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads

View File

@ -25,6 +25,7 @@ from enum import Enum
from functools import partial
from typing import Deque, Dict, List, Optional, Set, Tuple, Union
from tokenizers import Tokenizer
import torch
import torch.nn as nn
from torch.profiler import profile, schedule, tensorboard_trace_handler
@ -33,6 +34,7 @@ from tqdm import tqdm
from ..cache_utils import Cache
from ..configuration_utils import PretrainedConfig
from ..generation.configuration_utils import GenerationConfig
from ..utils.continuous_batching_visualizer import ContinuousBatchingVisualizer
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
@ -1102,6 +1104,7 @@ class ContinuousBatchingManager:
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self.visualizer = None
@traced
def start(self):
@ -1151,6 +1154,12 @@ class ContinuousBatchingManager:
logger.info("Continuous Batching Manager stopped.")
self._generation_thread = None
def set_tokenizer(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
def set_visualizer(self, visualizer: ContinuousBatchingVisualizer):
self.visualizer = visualizer
def add_request(
self, input_ids: List[int], request_id: Optional[str] = None, max_new_tokens: Optional[int] = None
) -> str:
@ -1312,13 +1321,13 @@ class ContinuousBatchingManager:
record_shapes=False,
with_stack=True,
) as prof:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
while not self.stop_event.is_set():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
prof.step()
else:
while not self.stop_event.is_set() or batch_processor.has_pending_requests():
while not self.stop_event.is_set():
self._inner_generation_loop(batch_processor, is_first)
if is_first:
is_first = False
@ -1334,6 +1343,10 @@ class ContinuousBatchingManager:
if torch.cuda.is_available():
torch.cuda.synchronize()
batch_processor.prepare_next_batch()
if self.visualizer is not None:
viz_data = self._collect_visualization_data(batch_processor)
self.visualizer.draw(viz_data)
self.visualizer.wait_for_input()
if torch.cuda.is_available() and self.use_cuda_graph:
if is_first:
self.warmup(batch_processor)
@ -1383,6 +1396,51 @@ class ContinuousBatchingManager:
if self.batch_processor is not None:
self.batch_processor.scheduler.finish_request(request_id)
def _collect_visualization_data(self, batch_processor: ContinuousBatchProcessor) -> Dict:
"""Collect data for visualization."""
data = {
"batch_contents": [],
"words": [],
"request_ids_per_token": [],
}
data["attention_mask"] = batch_processor.attention_mask.clone()
# Collect all tokens and map them to request IDs
all_tokens = []
all_request_ids = []
for req in batch_processor.requests_in_batch:
if self.tokenizer is not None:
decoded = self.tokenizer.decode(req.prompt_ids)
decoded_tokens_list = self.tokenizer.convert_ids_to_tokens(req.prompt_ids)
data["batch_contents"].append({"request_id": req.request_id, "decoded": decoded, "decoded_tokens": decoded_tokens_list})
all_tokens.extend(decoded_tokens_list)
else:
data["batch_contents"].append({"request_id": req.request_id, "tokens": req.prompt_ids})
# Convert token IDs to strings when no tokenizer is available
all_tokens.extend([str(token_id) for token_id in req.prompt_ids])
# Map each token to its request ID
all_request_ids.extend([req.request_id] * len(req.prompt_ids))
data["words"] = all_tokens
data["request_ids_per_token"] = all_request_ids
# Add cache statistics if available
if hasattr(batch_processor, 'cache'):
cache = batch_processor.cache
data["paged_attention_cache"] = {
"total_blocks": cache.num_blocks,
"used_blocks": cache.num_blocks - len(cache._free_blocks),
"free_blocks": len(cache._free_blocks),
"block_size": cache.block_size,
"num_heads": cache.num_key_value_heads,
"head_dim": cache.head_dim,
"utilization": (cache.num_blocks - len(cache._free_blocks)) / cache.num_blocks if cache.num_blocks > 0 else 0.0
}
return data
class ContinuousMixin:
"""Mixin class for models to add continuous batching capabilities."""
@ -1431,6 +1489,8 @@ class ContinuousMixin:
inputs: List[List[int]],
generation_config: Optional[GenerationConfig] = None,
progress_bar: bool = True,
enable_visualizer: bool = False,
tokenizer: Optional[Tokenizer] = None,
**kwargs,
) -> List[List[int]]:
"""Generate sequences for a batch of prompts using continuous batching.
@ -1438,6 +1498,8 @@ class ContinuousMixin:
Args:
inputs: List of input token sequences (prompts)
generation_config: Optional generation configuration
progress_bar: Whether to show a progress bar during generation
visualizer: Whether to visualize the continuous batching process
**kwargs: Additional generation parameters
Returns:
@ -1454,29 +1516,37 @@ class ContinuousMixin:
results = {}
num_requests = len(inputs)
try:
from tqdm.contrib.logging import logging_redirect_tqdm
if enable_visualizer:
manager.add_requests(inputs, **kwargs)
visualizer = ContinuousBatchingVisualizer()
if tokenizer is not None:
manager.set_tokenizer(tokenizer)
manager.set_visualizer(visualizer)
visualizer.run()
else:
from tqdm.contrib.logging import logging_redirect_tqdm
with logging_redirect_tqdm([logger]):
with tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar:
manager.add_requests(inputs, **kwargs)
finished_count = 0
while finished_count < num_requests:
result = manager.get_result(timeout=1)
if result:
req_id = result.request_id
if result.status == RequestStatus.FINISHED:
results[req_id] = result
finished_count += 1
pbar.update(1)
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
break
with logging_redirect_tqdm([logger]):
with tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar:
manager.add_requests(inputs, **kwargs)
finished_count = 0
while finished_count < num_requests:
result = manager.get_result(timeout=1)
if result:
req_id = result.request_id
if result.status == RequestStatus.FINISHED:
results[req_id] = result
finished_count += 1
pbar.update(1)
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
break
except Exception as e:
logger.error(f"Error during batch generation: {e}", exc_info=True)

View File

@ -0,0 +1,513 @@
from threading import Event
from typing import Optional, List, Any, Dict
import hashlib
from rich.text import Text
from rich.segment import Segment
from rich.style import Style
from textual.app import App, ComposeResult, RenderResult
from textual.containers import Horizontal, Vertical
from textual.reactive import reactive
from textual.widget import Widget
from textual.widgets import Static, Footer, Header, RichLog
from textual.strip import Strip
from textual.scroll_view import ScrollView
from textual.geometry import Size
from textual.cache import LRUCache
import torch
# Constants for visualization
BLACK_SQUARE = ""
WHITE_SQUARE = ""
class AttentionMatrixWidget(ScrollView):
"""Widget to display attention matrix visualization with request ID-based coloring."""
DEFAULT_CSS = """
AttentionMatrixWidget {
scrollbar-size: 1 1;
}
"""
def __init__(self):
super().__init__()
# Attention matrix data
self.words: List[str] = []
self.mask: Optional[torch.Tensor] = None
self.request_ids: List[str] = [] # Request ID for each token
self.img_token: str = "<img>"
# Processed data for rendering
self._processed_mask: Optional[torch.Tensor] = None
self._max_word_length: int = 0
self.header_lines: int = 0
# Performance caches
self._segment_cache = LRUCache(maxsize=1000)
self._style_cache = LRUCache(maxsize=100)
self._data_hash: Optional[str] = None
# Color scheme for request IDs
self._color_cache = LRUCache(maxsize=100)
def set_attention_data(
self,
words: List[str],
mask: torch.Tensor,
request_ids: Optional[List[str]] = None,
img_token: str = "<img>",
**kwargs
):
"""Set new attention data and trigger re-rendering."""
# Create hash of input data for caching
data_str = f"{words}_{mask.shape}_{request_ids}_{img_token}"
new_hash = hashlib.md5(data_str.encode()).hexdigest()
# Always update if data has changed or if this is first time
if new_hash != self._data_hash or self._data_hash is None:
self._data_hash = new_hash
# Clear caches when data changes
self._segment_cache.clear()
self._style_cache.clear()
# Store raw data
self.words = words
self.mask = mask.clone()
self.request_ids = request_ids or ["unknown"] * len(words)
self.img_token = img_token
# Process the data
self._process_attention_data()
# Update virtual size and refresh
self._calculate_virtual_size()
self.refresh()
def _process_attention_data(self):
"""Process attention data for efficient rendering."""
if not self.words or self.mask is None:
return
# Convert mask to 2D
mask = self.mask.int()
if mask.ndim == 3:
mask = mask[0, :, :]
elif mask.ndim == 4:
mask = mask[0, 0, :, :]
n = len(self.words)
self._max_word_length = max(len(repr(word)) for word in self.words) if self.words else 0
self._processed_mask = mask
def _calculate_virtual_size(self):
"""Calculate the virtual size for scrolling."""
if not self.words:
virtual_height = 1
else:
virtual_height = len(self.words)
# Width based on content (word length + matrix + spacing)
if self.words:
matrix_width = len(self.words) * 2 # Each cell takes 2 chars (symbol + space)
virtual_width = self._max_word_length + 10 + matrix_width
else:
virtual_width = 50
self.virtual_size = Size(virtual_width, virtual_height)
def _get_request_id_color(self, request_id: str) -> Style:
"""Get cached color style for request ID."""
cached_style = self._color_cache.get(request_id)
if cached_style is not None:
return cached_style
# Generate consistent color for request ID
r, g, b = self._string_to_rgb_color(request_id)
color_str = f"rgb({r},{g},{b})"
style = Style(color=color_str)
self._color_cache.set(request_id, style)
return style
def _string_to_rgb_color(self, input_string: str) -> tuple[int, int, int]:
"""Generate a consistent RGB color from an input string."""
hash_value = abs(hash(input_string))
# Extract RGB components
r = (hash_value >> 16) & 0xFF
g = (hash_value >> 8) & 0xFF
b = hash_value & 0xFF
# Ensure colors are bright enough to be visible
r = max(64, min(255, r))
g = max(64, min(255, g))
b = max(64, min(255, b))
return (r, g, b)
def render_line(self, y: int) -> Strip:
"""Render a single line using Line API for performance."""
# Early return for empty data
if not self.words or self._processed_mask is None:
return Strip([Segment("No attention data to display", Style(color="gray50"))])
# Get the actual content line based on viewport position
content_line = y
# Use a lighter caching approach - cache by content line and data hash only
# Don't cache if we don't have stable data to avoid scroll interference
cache_key = f"line_{content_line}_{self._data_hash}" if self._data_hash else None
cached_strip = None
if cache_key:
cached_strip = self._segment_cache.get(cache_key)
if cached_strip is not None:
return cached_strip
n = len(self.words)
# Render different types of lines based on content position
if content_line == 0:
strip = self._render_title_line()
elif content_line < n:
# Matrix row
strip = self._render_matrix_row(content_line)
else:
# Empty line
strip = Strip([Segment("")])
# Cache the result only if we have a valid cache key
if cache_key:
self._segment_cache.set(cache_key, strip)
return strip
def _render_title_line(self) -> Strip:
"""Render the title line."""
title = f"Attention Matrix ({len(self.words)}x{len(self.words)})"
return Strip([Segment(title, Style(bold=True))])
def _render_matrix_row(self, row_idx: int) -> Strip:
"""Render a single matrix row with request ID-based coloring."""
if row_idx >= len(self.words) or self._processed_mask is None:
return Strip([Segment("")])
word = self.words[row_idx]
word_repr = repr(word).ljust(self._max_word_length)
segments = []
# Row label (word) - colored by request ID
row_request_id = self.request_ids[row_idx] if row_idx < len(self.request_ids) else "unknown"
row_style = self._get_request_id_color(row_request_id)
segments.append(Segment(word_repr, row_style))
segments.append(Segment(f": {str(row_idx).rjust(2)} ", Style()))
# Matrix cells
for col_idx in range(len(self.words)):
mask_value = self._processed_mask[row_idx, col_idx].item()
col_request_id = self.request_ids[col_idx] if col_idx < len(self.request_ids) else "unknown"
if mask_value == 1: # Attended - use request ID color
symbol = BLACK_SQUARE
# Use the color of the target request ID (column)
style = self._get_request_id_color(col_request_id)
else: # Not attended
symbol = WHITE_SQUARE
style = Style(color="gray50")
segments.append(Segment(symbol, style))
segments.append(Segment(" ", Style())) # Spacing
return Strip(segments)
class BatchContentsWidget(RichLog):
"""Widget to display batch contents with request ID coloring using RichLog."""
DEFAULT_CSS = """
BatchContentsWidget {
height: 35%;
}
"""
def __init__(self, **kwargs):
super().__init__(
auto_scroll=False,
markup=True,
wrap=True,
**kwargs
)
def set_batch_contents(self, batch_contents: List[Dict[str, Any]]):
"""Set batch contents and update display."""
# Clear existing content
self.clear()
if not batch_contents:
self.write("Batch contents will be displayed here.")
return
# Write each token info as a separate line
for token_info in batch_contents:
request_id = token_info.get("request_id", "unknown")
color = self._get_color_for_request(request_id)
# Create Rich Text for this token
token_text = Text()
token_text.append(f"[{request_id}] ", style=f"bold {color}")
if "decoded" in token_info:
token_text.append(token_info["decoded"], style=color)
elif "tokens" in token_info:
tokens_str = " ".join(map(str, token_info["tokens"]))
token_text.append(tokens_str, style=color)
else:
token_text.append("(no content)", style=color)
# Write the token info to the log
self.write(token_text)
def _get_color_for_request(self, request_id: str) -> str:
"""Get color for request ID - delegates to parent app."""
app = self.app
if hasattr(app, '_get_cached_color'):
return app._get_cached_color(request_id)
return "white" # fallback
class CacheWidget(Widget):
"""Widget to display PagedAttentionCache contents and statistics."""
cache_info: reactive[Text] = reactive(Text("PagedAttentionCache: waiting for data..."))
def render(self) -> RenderResult:
return self.cache_info
class ContinuousBatchingVisualizer(App):
"""Main application for visualizing continuous batching with request ID-based coloring."""
# Bind 'q' key to quit action
BINDINGS = [("n", "next", "Next"), ("q", "quit", "Quit")]
CSS = """
/* Top row widgets */
#top-row {
height: 65%;
}
AttentionMatrixWidget {
width: 50%;
border: solid #87CEEB;
margin: 0;
scrollbar-size: 1 1;
}
CacheWidget {
width: 50%;
border: solid #98FB98;
margin: 0;
}
/* Bottom widget */
BatchContentsWidget {
width: 100%;
border: solid #FFB6C1;
margin: 0;
}
.content {
padding: 1;
background: $surface;
}
"""
def __init__(self):
super().__init__()
self.exited = False
self.wait_event = Event()
self._color_cache = LRUCache(maxsize=1024)
self._pending_attention_data = None
def compose(self) -> ComposeResult:
"""Compose the app layout."""
yield Header()
with Vertical():
with Horizontal(id="top-row"):
yield AttentionMatrixWidget()
yield CacheWidget()
yield BatchContentsWidget()
yield Footer()
def on_mount(self) -> None:
"""Called when the app is mounted and widgets are available."""
# If we have pending attention data, apply it now
if self._pending_attention_data:
self.set_timer(0.1, self._apply_pending_attention_data)
def _apply_pending_attention_data(self) -> None:
"""Apply any pending attention data if widgets are ready."""
if self._pending_attention_data:
try:
attention_widget = self.query_one(AttentionMatrixWidget)
attention_widget.set_attention_data(**self._pending_attention_data)
self._pending_attention_data = None
except Exception:
# Try again later if widget still not ready
self.set_timer(0.1, self._apply_pending_attention_data)
def action_quit(self) -> None:
"""Action to quit the application."""
self.wait_event.set()
self.exited = True
self.exit()
def action_next(self) -> None:
"""Action to update visualizations with new data."""
self.wait_event.set()
def draw(self, data: Dict[str, Any]):
"""
Update all widgets with new data from continuous batching.
Expected data format:
{
'batch_contents': [
{
'request_id': str,
'tokens': List[int] or 'decoded': str,
'decoded_tokens': List[str] # optional
}
],
'attention_mask': torch.Tensor,
'words': List[str], # tokens as strings
'request_ids_per_token': List[str] # request ID for each token
}
"""
if self.exited:
return
try:
# Update batch contents widget
self._update_batch_contents(data.get('batch_contents', []))
# Update attention matrix widget
self._update_attention_matrix(data)
# Update cache info
self._update_cache_info(data)
except Exception as e:
# Display error in cache widget
cache_widget = self.query_one(CacheWidget)
cache_widget.cache_info = Text(f"Error: {str(e)}", style="red")
def _update_batch_contents(self, batch_contents: List[Dict[str, Any]]):
"""Update the batch contents widget with scrollable display."""
try:
batch_widget = self.query_one(BatchContentsWidget)
batch_widget.set_batch_contents(batch_contents)
except Exception:
pass # Widget not ready yet
def _update_attention_matrix(self, data: Dict[str, Any]):
"""Update the attention matrix widget."""
words = data.get('words', [])
attention_mask = data.get('attention_mask')
request_ids = data.get('request_ids_per_token', [])
if words and attention_mask is not None:
try:
attention_widget = self.query_one(AttentionMatrixWidget)
attention_widget.set_attention_data(
words=words,
mask=attention_mask,
request_ids=request_ids
)
except Exception as e:
# If we can't find the widget, store the data and try again later
self._pending_attention_data = {
'words': words,
'mask': attention_mask,
'request_ids': request_ids
}
# Try again in a bit
self.set_timer(0.1, self._apply_pending_attention_data)
def _update_cache_info(self, data: Dict[str, Any]):
"""Update cache information display."""
cache_data = data.get('paged_attention_cache', {})
# Format PagedAttentionCache stats
cache_lines = ["[bold green]PagedAttentionCache[/bold green]"]
if cache_data:
# Display key PagedAttentionCache metrics
cache_lines.extend([
f"Total blocks: {cache_data.get('total_blocks', 0)}",
f"Used blocks: {cache_data.get('used_blocks', 0)}",
f"Free blocks: {cache_data.get('free_blocks', 0)}",
f"Block size: {cache_data.get('block_size', 'Unknown')}",
f"Num heads: {cache_data.get('num_heads', 'Unknown')}",
f"Head dim: {cache_data.get('head_dim', 'Unknown')}",
])
# Show utilization if available
if 'utilization' in cache_data:
cache_lines.append(f"Utilization: {cache_data['utilization']:.1%}")
else:
cache_lines.append("No PagedAttentionCache data available")
cache_info = Text.from_markup("\n".join(cache_lines))
try:
cache_widget = self.query_one(CacheWidget)
cache_widget.cache_info = cache_info
except Exception:
# Widget not ready yet, just show basic info
try:
cache_widget = self.query_one(CacheWidget)
cache_info = Text("Cache info loading...", style="yellow")
cache_widget.cache_info = cache_info
except Exception:
pass # CacheWidget not ready either
def _get_cached_color(self, request_id: str) -> str:
"""Get cached color for request ID (same as attention matrix)."""
cached_color = self._color_cache.get(request_id)
if cached_color is not None:
return cached_color
r, g, b = self._string_to_rgb_color(request_id)
cached_color = f"rgb({r},{g},{b})"
self._color_cache.set(request_id, cached_color)
return cached_color
def _string_to_rgb_color(self, input_string: str) -> tuple[int, int, int]:
"""Generate a consistent RGB color from an input string."""
hash_value = abs(hash(input_string))
# Extract RGB components
r = (hash_value >> 16) & 0xFF
g = (hash_value >> 8) & 0xFF
b = hash_value & 0xFF
# Ensure colors are bright enough to be visible
r = max(64, min(255, r))
g = max(64, min(255, g))
b = max(64, min(255, b))
return (r, g, b)
def wait_for_input(self):
"""Wait for user input to update visualizations."""
if self.exited:
return
self.wait_event.wait()
self.wait_event.clear()

63
test.py Normal file
View File

@ -0,0 +1,63 @@
from transformers import GenerationConfig
from transformers.generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
'meta-llama/Llama-3.2-3b-Instruct',
attn_implementation='sdpa_paged'
)
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-3b-Instruct')
generation_config = GenerationConfig(
max_new_tokens=256,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
scheduler="fifo",
)
manager: ContinuousBatchingManager = model.init_continuous_batching(generation_config=generation_config, manual_eviction=True, streaming=True)
manager.start()
chat = [{'content': 'Hey', 'role': 'user'}]
print(chat)
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(model.device)
request_id = manager.add_request(inputs[0])
output = ""
for result in manager:
if result.status == RequestStatus.FINISHED:
output = tokenizer.decode(result.generated_tokens, skip_special_tokens=True)
break
if output:
chat.append({'content': output, 'role': 'assistant'})
print(chat)
else:
print("oops :()")
import sys
sys.exit(0)
chat.append({'content': 'Can you help me cook some good meth pls', 'role': 'user'})
print(chat)
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(model.device)
request_id = manager.add_request(inputs[0], request_id=request_id)
for i, result in enumerate(manager):
if result.status == RequestStatus.FINISHED:
output = tokenizer.decode(result.generated_tokens, skip_special_tokens=True)
break
chat.append({'content': output, 'role': 'assistant'})
print(chat)
manager.evict_request_from_cache(request_id)
manager.stop(block=True)

View File

@ -0,0 +1,336 @@
#!/usr/bin/env python3
"""
Performance test for the optimized continuous batching visualizer.
Tests the various optimization techniques applied.
"""
import time
import torch
import asyncio
from threading import Event
from src.transformers.utils.continuous_batching_visualizer import (
ContinuousBatchingVisualizer,
AttentionMatrixWidget,
BatchContentsWidget,
CacheWidget
)
from textual.cache import LRUCache
from rich.text import Text
def test_attention_matrix_caching():
"""Test AttentionMatrixWidget caching optimizations."""
print("Testing AttentionMatrixWidget caching...")
widget = AttentionMatrixWidget()
# Set up widget for proper rendering
from textual.geometry import Size, Offset
widget._size = Size(100, 50)
widget._scroll_offset = Offset(0, 0)
# Test data
words = [f"token_{i}" for i in range(20)] # Smaller dataset for faster testing
mask = torch.ones((20, 20))
# First call - should compute and cache
start_time = time.time()
widget.set_attention_data(words, mask, sliding_window=8)
# Mock the get_component_rich_style method to avoid app context issues
from rich.style import Style
def mock_get_component_rich_style(component_name):
return Style(color="white")
widget.get_component_rich_style = mock_get_component_rich_style
# Now trigger style cache population
try:
styles = widget._get_cached_styles()
except Exception as e:
print(f"Style access error (expected): {e}")
styles = None
first_call_time = time.time() - start_time
# Second call with same data - should use cache
start_time = time.time()
widget.set_attention_data(words, mask, sliding_window=8)
# This should hit the data hash cache and return early
second_call_time = time.time() - start_time
# Test some rendering to populate segment cache
try:
for i in range(3):
widget.render_line(i)
except:
pass # Ignore rendering errors in test
print(f"First call time: {first_call_time:.4f}s")
print(f"Second call time: {second_call_time:.4f}s")
speedup = first_call_time / max(second_call_time, 0.0001)
print(f"Cache hit speedup: {speedup:.2f}x")
# Test cache sizes
style_cache_size = len(widget._style_cache.keys())
segment_cache_size = len(widget._segment_cache.keys())
print(f"Style cache size: {style_cache_size}")
print(f"Segment cache size: {segment_cache_size}")
# More lenient test - should show some improvement and have caches
return (second_call_time < first_call_time * 0.8 and # Some speedup
style_cache_size > 0) # Style cache populated
def test_line_rendering_performance():
"""Test line rendering performance with Line API."""
print("\nTesting line rendering performance...")
widget = AttentionMatrixWidget()
# Large dataset
words = [f"token_{i}" for i in range(50)] # Smaller dataset for testing
mask = torch.randint(0, 2, (50, 50))
widget.set_attention_data(words, mask, sliding_window=16)
# Set up widget for rendering by simulating proper initialization
from textual.geometry import Size, Offset
# Use private attributes to simulate proper widget state
widget._size = Size(100, 50)
widget._scroll_offset = Offset(0, 0)
widget._calculate_virtual_size()
# Test rendering multiple lines without cache dependencies
start_time = time.time()
lines_rendered = 0
for i in range(min(20, len(words) + widget.header_lines)): # Render available lines
try:
# Create a simple strip for testing without full widget dependencies
if widget.words and widget._processed_mask is not None:
# Just test that the rendering logic works
n = len(widget.words)
styles = {
'green': None, 'yellow': None, 'black': None, 'white': None
}
# Test header and matrix row creation logic
if i < widget.header_lines:
# Test header rendering
pass
elif i - widget.header_lines < n:
# Test matrix row rendering
pass
lines_rendered += 1
else:
lines_rendered += 1
except Exception as e:
print(f"Error in line {i}: {e}")
break
line_render_time = time.time() - start_time
print(f"Rendered {lines_rendered} lines in: {line_render_time:.4f}s")
print(f"Average per line: {line_render_time / max(lines_rendered, 1):.6f}s")
return line_render_time < 1.0 and lines_rendered > 0 # Should be fast and render some lines
def test_batch_contents_caching():
"""Test BatchContentsWidget caching."""
print("\nTesting BatchContentsWidget caching...")
widget = BatchContentsWidget()
# Test data
test_text = Text("Sample batch contents with styling")
test_text.stylize("bold red", 0, 6)
# First render
start_time = time.time()
widget.tokens_to_display = test_text
result1 = widget.render()
first_render_time = time.time() - start_time
# Second render with same content - should use cache
start_time = time.time()
result2 = widget.render()
second_render_time = time.time() - start_time
print(f"First render time: {first_render_time:.6f}s")
print(f"Second render time: {second_render_time:.6f}s")
print(f"Cache size: {len(widget._render_cache.keys())}")
return result1 == result2 and len(widget._render_cache.keys()) > 0
def test_color_caching():
"""Test color generation caching."""
print("\nTesting color caching...")
app = ContinuousBatchingVisualizer()
# Test repeated color generation
request_ids = [f"request_{i}" for i in range(10)] * 5 # 50 calls, 10 unique
start_time = time.time()
colors = []
for req_id in request_ids:
color = app._get_cached_color(req_id)
colors.append(color)
total_time = time.time() - start_time
print(f"Generated 50 colors (10 unique) in: {total_time:.4f}s")
print(f"Color cache size: {len(app._color_cache.keys())}")
print(f"Cache hit rate: {(50 - 10) / 50 * 100:.1f}%")
# Verify color consistency
test_color_1 = app._get_cached_color("test_request")
test_color_2 = app._get_cached_color("test_request")
return test_color_1 == test_color_2 and len(app._color_cache.keys()) == 11
def test_cache_widget_optimization():
"""Test CacheWidget static content optimization."""
print("\nTesting CacheWidget optimization...")
widget = CacheWidget()
# Test cache info updates
cache_info1 = {"cache_size": 100, "hit_rate": 0.85}
cache_info2 = {"cache_size": 100, "hit_rate": 0.85} # Same data
cache_info3 = {"cache_size": 120, "hit_rate": 0.90} # Different data
start_time = time.time()
widget.update_cache_info(cache_info1)
first_update_time = time.time() - start_time
start_time = time.time()
widget.update_cache_info(cache_info2) # Should be fast (no change)
second_update_time = time.time() - start_time
start_time = time.time()
widget.update_cache_info(cache_info3) # Should update
third_update_time = time.time() - start_time
print(f"First update: {first_update_time:.6f}s")
print(f"Second update (no change): {second_update_time:.6f}s")
print(f"Third update (changed): {third_update_time:.6f}s")
print(f"Display cache size: {len(widget._display_cache.keys())}")
return second_update_time < first_update_time and len(widget._display_cache.keys()) > 0
async def test_worker_optimization():
"""Test background worker for data processing."""
print("\nTesting worker optimization...")
app = ContinuousBatchingVisualizer()
# Large test data
batch_contents = []
for i in range(50):
batch_contents.append({
"request_id": f"req_{i % 10}", # 10 unique request IDs
"decoded": f"Sample text for request {i} with some longer content",
"decoded_tokens": [f"token_{j}" for j in range(20)]
})
attention_mask = torch.randint(0, 2, (1000, 1000)) # Large attention mask
test_data = {
"batch_contents": batch_contents,
"attention_mask": attention_mask,
"sliding_window": 128,
"token_type_ids": [1] * 1000,
"image_seq_length": 576
}
# Process data (test the async processing part directly)
start_time = time.time()
processed_data = await app._process_data_async(test_data)
processing_time = time.time() - start_time
print(f"Processed large dataset in: {processing_time:.4f}s")
print(f"Data cache size: {len(app._data_processing_cache.keys())}")
print(f"Color cache size: {len(app._color_cache.keys())}")
# Test cache hit
start_time = time.time()
processed_data_cached = await app._process_data_async(test_data)
cached_processing_time = time.time() - start_time
print(f"Cached processing time: {cached_processing_time:.6f}s")
print(f"Cache speedup: {processing_time / max(cached_processing_time, 0.000001):.2f}x")
# Verify that processed data is equivalent
data_matches = (processed_data['colored_text'] == processed_data_cached['colored_text'])
cache_working = len(app._data_processing_cache.keys()) > 0
return (cached_processing_time < processing_time / 2 and # Should be at least 2x faster
data_matches and cache_working) # Data should match and cache should work
def test_memory_efficiency():
"""Test memory efficiency of caching systems."""
print("\nTesting memory efficiency...")
# Test LRU cache eviction
cache = LRUCache(maxsize=5)
# Fill cache
for i in range(10):
cache.set(f"key_{i}", f"value_{i}")
# Should only have 5 items (most recent)
keys = list(cache.keys())
print(f"Cache keys after filling with 10 items (maxsize=5): {keys}")
print(f"Cache size: {len(keys)}")
# Test that old items were evicted
has_old_items = any(f"key_{i}" in keys for i in range(5))
has_new_items = any(f"key_{i}" in keys for i in range(5, 10))
print(f"Has old items (0-4): {has_old_items}")
print(f"Has new items (5-9): {has_new_items}")
return len(keys) == 5 and not has_old_items and has_new_items
async def main():
"""Run all performance tests."""
print("=== Continuous Batching Visualizer Performance Tests ===\n")
tests = [
test_attention_matrix_caching,
test_line_rendering_performance,
test_batch_contents_caching,
test_color_caching,
test_cache_widget_optimization,
test_worker_optimization,
test_memory_efficiency
]
results = []
for test in tests:
try:
if asyncio.iscoroutinefunction(test):
result = await test()
else:
result = test()
results.append(result)
print(f"{test.__name__}: {'PASS' if result else 'FAIL'}")
except Exception as e:
print(f"{test.__name__}: ERROR - {e}")
results.append(False)
print()
# Summary
passed = sum(results)
total = len(results)
print(f"=== Summary: {passed}/{total} tests passed ===")
if passed == total:
print("🎉 All performance optimizations working correctly!")
else:
print("⚠️ Some optimizations need attention.")
return passed == total
if __name__ == "__main__":
asyncio.run(main())