Update deprecated type hinting in vllm/compilation (#18072)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 16:32:48 +01:00
committed by GitHub
parent fc407a1425
commit 19324d660c
13 changed files with 70 additions and 69 deletions

View File

@ -74,7 +74,6 @@ exclude = [
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
"vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"]
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
"vllm/distributed/**/*.py" = ["UP006", "UP035"]

View File

@ -5,8 +5,9 @@ import dataclasses
import os
import pprint
import time
from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
@ -56,7 +57,7 @@ class CompilerManager:
"""
def __init__(self, compilation_config: CompilationConfig):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
@ -90,7 +91,7 @@ class CompilerManager:
def load(self,
graph: fx.GraphModule,
example_inputs: List[Any],
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Optional[Callable]:
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
@ -186,7 +187,7 @@ class SplitItem:
def split_graph(graph: fx.GraphModule,
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str], vllm_config: VllmConfig,
compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
@ -274,8 +275,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
return super().run(*fake_args)
def call_module(self, target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument,
...], kwargs: Dict[str, Any]) -> Any:
args: tuple[torch.fx.node.Argument,
...], kwargs: dict[str, Any]) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
@ -326,12 +327,12 @@ class VllmBackend:
graph: fx.GraphModule
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: List[SplitItem]
piecewise_graphs: list[SplitItem]
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
sym_tensor_indices: List[int]
input_buffers: List[torch.Tensor]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
def __init__(
@ -573,14 +574,14 @@ class ConcreteSizeEntry:
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[List[int]] = None
input_addresses: Optional[list[int]] = None
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int],
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
@ -608,9 +609,9 @@ class PiecewiseBackend:
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: Set[int] = set(
self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: Set[int] = set(
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()
@ -624,11 +625,11 @@ class PiecewiseBackend:
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,

View File

@ -4,7 +4,7 @@ import copy
import hashlib
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
@ -48,11 +48,11 @@ class CompilerInterface:
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
) -> tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
with a runtime shape. If the `runtime_shape` is None, it means
@ -82,7 +82,7 @@ class CompilerInterface:
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
"""
@ -120,7 +120,7 @@ class AlwaysHitShapeEnv:
"""
def __init__(self) -> None:
self.guards: List[Any] = []
self.guards: list[Any] = []
def evaluate_guards_expression(self, *args, **kwargs):
return True
@ -132,8 +132,8 @@ class AlwaysHitShapeEnv:
return ""
def get_inductor_factors() -> List[Any]:
factors: List[Any] = []
def get_inductor_factors() -> list[Any]:
factors: list[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
@ -169,11 +169,11 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
) -> tuple[Optional[Callable], Optional[Any]]:
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
@ -201,7 +201,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
@ -256,11 +256,11 @@ class InductorAdaptor(CompilerInterface):
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
) -> tuple[Optional[Callable], Optional[Any]]:
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
@ -420,7 +420,7 @@ class InductorAdaptor(CompilerInterface):
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
example_inputs: list[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface):
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
) -> tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
return graph, None

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch
import torch
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
@overload
def support_torch_compile(
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
) -> Callable[[_T], _T]:
...
@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T:
def support_torch_compile(
cls: Optional[_T] = None,
*,
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
) -> Union[Callable[[_T], _T], _T]:
"""
A decorator to add support for compiling the forward method of a class.
@ -131,7 +131,7 @@ def support_torch_compile(
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
dynamic_arg_dims: dict[str, Union[int, list[int]]],
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import operator
from typing import Dict, Iterable, List, Optional, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -27,7 +28,7 @@ class FixFunctionalizationPass(VllmInductorPass):
self.begin()
self.dump_graph(graph, "before_fix_functionalization")
self.nodes_to_remove: List[torch.fx.Node] = []
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
@ -117,8 +118,8 @@ class FixFunctionalizationPass(VllmInductorPass):
def defunctionalize(self,
graph: torch.fx.Graph,
node: torch.fx.Node,
mutated_args: Dict[int, Union[torch.fx.Node, str]],
args: Optional[Tuple[Union[torch.fx.Node, str],
mutated_args: dict[int, Union[torch.fx.Node, str]],
args: Optional[tuple[Union[torch.fx.Node, str],
...]] = None):
"""
De-functionalize a node by replacing it with a call to the original.
@ -130,7 +131,7 @@ class FixFunctionalizationPass(VllmInductorPass):
self._remove(node)
def replace_users_with_mutated_args(self, node: torch.fx.Node,
mutated_args: Dict[int,
mutated_args: dict[int,
Union[torch.fx.Node,
str]]):
"""
@ -146,7 +147,7 @@ class FixFunctionalizationPass(VllmInductorPass):
user.replace_all_uses_with(arg)
self._remove(user)
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
"""
Returns the operator.getitem users of the auto-functionalized node,
indexed by the index they are getting.
@ -161,7 +162,7 @@ class FixFunctionalizationPass(VllmInductorPass):
def insert_defunctionalized(self,
graph: torch.fx.Graph,
node: torch.fx.Node,
args: Optional[Tuple[Union[torch.fx.Node, str],
args: Optional[tuple[Union[torch.fx.Node, str],
...]] = None):
"""
Insert a new defunctionalized node into the graph before node.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
from typing import Callable, NamedTuple, Optional
import torch
import torch._inductor.pattern_matcher as pm
@ -57,7 +57,7 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
QUANT_OPS: Dict[QuantKey, OpOverload] = {
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
@ -80,7 +80,7 @@ class FusedRMSQuantKey(NamedTuple):
f"{'' if self.fused_add else 'out'} residual)")
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8StaticTensorSym, True):
@ -101,7 +101,7 @@ class QuantMultiOutputMatch(MultiOutputMatch):
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
int]],
**kwargs):
"""
@ -548,7 +548,7 @@ class FusionPass(VllmInductorPass):
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: List[MultiOutputMatch] = []
self.matches: list[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")

View File

@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import operator
from typing import Iterable, Optional
from collections.abc import Iterable
from typing import Optional
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized

View File

@ -5,7 +5,7 @@ import inspect
import json
import types
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from torch import fx
@ -83,7 +83,7 @@ class InductorPass(CustomGraphPass):
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: Dict[Any, Any]):
def hash_dict(dict_: dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.

View File

@ -3,7 +3,7 @@
import abc
import operator
from abc import abstractmethod
from typing import Iterable, List, Tuple
from collections.abc import Iterable
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -56,7 +56,7 @@ class MultiOutputMatch(abc.ABC):
raise NotImplementedError
@property
def nodes(self) -> List[fx.Node]:
def nodes(self) -> list[fx.Node]:
return self.match.nodes
@property
@ -87,7 +87,7 @@ class MultiOutputMatch(abc.ABC):
return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
indices: Iterable[int]) -> tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Union
from collections.abc import Iterable
from typing import Union
import torch.fx
from torch import SymInt

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
from torch import fx as fx
from vllm.config import VllmConfig
@ -34,7 +32,7 @@ class PostGradPassManager(CustomGraphPass):
"""
def __init__(self):
self.passes: List[VllmInductorPass] = []
self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph):
shape = get_pass_context().runtime_shape

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
from typing import Optional
import torch
import torch._inductor.pattern_matcher as pm
@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
all_reduce = tensor_model_parallel_all_reduce(mm_1)
rmsnorm = torch.ops.higher_order.auto_functionalized(
@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
residual: torch.Tensor,
mm_1: torch.Tensor,
rms_norm_weights: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
tp = get_tp_group()
tp_size = get_tensor_model_parallel_world_size()
reduce_scatter = torch.ops.vllm.reduce_scatter.default(

View File

@ -5,7 +5,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
@ -48,7 +48,7 @@ class TorchCompileWrapperWithCustomDispatcher:
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
self.compiled_codes: list[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher