mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-26 19:14:33 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			v0.11.1rc1
			...
			remove-met
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 69dbcc56bf | 
| @ -208,22 +208,6 @@ steps: | ||||
|   commands: | ||||
|   - pytest -v -s distributed/test_eplb_execute.py | ||||
|  | ||||
| - label: Metrics, Tracing Test # 10min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   num_gpus: 2 | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/metrics | ||||
|   - tests/tracing | ||||
|   commands: | ||||
|   - pytest -v -s metrics | ||||
|   - "pip install \ | ||||
|       'opentelemetry-sdk>=1.26.0' \ | ||||
|       'opentelemetry-api>=1.26.0' \ | ||||
|       'opentelemetry-exporter-otlp>=1.26.0' \ | ||||
|       'opentelemetry-semantic-conventions-ai>=0.4.1'" | ||||
|   - pytest -v -s tracing | ||||
|  | ||||
| ##### fast check tests  ##### | ||||
| #####  1 GPU test  ##### | ||||
|  | ||||
|  | ||||
| @ -1,268 +0,0 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import pytest | ||||
| import ray | ||||
| from prometheus_client import REGISTRY | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm import EngineArgs, LLMEngine | ||||
| from vllm.engine.arg_utils import AsyncEngineArgs | ||||
| from vllm.engine.async_llm_engine import AsyncLLMEngine | ||||
| from vllm.engine.metrics import RayPrometheusStatLogger | ||||
| from vllm.sampling_params import SamplingParams | ||||
| from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="function", autouse=True) | ||||
| def use_v0_only(monkeypatch): | ||||
|     """ | ||||
|     This module tests V0 internals, so set VLLM_USE_V1=0. | ||||
|     """ | ||||
|     monkeypatch.setenv('VLLM_USE_V1', '0') | ||||
|  | ||||
|  | ||||
| MODELS = [ | ||||
|     "distilbert/distilgpt2", | ||||
| ] | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["float"]) | ||||
| @pytest.mark.parametrize("max_tokens", [128]) | ||||
| def test_metric_counter_prompt_tokens( | ||||
|     vllm_runner, | ||||
|     example_prompts, | ||||
|     model: str, | ||||
|     dtype: str, | ||||
|     max_tokens: int, | ||||
| ) -> None: | ||||
|     with vllm_runner(model, | ||||
|                      dtype=dtype, | ||||
|                      disable_log_stats=False, | ||||
|                      gpu_memory_utilization=0.4) as vllm_model: | ||||
|         tokenizer = vllm_model.llm.get_tokenizer() | ||||
|         prompt_token_counts = [ | ||||
|             len(tokenizer.encode(p)) for p in example_prompts | ||||
|         ] | ||||
|         # This test needs at least 2 prompts in a batch of different lengths to | ||||
|         # verify their token count is correct despite padding. | ||||
|         assert len(example_prompts) > 1, "at least 2 prompts are required" | ||||
|         assert prompt_token_counts[0] != prompt_token_counts[1], ( | ||||
|             "prompts of different lengths are required") | ||||
|         vllm_prompt_token_count = sum(prompt_token_counts) | ||||
|  | ||||
|         _ = vllm_model.generate_greedy(example_prompts, max_tokens) | ||||
|         stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] | ||||
|         metric_count = stat_logger.metrics.counter_prompt_tokens.labels( | ||||
|             **stat_logger.labels)._value.get() | ||||
|  | ||||
|     assert vllm_prompt_token_count == metric_count, ( | ||||
|         f"prompt token count: {vllm_prompt_token_count!r}\n" | ||||
|         f"metric: {metric_count!r}") | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["float"]) | ||||
| @pytest.mark.parametrize("max_tokens", [128]) | ||||
| def test_metric_counter_generation_tokens( | ||||
|     vllm_runner, | ||||
|     example_prompts, | ||||
|     model: str, | ||||
|     dtype: str, | ||||
|     max_tokens: int, | ||||
| ) -> None: | ||||
|     with vllm_runner(model, | ||||
|                      dtype=dtype, | ||||
|                      disable_log_stats=False, | ||||
|                      gpu_memory_utilization=0.4) as vllm_model: | ||||
|         vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) | ||||
|         tokenizer = vllm_model.llm.get_tokenizer() | ||||
|         stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] | ||||
|         metric_count = stat_logger.metrics.counter_generation_tokens.labels( | ||||
|             **stat_logger.labels)._value.get() | ||||
|         vllm_generation_count = 0 | ||||
|         for i in range(len(example_prompts)): | ||||
|             vllm_output_ids, vllm_output_str = vllm_outputs[i] | ||||
|             prompt_ids = tokenizer.encode(example_prompts[i]) | ||||
|             # vllm_output_ids contains both prompt tokens and generation tokens. | ||||
|             # We're interested only in the count of the generation tokens. | ||||
|             vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) | ||||
|  | ||||
|     assert vllm_generation_count == metric_count, ( | ||||
|         f"generation token count: {vllm_generation_count!r}\n" | ||||
|         f"metric: {metric_count!r}") | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["float"]) | ||||
| @pytest.mark.parametrize( | ||||
|     "served_model_name", | ||||
|     [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) | ||||
| def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, | ||||
|                                    served_model_name: list[str]) -> None: | ||||
|     with vllm_runner(model, | ||||
|                      dtype=dtype, | ||||
|                      disable_log_stats=False, | ||||
|                      gpu_memory_utilization=0.3, | ||||
|                      served_model_name=served_model_name) as vllm_model: | ||||
|         stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] | ||||
|         metrics_tag_content = stat_logger.labels["model_name"] | ||||
|  | ||||
|     if envs.VLLM_CI_USE_S3: | ||||
|         model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" | ||||
|     if served_model_name is None or served_model_name == []: | ||||
|         assert metrics_tag_content == model, ( | ||||
|             f"Metrics tag model_name is wrong! expect: {model!r}\n" | ||||
|             f"actual: {metrics_tag_content!r}") | ||||
|     else: | ||||
|         assert metrics_tag_content == served_model_name[0], ( | ||||
|             f"Metrics tag model_name is wrong! expect: " | ||||
|             f"{served_model_name[0]!r}\n" | ||||
|             f"actual: {metrics_tag_content!r}") | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["half"]) | ||||
| @pytest.mark.parametrize("max_tokens", [4]) | ||||
| @pytest.mark.parametrize("disable_log_stats", [True, False]) | ||||
| @pytest.mark.asyncio | ||||
| async def test_async_engine_log_metrics_regression( | ||||
|     example_prompts, | ||||
|     model: str, | ||||
|     dtype: str, | ||||
|     max_tokens: int, | ||||
|     disable_log_stats: bool, | ||||
| ) -> None: | ||||
|     """ | ||||
|     Regression test ensuring async engine generates metrics | ||||
|     when disable_log_stats=False | ||||
|     (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) | ||||
|     """ | ||||
|     engine_args = AsyncEngineArgs( | ||||
|         model=model, | ||||
|         dtype=dtype, | ||||
|         disable_log_stats=disable_log_stats, | ||||
|     ) | ||||
|     async_engine = AsyncLLMEngine.from_engine_args(engine_args) | ||||
|     for i, prompt in enumerate(example_prompts): | ||||
|         results = async_engine.generate( | ||||
|             prompt, | ||||
|             SamplingParams(max_tokens=max_tokens), | ||||
|             f"request-id-{i}", | ||||
|         ) | ||||
|         # Exhaust the async iterator to make the async engine work | ||||
|         async for _ in results: | ||||
|             pass | ||||
|  | ||||
|     assert_metrics(model, async_engine.engine, disable_log_stats, | ||||
|                    len(example_prompts)) | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["half"]) | ||||
| @pytest.mark.parametrize("max_tokens", [4]) | ||||
| @pytest.mark.parametrize("disable_log_stats", [True, False]) | ||||
| def test_engine_log_metrics_regression( | ||||
|     example_prompts, | ||||
|     model: str, | ||||
|     dtype: str, | ||||
|     max_tokens: int, | ||||
|     disable_log_stats: bool, | ||||
| ) -> None: | ||||
|     engine_args = EngineArgs( | ||||
|         model=model, | ||||
|         dtype=dtype, | ||||
|         disable_log_stats=disable_log_stats, | ||||
|     ) | ||||
|     engine = LLMEngine.from_engine_args(engine_args) | ||||
|     for i, prompt in enumerate(example_prompts): | ||||
|         engine.add_request( | ||||
|             f"request-id-{i}", | ||||
|             prompt, | ||||
|             SamplingParams(max_tokens=max_tokens), | ||||
|         ) | ||||
|     while engine.has_unfinished_requests(): | ||||
|         engine.step() | ||||
|  | ||||
|     if envs.VLLM_CI_USE_S3: | ||||
|         model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}" | ||||
|     assert_metrics(model, engine, disable_log_stats, len(example_prompts)) | ||||
|  | ||||
|  | ||||
| def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, | ||||
|                    num_requests: int) -> None: | ||||
|     if disable_log_stats: | ||||
|         with pytest.raises(AttributeError): | ||||
|             _ = engine.stat_loggers | ||||
|     else: | ||||
|         assert (engine.stat_loggers | ||||
|                 is not None), "engine.stat_loggers should be set" | ||||
|         # Ensure the count bucket of request-level histogram metrics matches | ||||
|         # the number of requests as a simple sanity check to ensure metrics are | ||||
|         # generated | ||||
|         labels = {'model_name': model} | ||||
|         request_histogram_metrics = [ | ||||
|             "vllm:e2e_request_latency_seconds", | ||||
|             "vllm:request_prompt_tokens", | ||||
|             "vllm:request_generation_tokens", | ||||
|             "vllm:request_params_n", | ||||
|             "vllm:request_params_max_tokens", | ||||
|         ] | ||||
|         for metric_name in request_histogram_metrics: | ||||
|             metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", | ||||
|                                                      labels) | ||||
|             assert ( | ||||
|                 metric_value == num_requests), "Metrics should be collected" | ||||
|  | ||||
|  | ||||
| @pytest.mark.parametrize("model", MODELS) | ||||
| @pytest.mark.parametrize("dtype", ["half"]) | ||||
| @pytest.mark.parametrize("max_tokens", [16]) | ||||
| def test_engine_log_metrics_ray( | ||||
|     example_prompts, | ||||
|     model: str, | ||||
|     dtype: str, | ||||
|     max_tokens: int, | ||||
| ) -> None: | ||||
|     # This test is quite weak - it only checks that we can use | ||||
|     # RayPrometheusStatLogger without exceptions. | ||||
|     # Checking whether the metrics are actually emitted is unfortunately | ||||
|     # non-trivial. | ||||
|  | ||||
|     # We have to run in a Ray task for Ray metrics to be emitted correctly | ||||
|     @ray.remote(num_gpus=1) | ||||
|     def _inner(): | ||||
|  | ||||
|         class _RayPrometheusStatLogger(RayPrometheusStatLogger): | ||||
|  | ||||
|             def __init__(self, *args, **kwargs): | ||||
|                 self._i = 0 | ||||
|                 super().__init__(*args, **kwargs) | ||||
|  | ||||
|             def log(self, *args, **kwargs): | ||||
|                 self._i += 1 | ||||
|                 return super().log(*args, **kwargs) | ||||
|  | ||||
|         engine_args = EngineArgs( | ||||
|             model=model, | ||||
|             dtype=dtype, | ||||
|             disable_log_stats=False, | ||||
|         ) | ||||
|         engine = LLMEngine.from_engine_args(engine_args) | ||||
|         logger = _RayPrometheusStatLogger( | ||||
|             local_interval=0.5, | ||||
|             labels=dict(model_name=engine.model_config.served_model_name), | ||||
|             vllm_config=engine.vllm_config) | ||||
|         engine.add_logger("ray", logger) | ||||
|         for i, prompt in enumerate(example_prompts): | ||||
|             engine.add_request( | ||||
|                 f"request-id-{i}", | ||||
|                 prompt, | ||||
|                 SamplingParams(max_tokens=max_tokens), | ||||
|             ) | ||||
|         while engine.has_unfinished_requests(): | ||||
|             engine.step() | ||||
|         assert logger._i > 0, ".log must be called at least once" | ||||
|  | ||||
|     ray.get(_inner.remote()) | ||||
| @ -1,237 +0,0 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| # ruff: noqa | ||||
| # type: ignore | ||||
| from __future__ import annotations | ||||
|  | ||||
| import threading | ||||
| from collections.abc import Iterable | ||||
| from concurrent import futures | ||||
| from typing import Callable, Generator, Literal | ||||
|  | ||||
| import grpc | ||||
| import pytest | ||||
| from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ( | ||||
|     ExportTraceServiceResponse) | ||||
| from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import ( | ||||
|     TraceServiceServicer, add_TraceServiceServicer_to_server) | ||||
| from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue | ||||
| from opentelemetry.sdk.environment_variables import ( | ||||
|     OTEL_EXPORTER_OTLP_TRACES_INSECURE) | ||||
|  | ||||
| from vllm import LLM, SamplingParams | ||||
| from vllm.tracing import SpanAttributes | ||||
|  | ||||
|  | ||||
| @pytest.fixture(scope="function", autouse=True) | ||||
| def use_v0_only(monkeypatch: pytest.MonkeyPatch): | ||||
|     """ | ||||
|     Since this module is V0 only, set VLLM_USE_V1=0 for | ||||
|     all tests in the module. | ||||
|     """ | ||||
|     with monkeypatch.context() as m: | ||||
|         m.setenv('VLLM_USE_V1', '0') | ||||
|         yield | ||||
|  | ||||
|  | ||||
| FAKE_TRACE_SERVER_ADDRESS = "localhost:4317" | ||||
|  | ||||
| FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value', | ||||
|                     'array_value'] | ||||
|  | ||||
|  | ||||
| def decode_value(value: AnyValue): | ||||
|     field_decoders: dict[FieldName, Callable] = { | ||||
|         "bool_value": (lambda v: v.bool_value), | ||||
|         "string_value": (lambda v: v.string_value), | ||||
|         "int_value": (lambda v: v.int_value), | ||||
|         "double_value": (lambda v: v.double_value), | ||||
|         "array_value": | ||||
|         (lambda v: [decode_value(item) for item in v.array_value.values]), | ||||
|     } | ||||
|     for field, decoder in field_decoders.items(): | ||||
|         if value.HasField(field): | ||||
|             return decoder(value) | ||||
|     raise ValueError(f"Couldn't decode value: {value}") | ||||
|  | ||||
|  | ||||
| def decode_attributes(attributes: Iterable[KeyValue]): | ||||
|     return {kv.key: decode_value(kv.value) for kv in attributes} | ||||
|  | ||||
|  | ||||
| class FakeTraceService(TraceServiceServicer): | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.request = None | ||||
|         self.evt = threading.Event() | ||||
|  | ||||
|     def Export(self, request, context): | ||||
|         self.request = request | ||||
|         self.evt.set() | ||||
|         return ExportTraceServiceResponse() | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def trace_service() -> Generator[FakeTraceService, None, None]: | ||||
|     """Fixture to set up a fake gRPC trace service""" | ||||
|     server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) | ||||
|     service = FakeTraceService() | ||||
|     add_TraceServiceServicer_to_server(service, server) | ||||
|     server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS) | ||||
|     server.start() | ||||
|  | ||||
|     yield service | ||||
|  | ||||
|     server.stop(None) | ||||
|  | ||||
|  | ||||
| def test_traces( | ||||
|     monkeypatch: pytest.MonkeyPatch, | ||||
|     trace_service: FakeTraceService, | ||||
| ): | ||||
|     with monkeypatch.context() as m: | ||||
|         m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") | ||||
|  | ||||
|         sampling_params = SamplingParams( | ||||
|             temperature=0.01, | ||||
|             top_p=0.1, | ||||
|             max_tokens=256, | ||||
|         ) | ||||
|         model = "facebook/opt-125m" | ||||
|         llm = LLM( | ||||
|             model=model, | ||||
|             otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, | ||||
|         ) | ||||
|         prompts = ["This is a short prompt"] | ||||
|         outputs = llm.generate(prompts, sampling_params=sampling_params) | ||||
|  | ||||
|         timeout = 5 | ||||
|         if not trace_service.evt.wait(timeout): | ||||
|             raise TimeoutError( | ||||
|                 f"The fake trace service didn't receive a trace within " | ||||
|                 f"the {timeout} seconds timeout") | ||||
|  | ||||
|         request = trace_service.request | ||||
|         assert len(request.resource_spans) == 1, ( | ||||
|             f"Expected 1 resource span, " | ||||
|             f"but got {len(request.resource_spans)}") | ||||
|         assert len(request.resource_spans[0].scope_spans) == 1, ( | ||||
|             f"Expected 1 scope span, " | ||||
|             f"but got {len(request.resource_spans[0].scope_spans)}") | ||||
|         assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( | ||||
|             f"Expected 1 span, " | ||||
|             f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") | ||||
|  | ||||
|         attributes = decode_attributes( | ||||
|             request.resource_spans[0].scope_spans[0].spans[0].attributes) | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE | ||||
|                               ) == sampling_params.temperature | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS | ||||
|                               ) == sampling_params.max_tokens | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( | ||||
|                 outputs[0].prompt_token_ids) | ||||
|         completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens | ||||
|         metrics = outputs[0].metrics | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE | ||||
|                               ) == metrics.time_in_queue | ||||
|         ttft = metrics.first_token_time - metrics.arrival_time | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft | ||||
|         e2e_time = metrics.finished_time - metrics.arrival_time | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time | ||||
|         assert metrics.scheduler_time > 0 | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER | ||||
|                               ) == metrics.scheduler_time | ||||
|         # Model forward and model execute should be none, since detailed traces is | ||||
|         # not enabled. | ||||
|         assert metrics.model_forward_time is None | ||||
|         assert metrics.model_execute_time is None | ||||
|  | ||||
|  | ||||
| def test_traces_with_detailed_steps( | ||||
|     monkeypatch: pytest.MonkeyPatch, | ||||
|     trace_service: FakeTraceService, | ||||
| ): | ||||
|     with monkeypatch.context() as m: | ||||
|         m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true") | ||||
|  | ||||
|         sampling_params = SamplingParams( | ||||
|             temperature=0.01, | ||||
|             top_p=0.1, | ||||
|             max_tokens=256, | ||||
|         ) | ||||
|         model = "facebook/opt-125m" | ||||
|         llm = LLM( | ||||
|             model=model, | ||||
|             otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, | ||||
|             collect_detailed_traces=["all"], | ||||
|         ) | ||||
|         prompts = ["This is a short prompt"] | ||||
|         outputs = llm.generate(prompts, sampling_params=sampling_params) | ||||
|  | ||||
|         timeout = 5 | ||||
|         if not trace_service.evt.wait(timeout): | ||||
|             raise TimeoutError( | ||||
|                 f"The fake trace service didn't receive a trace within " | ||||
|                 f"the {timeout} seconds timeout") | ||||
|  | ||||
|         request = trace_service.request | ||||
|         assert len(request.resource_spans) == 1, ( | ||||
|             f"Expected 1 resource span, " | ||||
|             f"but got {len(request.resource_spans)}") | ||||
|         assert len(request.resource_spans[0].scope_spans) == 1, ( | ||||
|             f"Expected 1 scope span, " | ||||
|             f"but got {len(request.resource_spans[0].scope_spans)}") | ||||
|         assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( | ||||
|             f"Expected 1 span, " | ||||
|             f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") | ||||
|  | ||||
|         attributes = decode_attributes( | ||||
|             request.resource_spans[0].scope_spans[0].spans[0].attributes) | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE | ||||
|                               ) == sampling_params.temperature | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS | ||||
|                               ) == sampling_params.max_tokens | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len( | ||||
|                 outputs[0].prompt_token_ids) | ||||
|         completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens | ||||
|         metrics = outputs[0].metrics | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE | ||||
|                               ) == metrics.time_in_queue | ||||
|         ttft = metrics.first_token_time - metrics.arrival_time | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft | ||||
|         e2e_time = metrics.finished_time - metrics.arrival_time | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time | ||||
|         assert metrics.scheduler_time > 0 | ||||
|         assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER | ||||
|                               ) == metrics.scheduler_time | ||||
|         assert metrics.model_forward_time > 0 | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD | ||||
|         ) == pytest.approx(metrics.model_forward_time / 1000) | ||||
|         assert metrics.model_execute_time > 0 | ||||
|         assert attributes.get( | ||||
|             SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE | ||||
|         ) == metrics.model_execute_time | ||||
|         assert metrics.model_forward_time < 1000 * metrics.model_execute_time | ||||
		Reference in New Issue
	
	Block a user
	