Streaming should be handled at the request-level rather than at the istance level (#41444)

* Streaming should be handled at the request-level rather than at the instance level

* Add tests

* Require torch GPU
This commit is contained in:
Lysandre Debut
2025-10-10 10:24:55 +02:00
committed by GitHub
parent b28902c86b
commit 17c31a98ac
3 changed files with 106 additions and 13 deletions

View File

@ -18,7 +18,7 @@ from typing import Optional
import torch
from parameterized import parameterized
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
@ -337,6 +337,102 @@ class ContinuousBatchingTest(unittest.TestCase):
manager = model.init_continuous_batching()
assert "paged|eager" == manager.model.config._attn_implementation
@require_torch_gpu
def test_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
manager = model.init_continuous_batching()
manager.logit_processor = LogitsProcessorList()
manager.start()
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
model.device
)[0]
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
# In streaming mode, the total number of generated tokens is incremented by 1 on each iteration
chunk_1 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_1.generated_tokens), 1)
chunk_2 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_2.generated_tokens), 2)
chunk_3 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_3.generated_tokens), 3)
manager.stop(block=True)
@require_torch_gpu
def test_non_streaming_request(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
manager = model.init_continuous_batching()
manager.logit_processor = LogitsProcessorList()
manager.start()
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
model.device
)[0]
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
chunk = next(manager.request_id_iter(request_id))
# In non-streaming mode, the total number of generated tokens is equal to the max new tokens
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
manager.stop(block=True)
@require_torch_gpu
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
manager = model.init_continuous_batching()
manager.logit_processor = LogitsProcessorList()
manager.start()
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
model.device
)[0]
# Non-streaming request
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
chunk = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
# Streaming request works afterward
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
chunk_1 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_1.generated_tokens), 1)
chunk_2 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_2.generated_tokens), 2)
chunk_3 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_3.generated_tokens), 3)
manager.stop(block=True)
# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
# inverted on CUDA. On AMD they do fine.