From 6ac5e06f7c5d4658c9fb119826a92d9910730fb4 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sun, 19 Oct 2025 00:48:22 +0800 Subject: [PATCH] [Chore] Clean up pytorch helper functions in `vllm.utils` (#26908) Signed-off-by: Isotr0py Signed-off-by: isotr0py <2037008807@qq.com> --- .../kernels/bench_per_token_quant_fp8.py | 3 +- benchmarks/kernels/benchmark_activation.py | 3 +- benchmarks/kernels/benchmark_layernorm.py | 3 +- .../kernels/benchmark_paged_attention.py | 4 +- benchmarks/kernels/benchmark_quant.py | 3 +- .../kernels/benchmark_reshape_and_cache.py | 4 +- .../benchmark_reshape_and_cache_flash.py | 4 +- .../compile/piecewise/test_full_cudagraph.py | 2 +- .../compile/piecewise/test_multiple_graphs.py | 2 +- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 2 +- tests/compile/silly_attention.py | 2 +- tests/compile/test_aot_compile.py | 2 +- tests/compile/test_basic_correctness.py | 2 +- tests/compile/test_config.py | 2 +- tests/compile/test_decorator.py | 2 +- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusions_e2e.py | 2 +- tests/conftest.py | 2 +- tests/distributed/test_sequence_parallel.py | 2 +- tests/distributed/test_utils.py | 2 +- tests/kernels/attention/conftest.py | 5 +- .../kernels/attention/test_prefix_prefill.py | 2 +- tests/kernels/core/test_uva.py | 3 +- .../moe/test_modular_kernel_combinations.py | 3 +- tests/kernels/utils.py | 2 +- .../multimodal/pooling/test_intern_vit.py | 2 +- tests/models/multimodal/pooling/test_radio.py | 2 +- .../processing/test_tensor_schema.py | 2 +- tests/utils.py | 2 +- tests/utils_/test_utils.py | 10 +- tests/v1/attention/test_attention_backends.py | 3 +- tests/v1/attention/test_mla_backends.py | 3 +- tests/v1/engine/test_async_llm.py | 2 +- tests/v1/engine/test_engine_core.py | 2 +- tests/v1/engine/test_engine_core_client.py | 2 +- tests/v1/sample/test_sampler.py | 3 +- tests/v1/sample/utils.py | 2 +- tests/v1/shutdown/test_delete.py | 2 +- tests/v1/shutdown/test_forward_error.py | 2 +- tests/v1/shutdown/test_startup_error.py | 2 +- tests/v1/worker/test_gpu_input_batch.py | 3 +- vllm/attention/layer.py | 2 +- vllm/attention/ops/rocm_aiter_mla.py | 2 +- vllm/compilation/backends.py | 2 +- vllm/compilation/collective_fusion.py | 2 +- vllm/compilation/compiler_interface.py | 2 +- vllm/compilation/cuda_graph.py | 2 +- vllm/compilation/decorators.py | 2 +- vllm/compilation/inductor_pass.py | 2 +- vllm/config/compilation.py | 2 +- vllm/config/model.py | 3 +- vllm/config/parallel.py | 3 +- .../device_communicators/all_reduce_utils.py | 3 +- .../device_communicators/custom_all_reduce.py | 2 +- .../device_communicators/pynccl.py | 4 +- .../device_communicators/quick_all_reduce.py | 2 +- .../device_communicators/ray_communicator.py | 2 +- .../kv_connector/v1/p2p/p2p_nccl_engine.py | 3 +- vllm/distributed/parallel_state.py | 6 +- vllm/distributed/utils.py | 3 +- vllm/env_override.py | 2 +- vllm/envs.py | 2 +- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 3 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/utils.py | 3 +- vllm/model_executor/layers/layernorm.py | 2 +- .../layers/mamba/linear_attn.py | 2 +- .../layers/mamba/mamba_mixer.py | 2 +- .../layers/mamba/mamba_mixer2.py | 2 +- .../layers/mamba/mamba_utils.py | 5 +- .../model_executor/layers/mamba/short_conv.py | 2 +- .../layers/quantization/bitsandbytes.py | 2 +- .../layers/quantization/fp_quant.py | 2 +- .../layers/quantization/gguf.py | 2 +- .../quantization/kernels/scaled_mm/aiter.py | 2 +- .../layers/quantization/mxfp4.py | 2 +- .../quark/schemes/quark_ocp_mx.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 2 +- .../layers/quantization/utils/mxfp4_utils.py | 2 +- .../layers/quantization/utils/mxfp6_utils.py | 2 +- .../layers/quantization/utils/w8a8_utils.py | 2 +- .../layers/rotary_embedding/common.py | 2 +- .../rotary_embedding/rocm_aiter_rope_ops.py | 2 +- vllm/model_executor/layers/utils.py | 2 +- .../model_loader/base_loader.py | 2 +- .../model_loader/bitsandbytes_loader.py | 3 +- .../model_loader/gguf_loader.py | 2 +- .../model_loader/tensorizer_loader.py | 2 +- vllm/model_executor/model_loader/tpu.py | 2 +- vllm/model_executor/model_loader/utils.py | 10 - vllm/model_executor/models/config.py | 3 +- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/deepseek_vl2.py | 2 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmv.py | 2 +- vllm/model_executor/models/plamo2.py | 2 +- vllm/model_executor/models/qwen3_next.py | 2 +- .../model_executor/models/transformers/moe.py | 2 +- vllm/model_executor/models/utils.py | 6 +- vllm/model_executor/models/whisper.py | 2 +- vllm/model_executor/utils.py | 2 +- vllm/platforms/__init__.py | 2 +- vllm/platforms/cuda.py | 3 +- vllm/platforms/rocm.py | 2 +- vllm/usage/usage_lib.py | 3 +- vllm/utils/__init__.py | 578 +---------------- vllm/utils/torch_utils.py | 605 ++++++++++++++++++ vllm/v1/attention/backends/flex_attention.py | 3 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 2 +- vllm/v1/kv_cache_interface.py | 3 +- vllm/v1/sample/ops/penalties.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 8 +- vllm/v1/worker/tpu_worker.py | 3 +- vllm/v1/worker/ubatching.py | 2 +- 119 files changed, 772 insertions(+), 714 deletions(-) create mode 100644 vllm/utils/torch_utils.py diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py index 9a52ea7f47..d33b84fc36 100644 --- a/benchmarks/kernels/bench_per_token_quant_fp8.py +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -10,7 +10,8 @@ import torch from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE def with_triton_mode(fn): diff --git a/benchmarks/kernels/benchmark_activation.py b/benchmarks/kernels/benchmark_activation.py index 93edbcc939..7662655b5e 100644 --- a/benchmarks/kernels/benchmark_activation.py +++ b/benchmarks/kernels/benchmark_activation.py @@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401 from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE batch_size_range = [1, 16, 32, 64, 128] seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index 69978ec6b2..bcfa64c3f4 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -7,7 +7,8 @@ import torch from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 8f9907952d..1b1e71adee 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,9 +9,9 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 6ab26f5f1a..61427a77b4 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -7,7 +7,8 @@ import torch from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE @torch.inference_mode() diff --git a/benchmarks/kernels/benchmark_reshape_and_cache.py b/benchmarks/kernels/benchmark_reshape_and_cache.py index d4b564d2ec..e0ff09d4b3 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache.py @@ -9,9 +9,9 @@ from tabulate import tabulate from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random, ) diff --git a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py index 93df14f0d9..29f1b2ccdc 100644 --- a/benchmarks/kernels/benchmark_reshape_and_cache_flash.py +++ b/benchmarks/kernels/benchmark_reshape_and_cache_flash.py @@ -12,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import ( ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils import FlexibleArgumentParser +from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, - FlexibleArgumentParser, create_kv_caches_with_random_flash, ) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index e01b582209..c6d4b5272d 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -11,7 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer @contextlib.contextmanager diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index 246239b87d..700f57ffb0 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -20,7 +20,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index e39ee21b4d..228859532e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -19,7 +19,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from ..silly_attention import get_global_counter, reset_global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 500cca87d9..175ca4a230 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -27,7 +27,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from .. import silly_attention # noqa: F401 diff --git a/tests/compile/silly_attention.py b/tests/compile/silly_attention.py index f33c577290..29c02f6e6a 100644 --- a/tests/compile/silly_attention.py +++ b/tests/compile/silly_attention.py @@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations. import torch from torch.library import Library -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op # Shared library for all compilation test operations # Using "silly" namespace to match existing test expectations diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index 1701d85fe8..b2734af575 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -15,7 +15,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer def reference_fn(x: torch.Tensor): diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 954774a8e3..132a838b8d 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -5,7 +5,7 @@ import dataclasses import pytest from vllm.config import CompilationMode -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 87b5d167d1..c6fe65ab51 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -8,7 +8,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.config.compilation import CompilationMode -from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer +from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer def test_version(): diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index e459bc539f..c9d01f2317 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -15,7 +15,7 @@ from vllm.config import ( set_current_vllm_config, ) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index fb511dd8f7..7a4e859b3e 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import create_new_process_for_each_test diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 7399abaec5..efb5774b78 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -15,8 +15,8 @@ from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..utils import flat_product, multi_gpu_test diff --git a/tests/conftest.py b/tests/conftest.py index 8537e8fd03..5e94a8322e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,8 +60,8 @@ from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import set_default_torch_num_threads from vllm.utils.collections import is_list_of +from vllm.utils.torch_utils import set_default_torch_num_threads logger = init_logger(__name__) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index fc55a39bd5..3646f48426 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -18,7 +18,7 @@ import pytest from vllm.config.compilation import CompilationMode from vllm.config.model import RunnerOption from vllm.logger import init_logger -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 2a6936fcd4..c10c256581 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -11,10 +11,10 @@ import vllm.envs as envs from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.utils import StatelessProcessGroup from vllm.utils import ( - cuda_device_count_stateless, get_open_port, update_environment_variables, ) +from vllm.utils.torch_utils import cuda_device_count_stateless from ..utils import multi_gpu_test diff --git a/tests/kernels/attention/conftest.py b/tests/kernels/attention/conftest.py index b080a71bd5..e520267320 100644 --- a/tests/kernels/attention/conftest.py +++ b/tests/kernels/attention/conftest.py @@ -3,7 +3,10 @@ import pytest -from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash +from vllm.utils.torch_utils import ( + create_kv_caches_with_random, + create_kv_caches_with_random_flash, +) @pytest.fixture() diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 5ff2624cd7..65972d02f2 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -15,7 +15,7 @@ from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 64] diff --git a/tests/kernels/core/test_uva.py b/tests/kernels/core/test_uva.py index 73738175e5..dee92976eb 100644 --- a/tests/kernels/core/test_uva.py +++ b/tests/kernels/core/test_uva.py @@ -3,7 +3,8 @@ import pytest import torch -from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available +from vllm.utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index a86185a2dc..a7beb31301 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -13,8 +13,9 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, set_current_vllm_config from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.torch_utils import cuda_device_count_stateless from .modular_kernel_tools.common import ( Config, diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 6c7ff984b4..eb00bc72b4 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -22,8 +22,8 @@ from vllm.utils import ( STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad, ) +from vllm.utils.torch_utils import make_tensor_with_pad # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. diff --git a/tests/models/multimodal/pooling/test_intern_vit.py b/tests/models/multimodal/pooling/test_intern_vit.py index 74e30c4307..5a97848216 100644 --- a/tests/models/multimodal/pooling/test_intern_vit.py +++ b/tests/models/multimodal/pooling/test_intern_vit.py @@ -7,7 +7,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/pooling/test_radio.py b/tests/models/multimodal/pooling/test_radio.py index 414e99a71e..8929563d8b 100644 --- a/tests/models/multimodal/pooling/test_radio.py +++ b/tests/models/multimodal/pooling/test_radio.py @@ -9,7 +9,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.radio import RadioModel from vllm.transformers_utils.configs.radio import RadioConfig -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import ImageTestAssets diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 166709329a..093898dd4b 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -26,7 +26,6 @@ from vllm.distributed import ( init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.interfaces import ( SupportsMultiModal, supports_multimodal, @@ -36,6 +35,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingC from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of +from vllm.utils.torch_utils import set_default_torch_dtype from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS from ...utils import dummy_hf_overrides diff --git a/tests/utils.py b/tests/utils.py index d17dbbeefc..a6d188cd67 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -46,10 +46,10 @@ from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import ( FlexibleArgumentParser, - cuda_device_count_stateless, get_open_port, ) from vllm.utils.mem_constants import GB_bytes +from vllm.utils.torch_utils import cuda_device_count_stateless if current_platform.is_rocm(): from amdsmi import ( diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 32a2072396..b3a27460df 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -24,11 +24,8 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens from vllm.utils import ( FlexibleArgumentParser, bind_kv_cache, - common_broadcastable_dtype, - current_stream, get_open_port, get_tcp_uri, - is_lossless_cast, join_host_port, make_zmq_path, make_zmq_socket, @@ -37,6 +34,11 @@ from vllm.utils import ( split_zmq_path, unique_filepath, ) +from vllm.utils.torch_utils import ( + common_broadcastable_dtype, + current_stream, + is_lossless_cast, +) from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from ..utils import create_new_process_for_each_test, flat_product @@ -408,7 +410,7 @@ def test_bind_kv_cache_non_attention(): def test_bind_kv_cache_pp(): - with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): + with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 174642123d..12f7fc66d1 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -18,7 +18,8 @@ from tests.v1.attention.utils import ( from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, set_kv_cache_layout, diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index f41f63ed2a..81fd6433b0 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -22,7 +22,8 @@ from vllm import _custom_ops as ops from vllm.attention.backends.registry import _Backend from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.config.vllm import set_current_vllm_config -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index b9fa553142..c9605ea1b0 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -15,7 +15,7 @@ from vllm.inputs import PromptType from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.metrics.loggers import ( AggregatedLoggingStatLogger, diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 2eb391d676..341a1f3357 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -12,7 +12,7 @@ from transformers import AutoTokenizer from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.executor.abstract import Executor, UniProcExecutor diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 32eeaebbca..770560a5e5 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -21,7 +21,7 @@ from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublis from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext -from vllm.utils import set_default_torch_num_threads +from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index edc6acae84..a1513acc7b 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -7,7 +7,8 @@ import torch from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index 5d457762fc..a0abb3b4c6 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -9,7 +9,7 @@ import regex as re import torch from vllm import CompletionOutput -from vllm.utils import make_tensor_with_pad +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata diff --git a/tests/v1/shutdown/test_delete.py b/tests/v1/shutdown/test_delete.py index 286575b0d5..ee04dfad39 100644 --- a/tests/v1/shutdown/test_delete.py +++ b/tests/v1/shutdown/test_delete.py @@ -12,7 +12,7 @@ from tests.v1.shutdown.utils import ( from vllm import LLM, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.sampling_params import RequestOutputKind -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] diff --git a/tests/v1/shutdown/test_forward_error.py b/tests/v1/shutdown/test_forward_error.py index cacc71be43..a751b2d919 100644 --- a/tests/v1/shutdown/test_forward_error.py +++ b/tests/v1/shutdown/test_forward_error.py @@ -14,7 +14,7 @@ from tests.v1.shutdown.utils import ( from vllm import LLM, AsyncEngineArgs, SamplingParams from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineDeadError diff --git a/tests/v1/shutdown/test_startup_error.py b/tests/v1/shutdown/test_startup_error.py index 66c9a52b6d..c1594cc2e8 100644 --- a/tests/v1/shutdown/test_startup_error.py +++ b/tests/v1/shutdown/test_startup_error.py @@ -13,7 +13,7 @@ from vllm import LLM from vllm.distributed import get_tensor_model_parallel_rank from vllm.engine.arg_utils import AsyncEngineArgs from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.v1.engine.async_llm import AsyncLLM MODELS = ["hmellor/tiny-random-LlamaForCausalLM"] diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 5ab67dcf76..132f0a58bb 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -10,7 +10,8 @@ import torch from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f23fc9ba52..a028be6ce7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import current_platform -from vllm.utils import ( +from vllm.utils.torch_utils import ( direct_register_custom_op, kv_cache_dtype_str_to_dtype, ) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 8fc034dd72..6308f63cc4 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -5,7 +5,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9e6053bc30..556222936e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import ( from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer from .caching import VllmSerializableFunction from .compiler_interface import ( diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c1ed058ded..7294ddce64 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index e2369a635a..0a3f0769db 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,7 +16,7 @@ import torch.fx as fx import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer class CompilerInterface: diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index fe20a5f7e6..a2e0abfebc 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import weak_ref_tensors +from vllm.utils.torch_utils import weak_ref_tensors logger = init_logger(__name__) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4cbe3044e4..4a4903035c 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors -from vllm.utils import supports_dynamo from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_dynamo from .monitor import start_monitoring_torch_compile diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 4b263fa6f5..9af635a929 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -14,7 +14,7 @@ import torch from torch import fx from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer if is_torch_equal_or_newer("2.6"): from torch._inductor.custom_graph_pass import CustomGraphPass diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ff43e4e826..2aaf4ba51f 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import is_torch_equal_or_newer if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/config/model.py b/vllm/config/model.py index 4ed8b63297..c99451aa2a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -41,8 +41,9 @@ from vllm.transformers_utils.config import ( ) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import LayerBlockType, common_broadcastable_dtype +from vllm.utils import LayerBlockType from vllm.utils.import_utils import LazyLoader +from vllm.utils.torch_utils import common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 953aa1a147..999576bab9 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, get_open_ports_list +from vllm.utils import get_open_ports_list +from vllm.utils.torch_utils import cuda_device_count_stateless if TYPE_CHECKING: from ray.runtime_env import RuntimeEnv diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 2eb3ce2976..7ccc04cf55 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -22,7 +22,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) -from vllm.utils import cuda_device_count_stateless, update_environment_variables +from vllm.utils import update_environment_variables +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 4bc737494c..4b82f3b5d3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import ( from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless try: ops.meta_size() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f083308791..ad3c8676fa 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( ) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) @@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm): from vllm.distributed.device_communicators.pynccl_allocator import ( nccl_symm_mem_context, ) - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op global _NCCL_SYMM_OPS_REGISTERED if _NCCL_SYMM_OPS_REGISTERED: diff --git a/vllm/distributed/device_communicators/quick_all_reduce.py b/vllm/distributed/device_communicators/quick_all_reduce.py index 7a95749635..9c7765883c 100644 --- a/vllm/distributed/device_communicators/quick_all_reduce.py +++ b/vllm/distributed/device_communicators/quick_all_reduce.py @@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) diff --git a/vllm/distributed/device_communicators/ray_communicator.py b/vllm/distributed/device_communicators/ray_communicator.py index 732a40770f..3b02b885e7 100644 --- a/vllm/distributed/device_communicators/ray_communicator.py +++ b/vllm/distributed/device_communicators/ray_communicator.py @@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 7714359a50..5b32a97566 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501 TensorMemoryPool, ) -from vllm.utils import current_stream, get_ip +from vllm.utils import get_ip +from vllm.utils.torch_utils import current_stream logger = logging.getLogger(__name__) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 132fb90491..f3f5318200 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import ( - direct_register_custom_op, get_distributed_init_method, - supports_custom_op, ) from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import ( + direct_register_custom_op, + supports_custom_op, +) @dataclass diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index a3d9dbe83a..a5df81e55e 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_tcp_uri, is_torch_equal_or_newer +from vllm.utils import get_tcp_uri +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/env_override.py b/vllm/env_override.py index f4ac48584c..30071f8ea4 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -5,7 +5,7 @@ import os import torch from vllm.logger import init_logger -from vllm.utils import is_torch_equal +from vllm.utils.torch_utils import is_torch_equal logger = init_logger(__name__) diff --git a/vllm/envs.py b/vllm/envs.py index 3cf3444e20..e7ab320b4e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None: def use_aot_compile() -> bool: - from vllm.utils import is_torch_equal_or_newer + from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index c833045598..fd4c1364de 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -12,7 +12,7 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 9cba8f4944..8c58915e3f 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -12,7 +12,7 @@ import torch from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 698d12d5ea..f21fe16c51 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def flashinfer_fused_moe_blockscale_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d9007d50e3..f5760fea65 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc from vllm.model_executor.utils import maybe_disable_graph_partition from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3bb544a49f..04d8e91b0d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up +from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 350713698a..e18514ad43 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class QuantMethod(IntEnum): diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index e595747463..0627ea50d8 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( mxfp8_e4m3_quantize, ) from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv from vllm.utils.flashinfer import flashinfer_fp4_quantize +from vllm.utils.torch_utils import is_torch_equal_or_newer @triton.jit diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index dac5d129c3..65432c0fb2 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index ce8f50bb27..fd4567ee47 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateShapeCalculator, ) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata if TYPE_CHECKING: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 8f7317556f..a9a0c21647 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_state_update, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b0ee327a82..fb45afa33d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import ( sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index 41ab7f3fec..91a4562358 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -6,7 +6,10 @@ import torch from vllm.config.cache import MambaDType from vllm.config.model import ModelDType from vllm.distributed import divide -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype +from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + get_kv_cache_torch_dtype, +) class MambaStateDtypeCalculator: diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index afaa706929..04efa8a8b3 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 81cf86a7d0..ccd9b311cc 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import ( QuantizationMethods, ) from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/fp_quant.py b/vllm/model_executor/layers/quantization/fp_quant.py index f00ea17ab6..15a253cef0 100644 --- a/vllm/model_executor/layers/quantization/fp_quant.py +++ b/vllm/model_executor/layers/quantization/fp_quant.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op class FPQuantConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 84cd07a0c1..8a914c57a9 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 5e133aac10..a19396a162 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -7,7 +7,7 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 5b35cf6df8..12b0c208dd 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -49,10 +49,10 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import ( has_triton_kernels, - is_torch_equal_or_newer, round_up, ) from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index 1bc1171843..c25c522dea 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -45,7 +45,7 @@ try: from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 from aiter.ops.triton.quant import dynamic_mxfp4_quant - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op if is_rocm_aiter_fp4_asm_gemm_enabled(): from aiter import gemm_a4w4, per_1x32_f4_quant_hip diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 7af1e0a5c8..f25148abb6 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -28,13 +28,13 @@ from vllm.model_executor.parameter import ( ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import ( fp8_gemm_nt, is_deep_gemm_e8m0_used, is_deep_gemm_supported, should_use_deepgemm_for_fp8_linear, ) +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 231d7dc6ce..5e87cadfb1 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -7,7 +7,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py index 2249e96589..2b5659e300 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp6_utils.py @@ -3,7 +3,7 @@ import torch from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def _quant_dequant_mxfp6( diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 17da125d5e..380431e864 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer +from vllm.utils.torch_utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index f1b34f1785..9e6ec9fdd5 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -10,7 +10,7 @@ import torch from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op if current_platform.is_cuda(): from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py index 223350d432..a01d14f7b3 100644 --- a/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py +++ b/vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py @@ -5,7 +5,7 @@ import torch import vllm.envs as envs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def is_rocm_triton_rotary_embedding_enabled() -> bool: diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 87ffcb48c8..c1a48fa200 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -9,7 +9,7 @@ import torch from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import CpuArchEnum, current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op def shuffle_weight(w: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 6106a1ab8a..94dfa47824 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -11,8 +11,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 71df96cb3e..97c7a20bc4 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype +from vllm.model_executor.model_loader.utils import ParamMapping from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -48,6 +48,7 @@ from vllm.model_executor.utils import ( set_weight_attrs, ) from vllm.platforms import current_platform +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index dbcd864516..7db1fc167c 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) from vllm.model_executor.model_loader.weight_utils import ( get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, ) +from vllm.utils.torch_utils import set_default_torch_dtype class GGUFModelLoader(BaseModelLoader): diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index ba72d576ba..2b3704cfeb 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.utils import ( get_model_architecture, initialize_model, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py index ec42e3a1ea..fc142f1f07 100644 --- a/vllm/model_executor/model_loader/tpu.py +++ b/vllm/model_executor/model_loader/tpu.py @@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, - set_default_torch_dtype, ) +from vllm.utils.torch_utils import set_default_torch_dtype logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index c68ac61155..88dfbc33e1 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" -import contextlib import inspect import warnings from contextlib import contextmanager @@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available logger = init_logger(__name__) -@contextlib.contextmanager -def set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - def initialize_model( vllm_config: VllmConfig, *, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 662f2c9209..da5d80f982 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -6,7 +6,8 @@ from typing import TYPE_CHECKING import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up +from vllm.utils import cdiv, round_up +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8ad85357ae..58133aa55d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 759f2a18d3..56bbaf0da1 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -18,7 +18,6 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils.collections import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 05b822d6fd..e2d2647f01 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -51,8 +51,8 @@ from vllm.multimodal.processing import ( from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import set_default_torch_num_threads from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_num_threads from .interfaces import ( MultiModalEmbeddings, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 1c2e9042a6..b4a558ad69 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import ( Resampler2, get_2d_sincos_pos_embed, ) -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -88,6 +87,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.collections import flatten_2d_lists from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import ( diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b35a8c6b66..09293f63f7 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -64,7 +64,7 @@ from vllm.model_executor.models.utils import ( ) from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f891a4961d..e81ad5f68d 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -71,7 +71,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from .interfaces import ( diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index ed56fd7399..5de786f995 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.utils import maybe_prefix from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils.torch_utils import direct_register_custom_op from .utils import log_replacement diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4cac6e6133..022cd0fd23 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,11 +24,13 @@ from vllm.multimodal import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import ( cdiv, - direct_register_custom_op, - get_cuda_view_from_cpu_tensor, is_pin_memory_available, is_uva_available, ) +from vllm.utils.torch_utils import ( + direct_register_custom_op, + get_cuda_view_from_cpu_tensor, +) logger = init_logger(__name__) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 0246e0739b..ccfe1871ef 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -53,6 +52,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.jsontree import json_map_leaves from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription from .utils import ( diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 5ffee6cb8d..759b809433 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -7,7 +7,7 @@ from typing import Any import torch -from vllm.utils import is_torch_equal_or_newer +from vllm.utils.torch_utils import is_torch_equal_or_newer def set_random_seed(seed: int) -> None: diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 8942a3206e..30dd7cade2 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING from vllm import envs from vllm.plugins import PLATFORM_PLUGINS_GROUP, load_plugins_by_group -from vllm.utils import supports_xccl from vllm.utils.import_utils import resolve_obj_by_qualname +from vllm.utils.torch_utils import supports_xccl from .interface import CpuArchEnum, Platform, PlatformEnum diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a6b9df7c14..c736e084a3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -16,7 +16,8 @@ from typing_extensions import ParamSpec import vllm._C # noqa import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, import_pynvml +from vllm.utils import import_pynvml +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4680050965..9788bfeca1 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -9,7 +9,7 @@ import torch import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless +from vllm.utils.torch_utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 27a4f89e00..4211535131 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -21,7 +21,8 @@ import torch import vllm.envs as envs from vllm.connections import global_http_connection from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties +from vllm.utils import cuda_get_device_properties +from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index dd83f9fc96..7cb3805fcb 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -35,12 +35,11 @@ from argparse import ( from collections import defaultdict from collections.abc import ( Callable, - Collection, Iterator, Sequence, ) from concurrent.futures.process import ProcessPoolExecutor -from functools import cache, lru_cache, partial, wraps +from functools import cache, partial, wraps from pathlib import Path from typing import TYPE_CHECKING, Any, TextIO, TypeVar from urllib.parse import urlparse @@ -48,8 +47,6 @@ from uuid import uuid4 import cbor2 import cloudpickle -import numpy as np -import numpy.typing as npt import psutil import regex as re import setproctitle @@ -57,9 +54,6 @@ import torch import yaml import zmq import zmq.asyncio -from packaging import version -from packaging.version import Version -from torch.library import Library import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -69,13 +63,11 @@ if TYPE_CHECKING: from argparse import Namespace from vllm.config import ModelConfig, VllmConfig - from vllm.sequence import IntermediateTensors else: Namespace = object ModelConfig = object VllmConfig = object - IntermediateTensors = object logger = init_logger(__name__) @@ -105,46 +97,6 @@ STR_INVALID_VAL: str = "INVALID" CYAN = "\033[1;36m" RESET = "\033[0;0m" -STR_DTYPE_TO_TORCH_DTYPE = { - "float32": torch.float32, - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.uint8, - "fp8_e4m3": torch.uint8, - "fp8_e5m2": torch.uint8, - "int8": torch.int8, - "fp8_inc": torch.float8_e4m3fn, - "fp8_ds_mla": torch.uint8, -} - -TORCH_DTYPE_TO_NUMPY_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int32: np.int32, - torch.int64: np.int64, -} - - -@contextlib.contextmanager -def set_default_torch_num_threads(num_threads: int): - """Sets the default number of threads for PyTorch to the given value.""" - old_num_threads = torch.get_num_threads() - torch.set_num_threads(num_threads) - yield - torch.set_num_threads(old_num_threads) - - -def kv_cache_dtype_str_to_dtype( - kv_cache_dtype: str, model_config: ModelConfig -) -> torch.dtype: - if kv_cache_dtype == "auto": - # Model config may not be specified for unit tests, default to float16 - return model_config.dtype if model_config else torch.half - return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] - T = TypeVar("T") U = TypeVar("U") @@ -407,141 +359,6 @@ def round_down(x: int, y: int) -> int: return (x // y) * y -def _generate_random_fp8( - tensor: torch.Tensor, - low: float, - high: float, -) -> None: - # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, - # it may occur Inf or NaN if we directly use torch.randint - # to generate random data for fp8 data. - # For example, s.11111.00 in fp8e5m2 format represents Inf. - # | E4M3 | E5M2 - # -----|-------------|------------------- - # Inf | N/A | s.11111.00 - # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm import _custom_ops as ops - - tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) - tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor, tensor_tmp) - del tensor_tmp - - -def get_kv_cache_torch_dtype( - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, -) -> torch.dtype: - if isinstance(cache_dtype, str): - if cache_dtype == "auto": - if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - elif isinstance(model_dtype, torch.dtype): - torch_dtype = model_dtype - else: - raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - elif isinstance(cache_dtype, torch.dtype): - torch_dtype = cache_dtype - else: - raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") - return torch_dtype - - -def create_kv_caches_with_random_flash( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, - seed: int | None = None, - device: str | None = "cuda", - cache_layout: str | None = "NHD", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) - assert cache_layout in ("NHD", "HND") - stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) - - kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) - scale = head_size**-0.5 - - key_caches: list[torch.Tensor] = [] - value_caches: list[torch.Tensor] = [] - - for _ in range(num_layers): - key_value_cache = torch.empty( - size=kv_cache_allocation_shape, dtype=dtype, device=device - ).permute(*stride_order) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_value_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_value_cache[:, 0]) - value_caches.append(key_value_cache[:, 1]) - return key_caches, value_caches - - -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: str | torch.dtype | None, - model_dtype: str | torch.dtype | None = None, - seed: int | None = None, - device: str | None = "cuda", -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - if cache_dtype == "fp8" and head_size % 16: - raise ValueError( - f"Does not support key cache of type fp8 with head_size {head_size}" - ) - from vllm.platforms import current_platform - - current_platform.seed_everything(seed) - - dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - key_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(key_cache, -scale, scale) - else: - raise ValueError(f"Does not support key cache of type {cache_dtype}") - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches: list[torch.Tensor] = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - if cache_dtype in ["auto", "half", "bfloat16", "float"]: - value_cache.uniform_(-scale, scale) - elif cache_dtype == "fp8": - _generate_random_fp8(value_cache, -scale, scale) - else: - raise ValueError(f"Does not support value cache of type {cache_dtype}") - value_caches.append(value_cache) - return key_caches, value_caches - - @cache def is_pin_memory_available() -> bool: from vllm.platforms import current_platform @@ -557,121 +374,6 @@ def is_uva_available() -> bool: return is_pin_memory_available() -def make_ndarray_with_pad( - x: list[list[T]], - pad: T, - dtype: npt.DTypeLike, - *, - max_len: int | None = None, -) -> npt.NDArray: - """ - Make a padded array from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - if max_len is None: - # Unlike for most functions, map is faster than a genexpr over `len` - max_len = max(map(len, x), default=0) - - padded_x = np.full((len(x), max_len), pad, dtype=dtype) - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len - padded_x[ind, : len(blocktb)] = blocktb - - return padded_x - - -def make_tensor_with_pad( - x: list[list[T]], - pad: T, - dtype: torch.dtype, - *, - max_len: int | None = None, - device: str | torch.device | None = None, - pin_memory: bool = False, -) -> torch.Tensor: - """ - Make a padded tensor from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) - - tensor = torch.from_numpy(padded_x).to(device) - if pin_memory: - tensor = tensor.pin_memory() - - return tensor - - -def async_tensor_h2d( - data: list, - dtype: torch.dtype, - target_device: str | torch.device, - pin_memory: bool, -) -> torch.Tensor: - """Asynchronously create a tensor and copy it from host to device.""" - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") - return t.to(device=target_device, non_blocking=True) - - -def get_dtype_size(dtype: torch.dtype) -> int: - """Get the size of the data type in bytes.""" - return torch.tensor([], dtype=dtype).element_size() - - -# bool = 0, int = 1, float = 2, complex = 3 -def _get_precision_level(dtype: torch.dtype) -> int: - # NOTE: Complex dtypes return `is_floating_point=False` - return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 - - -def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): - """ - Test whether it is lossless to cast a tensor from - `src_dtype` to `tgt_dtype`. - """ - if src_dtype == tgt_dtype: - return True - - src_level = _get_precision_level(src_dtype) - tgt_level = _get_precision_level(tgt_dtype) - - if src_level < tgt_level: - return True - if src_level > tgt_level: - return False - - # Compare integral types - if not src_dtype.is_floating_point and not src_dtype.is_complex: - src_info = torch.iinfo(src_dtype) - tgt_info = torch.iinfo(tgt_dtype) - return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max - - # Compare floating-point types - src_info = torch.finfo(src_dtype) - tgt_info = torch.finfo(tgt_dtype) - return ( - src_info.min >= tgt_info.min - and src_info.max <= tgt_info.max - and src_info.resolution >= tgt_info.resolution - ) - - -def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): - """ - Get the common `dtype` where all of the other `dtypes` can be - cast to it without losing any information. - """ - return max( - dtypes, - key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), - ) - - # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: @@ -767,60 +469,6 @@ def find_nccl_include_paths() -> list[str] | None: return out or None -prev_set_stream = torch.cuda.set_stream - -_current_stream_tls = threading.local() - - -def _patched_set_stream(stream: torch.cuda.Stream) -> None: - _current_stream_tls.value = stream - prev_set_stream(stream) - - -torch.cuda.set_stream = _patched_set_stream - - -class _StreamPlaceholder: - def __init__(self): - self.synchronize = lambda: None - - -def current_stream() -> torch.cuda.Stream: - """ - replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. - it turns out that `torch.cuda.current_stream()` is quite expensive, - as it will construct a new stream object at each call. - here we patch `torch.cuda.set_stream` to keep track of the current stream - directly, so that we can avoid calling `torch.cuda.current_stream()`. - - the underlying hypothesis is that we do not call `torch._C._cuda_setStream` - from C/C++ code. - """ - from vllm.platforms import current_platform - - if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: - # when this function is called before any stream is set, - # we return the default stream. - # On ROCm using the default 0 stream in combination with RCCL - # is hurting performance. Therefore creating a dedicated stream - # per process - if current_platform.is_rocm(): - # torch.cuda.set_stream here is the alias of _pathed_set_stream - torch.cuda.set_stream(torch.cuda.Stream()) - elif current_platform.is_cpu(): - _current_stream_tls.value = _StreamPlaceholder() - else: - current_stream = current_platform.current_stream - if current_stream is not None: - _current_stream_tls.value = current_stream() - else: - raise ValueError( - "Fail to set current stream, current platform " - "may not support current_stream with torch API" - ) - return _current_stream_tls.value - - def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable @@ -842,48 +490,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: enable_trace_function_call(log_path) -@lru_cache(maxsize=8) -def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: - # Note: cuda_visible_devices is not used, but we keep it as an argument for - # LRU Cache purposes. - - # Code below is based on - # https://github.com/pytorch/pytorch/blob/ - # c1cd946818442aca8c7f812b16d187ce1586c3bc/ - # torch/cuda/__init__.py#L831C1-L831C17 - import torch.cuda - - from vllm.platforms import current_platform - - if not torch.cuda._is_compiled(): - return 0 - if current_platform.is_rocm(): - # ROCm uses amdsmi instead of nvml for stateless device count - # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = ( - torch.cuda._device_count_amdsmi() - if (hasattr(torch.cuda, "_device_count_amdsmi")) - else -1 - ) - else: - raw_count = torch.cuda._device_count_nvml() - r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count - return r - - -def cuda_device_count_stateless() -> int: - """Get number of CUDA devices, caching based on the value of - CUDA_VISIBLE_DEVICES at the time of call. - - This should be used instead of torch.cuda.device_count() - unless CUDA_VISIBLE_DEVICES has already been set to the desired - value.""" - - # This can be removed and simply replaced with torch.cuda.get_device_count - # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) - - def cuda_is_initialized() -> bool: """Check if CUDA is initialized.""" if not torch.cuda._is_compiled(): @@ -1411,27 +1017,6 @@ class FlexibleArgumentParser(ArgumentParser): return processed_args -# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. -# In particular, the FakeScalarType is not supported for earlier versions of -# PyTorch which breaks dynamo for any ops registered using ScalarType. -def supports_dynamo() -> bool: - base_torch_version = Version(Version(torch.__version__).base_version) - return base_torch_version >= Version("2.4.0") - - -# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform -def supports_xccl() -> bool: - return ( - is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() - ) - - -# Some backends use pytorch version < 2.4.0 which doesn't -# support `torch.library.custom_op`. -def supports_custom_op() -> bool: - return hasattr(torch.library, "custom_op") - - class AtomicCounter: """An atomic, thread-safe counter""" @@ -1457,118 +1042,6 @@ class AtomicCounter: return self._value -def weak_ref_tensor(tensor: Any) -> Any: - """ - Create a weak reference to a tensor. - The new tensor will share the same data as the original tensor, - but will not keep the original tensor alive. - """ - if isinstance(tensor, torch.Tensor): - return torch.ops._C.weak_ref_tensor(tensor) - else: - return tensor - - -def weak_ref_tensors( - tensors: torch.Tensor - | list[torch.Tensor] - | tuple[torch.Tensor] - | IntermediateTensors, -) -> torch.Tensor | list[Any] | tuple[Any] | Any: - """ - Convenience function to create weak references to tensors, - for single tensor, list of tensors or tuple of tensors. - """ - if isinstance(tensors, torch.Tensor): - return weak_ref_tensor(tensors) - if isinstance(tensors, list): - return [weak_ref_tensor(t) for t in tensors] - if isinstance(tensors, tuple): - return tuple(weak_ref_tensor(t) for t in tensors) - - # For IntermediateTensors used in pipeline parallelism - from vllm.sequence import IntermediateTensors - - if isinstance(tensors, IntermediateTensors): - ret = IntermediateTensors( - {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} - ) - return ret - raise ValueError("Invalid type for tensors") - - -def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: - """ - Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). - """ - assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" - return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) - - -# create a library to hold the custom op -vllm_lib = Library("vllm", "FRAGMENT") # noqa - - -def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: list[str] | None = None, - fake_impl: Callable | None = None, - target_lib: Library | None = None, - dispatch_key: str | None = None, - tags: tuple[torch.Tag, ...] = (), -): - """ - `torch.library.custom_op` can have significant overhead because it - needs to consider complicated dispatching logic. This function - directly registers a custom op and dispatches it to the CUDA backend. - See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 - for more details. - - By default, the custom op is registered to the vLLM library. If you - want to register it to a different library, you can pass the library - object to the `target_lib` argument. - - IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. If you want to bind the operator to a different library, - make sure the library object is alive when the operator is used. - """ - if not supports_custom_op(): - from vllm.platforms import current_platform - - assert not current_platform.is_cuda_alike(), ( - "cuda platform needs torch>=2.4 to support custom op, " - "chances are you are using an old version of pytorch " - "or a custom build of pytorch. It is recommended to " - "use vLLM in a fresh new environment and let it install " - "the required dependencies." - ) - return - - if mutates_args is None: - mutates_args = [] - - if dispatch_key is None: - from vllm.platforms import current_platform - - dispatch_key = current_platform.dispatch_key - - import torch.library - - if hasattr(torch.library, "infer_schema"): - schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - else: - # for pytorch 2.4 - import torch._custom_op.impl - - schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) - my_lib = target_lib or vllm_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) - if fake_impl is not None: - my_lib._register_fake(op_name, fake_impl) - - def kill_process_tree(pid: int): """ Kills all descendant processes of the given pid by sending SIGKILL. @@ -2063,55 +1536,6 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: raise ValueError(f"Unsupported hash function: {hash_fn_name}") -def is_torch_equal_or_newer(target: str) -> bool: - """Check if the installed torch version is >= the target version. - - Args: - target: a version string, like "2.6.0". - - Returns: - Whether the condition meets. - """ - try: - return _is_torch_equal_or_newer(str(torch.__version__), target) - except Exception: - # Fallback to PKG-INFO to load the package info, needed by the doc gen. - return Version(importlib.metadata.version("torch")) >= Version(target) - - -# Helper function used in testing. -def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: - torch_version = version.parse(torch_version) - return torch_version >= version.parse(target) - - -def _is_torch_equal(target: str) -> bool: - assert target.count(".") == 2 - torch_version = str(torch.__version__) - torch_version = version.parse(torch_version) - # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" - # or "2.6.0+cu128" but never "2.6.0.1" - return ( - torch_version >= version.parse(target) - and version.parse(target + ".1") > torch_version - ) - - -def is_torch_equal(target: str) -> bool: - """Check if the installed torch version is == the target version. - - Args: - target: a version string, like "2.6.0". - - Returns: - Whether the condition meets. - """ - try: - return _is_torch_equal(target) - except Exception: - return Version(importlib.metadata.version("torch")) == Version(target) - - @cache def _has_module(module_name: str) -> bool: """Return True if *module_name* can be found in the current environment. diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py new file mode 100644 index 0000000000..adcacb34cb --- /dev/null +++ b/vllm/utils/torch_utils.py @@ -0,0 +1,605 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import importlib.metadata +import threading +from collections.abc import Callable, Collection +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeVar + +import numpy as np +import numpy.typing as npt +import torch +from packaging import version +from packaging.version import Version +from torch.library import Library + +import vllm.envs as envs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import IntermediateTensors +else: + ModelConfig = object + IntermediateTensors = object + + +STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, + "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + + +T = TypeVar("T") + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int): + """Sets the default number of threads for PyTorch to the given value.""" + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + yield + torch.set_num_threads(old_num_threads) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + # -----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, +) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def kv_cache_dtype_str_to_dtype( + kv_cache_dtype: str, model_config: ModelConfig +) -> torch.dtype: + if kv_cache_dtype == "auto": + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) + scale = head_size**-0.5 + + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=dtype, device=device + ).permute(*stride_order) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: str | torch.device, + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def make_ndarray_with_pad( + x: list[list[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: int | None = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, : len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: list[list[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: int | None = None, + device: str | torch.device | None = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +prev_set_stream = torch.cuda.set_stream + +_current_stream_tls = threading.local() + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + _current_stream_tls.value = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +class _StreamPlaceholder: + def __init__(self): + self.synchronize = lambda: None + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from vllm.platforms import current_platform + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + if current_platform.is_rocm(): + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API" + ) + return _current_stream_tls.value + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + from vllm.platforms import current_platform + + if not torch.cuda._is_compiled(): + return 0 + if current_platform.is_rocm(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: torch.Tensor + | list[torch.Tensor] + | tuple[torch.Tensor] + | IntermediateTensors, +) -> torch.Tensor | list[Any] | tuple[Any] | Any: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) + return ret + raise ValueError("Invalid type for tensors") + + +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target) + + +def is_torch_equal_or_newer(target: str) -> bool: + """Check if the installed torch version is >= the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal_or_newer(str(torch.__version__), target) + except Exception: + # Fallback to PKG-INFO to load the package info, needed by the doc gen. + return Version(importlib.metadata.version("torch")) >= Version(target) + + +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + return is_torch_equal_or_newer("2.4.0") + + +# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform +def supports_xccl() -> bool: + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if not supports_custom_op(): + from vllm.platforms import current_platform + + assert not current_platform.is_cuda_alike(), ( + "cuda platform needs torch>=2.4 to support custom op, " + "chances are you are using an old version of pytorch " + "or a custom build of pytorch. It is recommended to " + "use vLLM in a fresh new environment and let it install " + "the required dependencies." + ) + return + + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + + dispatch_key = current_platform.dispatch_key + + import torch.library + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 29884700d9..e1fb48b309 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -28,7 +28,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) -from vllm.utils import cdiv, is_torch_equal_or_newer +from vllm.utils import cdiv +from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 7c73611d4a..f7a4114a0a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -29,7 +29,7 @@ if current_platform.is_rocm(): import aiter from vllm.triton_utils import tl, triton - from vllm.utils import direct_register_custom_op + from vllm.utils.torch_utils import direct_register_custom_op @triton.jit def _vllm_layout_trans_kernel( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index a9ef1b92c2..392519f8fa 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -10,7 +10,8 @@ from typing_extensions import Self from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import cdiv, get_dtype_size +from vllm.utils import cdiv +from vllm.utils.torch_utils import get_dtype_size logger = init_logger(__name__) diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index e49b8db478..44f53d95dd 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -4,7 +4,8 @@ import torch from vllm.model_executor.layers.utils import apply_penalties -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad def apply_all_penalties( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 72fc00f6ed..258f2f460f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -71,16 +71,18 @@ from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import ( cdiv, check_use_alibi, - get_dtype_size, is_pin_memory_available, - kv_cache_dtype_str_to_dtype, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, ) from vllm.utils.jsontree import json_map_leaves from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_utils import DeviceMemoryProfiler +from vllm.utils.torch_utils import ( + get_dtype_size, + kv_cache_dtype_str_to_dtype, + supports_dynamo, +) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c19ed1fc0b..fae1f8e37b 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -25,7 +25,8 @@ from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.platforms.tpu import USE_TPU_INFERENCE from vllm.tasks import SupportedTask -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import cdiv +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 867ce2b930..6edcb78486 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -7,7 +7,7 @@ import torch from vllm import forward_context from vllm.forward_context import ForwardContext -from vllm.utils import current_stream +from vllm.utils.torch_utils import current_stream _THREAD_ID_TO_CONTEXT: dict = {} _CURRENT_CONTEXTS: list[Optional["UBatchContext"]] = [None, None]