mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Streaming should be handled at the request-level rather than at the istance level (#41444)
* Streaming should be handled at the request-level rather than at the instance level * Add tests * Require torch GPU
This commit is contained in:
@ -18,7 +18,7 @@ from typing import Optional
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
|
||||
from transformers.generation.continuous_batching.cache import group_layers_by_attn_type
|
||||
from transformers.generation.continuous_batching.continuous_api import build_attention_mask
|
||||
from transformers.testing_utils import Expectations, require_kernels, require_torch_gpu, slow
|
||||
@ -337,6 +337,102 @@ class ContinuousBatchingTest(unittest.TestCase):
|
||||
manager = model.init_continuous_batching()
|
||||
assert "paged|eager" == manager.model.config._attn_implementation
|
||||
|
||||
@require_torch_gpu
|
||||
def test_streaming_request(self) -> None:
|
||||
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
max_new_tokens = 3
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
manager = model.init_continuous_batching()
|
||||
manager.logit_processor = LogitsProcessorList()
|
||||
manager.start()
|
||||
|
||||
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)[0]
|
||||
|
||||
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
|
||||
|
||||
# In streaming mode, the total number of generated tokens is incremented by 1 on each iteration
|
||||
chunk_1 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_1.generated_tokens), 1)
|
||||
|
||||
chunk_2 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_2.generated_tokens), 2)
|
||||
|
||||
chunk_3 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_3.generated_tokens), 3)
|
||||
|
||||
manager.stop(block=True)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_non_streaming_request(self) -> None:
|
||||
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
max_new_tokens = 3
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
manager = model.init_continuous_batching()
|
||||
manager.logit_processor = LogitsProcessorList()
|
||||
manager.start()
|
||||
|
||||
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)[0]
|
||||
|
||||
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
|
||||
|
||||
chunk = next(manager.request_id_iter(request_id))
|
||||
|
||||
# In non-streaming mode, the total number of generated tokens is equal to the max new tokens
|
||||
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
|
||||
|
||||
manager.stop(block=True)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
|
||||
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
max_new_tokens = 3
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
|
||||
|
||||
manager = model.init_continuous_batching()
|
||||
manager.logit_processor = LogitsProcessorList()
|
||||
manager.start()
|
||||
|
||||
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)[0]
|
||||
|
||||
# Non-streaming request
|
||||
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
|
||||
chunk = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
|
||||
|
||||
# Streaming request works afterward
|
||||
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
|
||||
|
||||
chunk_1 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_1.generated_tokens), 1)
|
||||
|
||||
chunk_2 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_2.generated_tokens), 2)
|
||||
|
||||
chunk_3 = next(manager.request_id_iter(request_id))
|
||||
self.assertEqual(len(chunk_3.generated_tokens), 3)
|
||||
|
||||
manager.stop(block=True)
|
||||
|
||||
|
||||
# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
|
||||
# inverted on CUDA. On AMD they do fine.
|
||||
|
Reference in New Issue
Block a user