[Chore] Clean up pytorch helper functions in vllm.utils (#26908)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-10-19 00:48:22 +08:00
committed by GitHub
parent 5c2acb270a
commit 6ac5e06f7c
119 changed files with 772 additions and 714 deletions

View File

@ -10,7 +10,8 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
def with_triton_mode(fn):

View File

@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]

View File

@ -7,7 +7,8 @@ import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()

View File

@ -9,9 +9,9 @@ import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
@torch.inference_mode()

View File

@ -9,9 +9,9 @@ from tabulate import tabulate
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)

View File

@ -12,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils import FlexibleArgumentParser
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random_flash,
)

View File

@ -11,7 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
@contextlib.contextmanager

View File

@ -20,7 +20,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401

View File

@ -19,7 +19,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter

View File

@ -27,7 +27,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from .. import silly_attention # noqa: F401

View File

@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
import torch
from torch.library import Library
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
# Shared library for all compilation test operations
# Using "silly" namespace to match existing test expectations

View File

@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):

View File

@ -5,7 +5,7 @@ import dataclasses
import pytest
from vllm.config import CompilationMode
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import compare_all_settings

View File

@ -8,7 +8,7 @@ from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import CompilationMode
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
def test_version():

View File

@ -15,7 +15,7 @@ from vllm.config import (
set_current_vllm_config,
)
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401

View File

@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test

View File

@ -15,8 +15,8 @@ from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..utils import flat_product, multi_gpu_test

View File

@ -60,8 +60,8 @@ from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import set_default_torch_num_threads
from vllm.utils.collections import is_list_of
from vllm.utils.torch_utils import set_default_torch_num_threads
logger = init_logger(__name__)

View File

@ -18,7 +18,7 @@ import pytest
from vllm.config.compilation import CompilationMode
from vllm.config.model import RunnerOption
from vllm.logger import init_logger
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import compare_two_settings, create_new_process_for_each_test

View File

@ -11,10 +11,10 @@ import vllm.envs as envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (
cuda_device_count_stateless,
get_open_port,
update_environment_variables,
)
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import multi_gpu_test

View File

@ -3,7 +3,10 @@
import pytest
from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash
from vllm.utils.torch_utils import (
create_kv_caches_with_random,
create_kv_caches_with_random_flash,
)
@pytest.fixture()

View File

@ -15,7 +15,7 @@ from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 64]

View File

@ -3,7 +3,8 @@
import pytest
import torch
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
from vllm.utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]

View File

@ -13,8 +13,9 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import cuda_device_count_stateless
from .modular_kernel_tools.common import (
Config,

View File

@ -22,8 +22,8 @@ from vllm.utils import (
STR_BACKEND_ENV_VAR,
STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad,
)
from vllm.utils.torch_utils import make_tensor_with_pad
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.

View File

@ -7,7 +7,7 @@ from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets

View File

@ -9,7 +9,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.radio import RadioModel
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets

View File

@ -26,7 +26,6 @@ from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
supports_multimodal,
@ -36,6 +35,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingC
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of
from vllm.utils.torch_utils import set_default_torch_dtype
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides

View File

@ -46,10 +46,10 @@ from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import (
FlexibleArgumentParser,
cuda_device_count_stateless,
get_open_port,
)
from vllm.utils.mem_constants import GB_bytes
from vllm.utils.torch_utils import cuda_device_count_stateless
if current_platform.is_rocm():
from amdsmi import (

View File

@ -24,11 +24,8 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
from vllm.utils import (
FlexibleArgumentParser,
bind_kv_cache,
common_broadcastable_dtype,
current_stream,
get_open_port,
get_tcp_uri,
is_lossless_cast,
join_host_port,
make_zmq_path,
make_zmq_socket,
@ -37,6 +34,11 @@ from vllm.utils import (
split_zmq_path,
unique_filepath,
)
from vllm.utils.torch_utils import (
common_broadcastable_dtype,
current_stream,
is_lossless_cast,
)
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from ..utils import create_new_process_for_each_test, flat_product
@ -408,7 +410,7 @@ def test_bind_kv_cache_non_attention():
def test_bind_kv_cache_pp():
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
with set_current_vllm_config(cfg):

View File

@ -18,7 +18,8 @@ from tests.v1.attention.utils import (
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout,

View File

@ -22,7 +22,8 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.config.vllm import set_current_vllm_config
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@ -15,7 +15,7 @@ from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import (
AggregatedLoggingStatLogger,

View File

@ -12,7 +12,7 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor

View File

@ -21,7 +21,7 @@ from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublis
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient

View File

@ -7,7 +7,8 @@ import torch
from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler

View File

@ -9,7 +9,7 @@ import regex as re
import torch
from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata

View File

@ -12,7 +12,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import RequestOutputKind
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]

View File

@ -14,7 +14,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError

View File

@ -13,7 +13,7 @@ from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]

View File

@ -10,7 +10,8 @@ import torch
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata

View File

@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils import (
from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)

View File

@ -5,7 +5,7 @@
import torch
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
def get_aiter_mla_metadata(

View File

@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import (
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .caching import VllmSerializableFunction
from .compiler_interface import (

View File

@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm

View File

@ -16,7 +16,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
class CompilerInterface:

View File

@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__)

View File

@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo
from .monitor import start_monitoring_torch_compile

View File

@ -14,7 +14,7 @@ import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass

View File

@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
from vllm.config import VllmConfig

View File

@ -41,8 +41,9 @@ from vllm.transformers_utils.config import (
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType, common_broadcastable_dtype
from vllm.utils import LayerBlockType
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
if TYPE_CHECKING:
from transformers import PretrainedConfig

View File

@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
from vllm.utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv

View File

@ -22,7 +22,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils import cuda_device_count_stateless, update_environment_variables
from vllm.utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
try:
ops.meta_size()

View File

@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream
from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__)
@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context,
)
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED:

View File

@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)

View File

@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.utils import current_stream
from vllm.utils.torch_utils import current_stream
logger = init_logger(__name__)

View File

@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool,
)
from vllm.utils import current_stream, get_ip
from vllm.utils import get_ip
from vllm.utils.torch_utils import current_stream
logger = logging.getLogger(__name__)

View File

@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import (
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (
direct_register_custom_op,
get_distributed_init_method,
supports_custom_op,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import (
direct_register_custom_op,
supports_custom_op,
)
@dataclass

View File

@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
from vllm.utils import get_tcp_uri
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)

View File

@ -5,7 +5,7 @@ import os
import torch
from vllm.logger import init_logger
from vllm.utils import is_torch_equal
from vllm.utils.torch_utils import is_torch_equal
logger = init_logger(__name__)

View File

@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None:
def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"

View File

@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit

View File

@ -12,7 +12,7 @@ import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit

View File

@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def flashinfer_fused_moe_blockscale_fp8(

View File

@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

View File

@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():

View File

@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class QuantMethod(IntEnum):

View File

@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.utils import cdiv
from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.torch_utils import is_torch_equal_or_newer
@triton.jit

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:

View File

@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata

View File

@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024

View File

@ -6,7 +6,10 @@ import torch
from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType
from vllm.distributed import divide
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
class MambaStateDtypeCalculator:

View File

@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata

View File

@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig):

View File

@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
class FPQuantConfig(QuantizationConfig):

View File

@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)

View File

@ -7,7 +7,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig

View File

@ -49,10 +49,10 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import (
has_triton_kernels,
is_torch_equal_or_newer,
round_up,
)
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)

View File

@ -45,7 +45,7 @@ try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip

View File

@ -28,13 +28,13 @@ from vllm.model_executor.parameter import (
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)

View File

@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
logger = init_logger(__name__)

View File

@ -3,7 +3,7 @@
import torch
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def _quant_dequant_mxfp6(

View File

@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale

View File

@ -10,7 +10,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

View File

@ -5,7 +5,7 @@ import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:

View File

@ -9,7 +9,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.torch_utils import direct_register_custom_op
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:

View File

@ -11,8 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)

View File

@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
@ -48,6 +48,7 @@ from vllm.model_executor.utils import (
set_weight_attrs,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)

View File

@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
from vllm.utils.torch_utils import set_default_torch_dtype
class GGUFModelLoader(BaseModelLoader):

View File

@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import (
from vllm.model_executor.model_loader.utils import (
get_model_architecture,
initialize_model,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)

View File

@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
set_default_torch_dtype,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import contextlib
import inspect
import warnings
from contextlib import contextmanager
@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def initialize_model(
vllm_config: VllmConfig,
*,

View File

@ -6,7 +6,8 @@ from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
from vllm.utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING:

View File

@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata,

View File

@ -18,7 +18,6 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers.utils import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (

View File

@ -51,8 +51,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads
from .interfaces import (
MultiModalEmbeddings,

View File

@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import (
Resampler2,
get_2d_sincos_pos_embed,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -88,6 +87,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.collections import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import (

Some files were not shown because too many files have changed in this diff Show More