mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Chore] Clean up pytorch helper functions in vllm.utils
(#26908)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: isotr0py <2037008807@qq.com>
This commit is contained in:
@ -10,7 +10,8 @@ import torch
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
def with_triton_mode(fn):
|
||||
|
@ -10,7 +10,8 @@ import vllm.model_executor.layers.activation # noqa F401
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
@ -7,7 +7,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -9,9 +9,9 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
FlexibleArgumentParser,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -9,9 +9,9 @@ from tabulate import tabulate
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
FlexibleArgumentParser,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
|
@ -12,9 +12,9 @@ from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
FlexibleArgumentParser,
|
||||
create_kv_caches_with_random_flash,
|
||||
)
|
||||
|
||||
|
@ -11,7 +11,7 @@ from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -20,7 +20,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
@ -19,7 +19,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from ..silly_attention import get_global_counter, reset_global_counter
|
||||
|
@ -27,7 +27,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from .. import silly_attention # noqa: F401
|
||||
|
@ -8,7 +8,7 @@ Centralizes custom operation definitions to avoid duplicate registrations.
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# Shared library for all compilation test operations
|
||||
# Using "silly" namespace to match existing test expectations
|
||||
|
@ -15,7 +15,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
def reference_fn(x: torch.Tensor):
|
||||
|
@ -5,7 +5,7 @@ import dataclasses
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import compare_all_settings
|
||||
|
||||
|
@ -8,7 +8,7 @@ from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def test_version():
|
||||
|
@ -15,7 +15,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
@ -12,7 +12,7 @@ from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
@ -15,8 +15,8 @@ from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..utils import flat_product, multi_gpu_test
|
||||
|
||||
|
@ -60,8 +60,8 @@ from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -18,7 +18,7 @@ import pytest
|
||||
from vllm.config.compilation import CompilationMode
|
||||
from vllm.config.model import RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
|
@ -11,10 +11,10 @@ import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import (
|
||||
cuda_device_count_stateless,
|
||||
get_open_port,
|
||||
update_environment_variables,
|
||||
)
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
|
@ -3,7 +3,10 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash
|
||||
from vllm.utils.torch_utils import (
|
||||
create_kv_caches_with_random,
|
||||
create_kv_caches_with_random_flash,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -15,7 +15,7 @@ from tests.kernels.utils import make_alibi_bias
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 64]
|
||||
|
@ -3,7 +3,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
from vllm.utils import is_uva_available
|
||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
@ -13,8 +13,9 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
|
@ -22,8 +22,8 @@ from vllm.utils import (
|
||||
STR_BACKEND_ENV_VAR,
|
||||
STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad,
|
||||
)
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
# bugs related to this test in PyTorch 2.4.
|
||||
|
@ -7,7 +7,7 @@ from huggingface_hub import snapshot_download
|
||||
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
|
||||
|
@ -9,7 +9,7 @@ from transformers import AutoConfig, AutoModel, CLIPImageProcessor
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
|
||||
|
@ -26,7 +26,6 @@ from vllm.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
supports_multimodal,
|
||||
@ -36,6 +35,7 @@ from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingC
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
@ -46,10 +46,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
cuda_device_count_stateless,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from amdsmi import (
|
||||
|
@ -24,11 +24,8 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
bind_kv_cache,
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
is_lossless_cast,
|
||||
join_host_port,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
@ -37,6 +34,11 @@ from vllm.utils import (
|
||||
split_zmq_path,
|
||||
unique_filepath,
|
||||
)
|
||||
from vllm.utils.torch_utils import (
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
is_lossless_cast,
|
||||
)
|
||||
|
||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
@ -408,7 +410,7 @@ def test_bind_kv_cache_non_attention():
|
||||
|
||||
|
||||
def test_bind_kv_cache_pp():
|
||||
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
|
||||
with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
|
||||
# this test runs with 1 GPU, but we simulate 2 GPUs
|
||||
cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
|
||||
with set_current_vllm_config(cfg):
|
||||
|
@ -18,7 +18,8 @@ from tests.v1.attention.utils import (
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
set_kv_cache_layout,
|
||||
|
@ -22,7 +22,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
@ -15,7 +15,7 @@ from vllm.inputs import PromptType
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import (
|
||||
AggregatedLoggingStatLogger,
|
||||
|
@ -12,7 +12,7 @@ from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
|
@ -21,7 +21,7 @@ from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublis
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
|
||||
|
@ -7,7 +7,8 @@ import torch
|
||||
|
||||
from tests.v1.sample.utils import create_allowed_token_ids
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
@ -9,7 +9,7 @@ import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import CompletionOutput
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
@ -12,7 +12,7 @@ from tests.v1.shutdown.utils import (
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||
|
@ -14,7 +14,7 @@ from tests.v1.shutdown.utils import (
|
||||
from vllm import LLM, AsyncEngineArgs, SamplingParams
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
|
@ -13,7 +13,7 @@ from vllm import LLM
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||
|
@ -10,7 +10,8 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
|
@ -5,7 +5,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
|
||||
def get_aiter_mla_metadata(
|
||||
|
@ -24,8 +24,8 @@ from vllm.compilation.partition_rules import (
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .caching import VllmSerializableFunction
|
||||
from .compiler_interface import (
|
||||
|
@ -21,7 +21,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
@ -16,7 +16,7 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
|
@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import weak_ref_tensors
|
||||
from vllm.utils.torch_utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -21,8 +21,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import supports_dynamo
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import supports_dynamo
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
|
@ -14,7 +14,7 @@ import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
@ -16,8 +16,8 @@ from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
@ -41,8 +41,9 @@ from vllm.transformers_utils.config import (
|
||||
)
|
||||
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import LayerBlockType, common_broadcastable_dtype
|
||||
from vllm.utils import LayerBlockType
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.torch_utils import common_broadcastable_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
@ -18,7 +18,8 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
||||
from vllm.utils import get_open_ports_list
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
|
@ -22,7 +22,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||
from vllm.utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -17,7 +17,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
|
@ -19,7 +19,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -30,7 +30,7 @@ def register_nccl_symmetric_ops(pynccl_comm):
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
nccl_symm_mem_context,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
global _NCCL_SYMM_OPS_REGISTERED
|
||||
if _NCCL_SYMM_OPS_REGISTERED:
|
||||
|
@ -13,7 +13,7 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -14,7 +14,7 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import current_stream
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -25,7 +25,8 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool,
|
||||
)
|
||||
from vllm.utils import current_stream, get_ip
|
||||
from vllm.utils import get_ip
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -50,11 +50,13 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (
|
||||
direct_register_custom_op,
|
||||
get_distributed_init_method,
|
||||
supports_custom_op,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
supports_custom_op,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -29,7 +29,8 @@ from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
||||
from vllm.utils import get_tcp_uri
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_torch_equal
|
||||
from vllm.utils.torch_utils import is_torch_equal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -246,7 +246,7 @@ def maybe_convert_bool(value: str | None) -> bool | None:
|
||||
|
||||
|
||||
def use_aot_compile() -> bool:
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
|
||||
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
|
||||
|
@ -12,7 +12,7 @@ import torch
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
@ -12,7 +12,7 @@ import torch
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
|
@ -52,8 +52,8 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Sc
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
|
@ -52,8 +52,9 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up
|
||||
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
|
@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class QuantMethod(IntEnum):
|
||||
|
@ -23,8 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_e4m3_quantize,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, is_torch_equal_or_newer
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.flashinfer import flashinfer_fp4_quantize
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
|
@ -34,7 +34,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -37,7 +37,7 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
@ -6,7 +6,10 @@ import torch
|
||||
from vllm.config.cache import MambaDType
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.distributed import divide
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
get_kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
|
@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import (
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
|
@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
class FPQuantConfig(QuantizationConfig):
|
||||
|
@ -28,7 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
@ -49,10 +49,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (
|
||||
has_triton_kernels,
|
||||
is_torch_equal_or_newer,
|
||||
round_up,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -45,7 +45,7 @@ try:
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
@ -28,13 +28,13 @@ from vllm.model_executor.parameter import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
should_use_deepgemm_for_fp8_linear,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def _quant_dequant_mxfp6(
|
||||
|
@ -12,8 +12,8 @@ from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
|
@ -10,7 +10,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_triton_rotary_embedding_enabled() -> bool:
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -11,8 +11,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import ParamMapping, set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.utils import ParamMapping
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
@ -48,6 +48,7 @@ from vllm.model_executor.utils import (
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -15,13 +15,13 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
|
@ -22,8 +22,8 @@ from vllm.model_executor.model_loader.tensorizer import (
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
initialize_model,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -14,8 +14,8 @@ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
@ -32,15 +31,6 @@ from vllm.utils import is_pin_memory_available
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
|
@ -6,7 +6,8 @@ from typing import TYPE_CHECKING
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, round_up
|
||||
from vllm.utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -79,8 +79,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerBackend,
|
||||
DeepseekV32IndexerMetadata,
|
||||
|
@ -18,7 +18,6 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.transformers.utils import replace_linear_class
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
@ -51,6 +50,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processo
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (
|
||||
|
@ -51,8 +51,8 @@ from vllm.multimodal.processing import (
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
|
@ -49,7 +49,6 @@ from vllm.model_executor.layers.resampler import (
|
||||
Resampler2,
|
||||
get_2d_sincos_pos_embed,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -88,6 +87,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.collections import flatten_2d_lists
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import (
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user