mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Update deprecated type hinting in vllm/compilation
(#18072)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -74,7 +74,6 @@ exclude = [
|
||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
|
||||
|
@ -5,8 +5,9 @@ import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -56,7 +57,7 @@ class CompilerManager:
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
|
||||
self.cache: dict[tuple[Optional[int], int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
@ -90,7 +91,7 @@ class CompilerManager:
|
||||
|
||||
def load(self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Optional[Callable]:
|
||||
if (runtime_shape, graph_index, self.compiler.name) not in self.cache:
|
||||
@ -186,7 +187,7 @@ class SplitItem:
|
||||
|
||||
|
||||
def split_graph(graph: fx.GraphModule,
|
||||
ops: List[str]) -> Tuple[fx.GraphModule, List[SplitItem]]:
|
||||
ops: list[str]) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
@ -252,7 +253,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: List[str], vllm_config: VllmConfig,
|
||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||
graph_pool, vllm_backend: "VllmBackend"):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
@ -274,8 +275,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: Tuple[torch.fx.node.Argument,
|
||||
...], kwargs: Dict[str, Any]) -> Any:
|
||||
args: tuple[torch.fx.node.Argument,
|
||||
...], kwargs: dict[str, Any]) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
@ -326,12 +327,12 @@ class VllmBackend:
|
||||
graph: fx.GraphModule
|
||||
# the stiching graph module for all the piecewise graphs
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: List[SplitItem]
|
||||
piecewise_graphs: list[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
sym_tensor_indices: List[int]
|
||||
input_buffers: List[torch.Tensor]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
|
||||
def __init__(
|
||||
@ -573,14 +574,14 @@ class ConcreteSizeEntry:
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[List[int]] = None
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
@ -608,9 +609,9 @@ class PiecewiseBackend:
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: Set[int] = set(
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.cudagraph_capture_sizes: Set[int] = set(
|
||||
self.cudagraph_capture_sizes: set[int] = set(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
@ -624,11 +625,11 @@ class PiecewiseBackend:
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
|
@ -4,7 +4,7 @@ import copy
|
||||
import hashlib
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -48,11 +48,11 @@ class CompilerInterface:
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
@ -82,7 +82,7 @@ class CompilerInterface:
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
"""
|
||||
@ -120,7 +120,7 @@ class AlwaysHitShapeEnv:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: List[Any] = []
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args, **kwargs):
|
||||
return True
|
||||
@ -132,8 +132,8 @@ class AlwaysHitShapeEnv:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> List[Any]:
|
||||
factors: List[Any] = []
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
@ -169,11 +169,11 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
@ -201,7 +201,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
@ -256,11 +256,11 @@ class InductorAdaptor(CompilerInterface):
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@ -420,7 +420,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
@ -522,11 +522,11 @@ class EagerAdaptor(CompilerInterface):
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, TypeVar, Union, overload
|
||||
from typing import Callable, Optional, TypeVar, Union, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -25,7 +25,7 @@ _T = TypeVar("_T", bound=type[nn.Module])
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]],
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]],
|
||||
) -> Callable[[_T], _T]:
|
||||
...
|
||||
|
||||
@ -38,7 +38,7 @@ def support_torch_compile(cls: _T) -> _T:
|
||||
def support_torch_compile(
|
||||
cls: Optional[_T] = None,
|
||||
*,
|
||||
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None,
|
||||
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
|
||||
) -> Union[Callable[[_T], _T], _T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
@ -131,7 +131,7 @@ def support_torch_compile(
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
dynamic_arg_dims: Dict[str, Union[int, List[int]]],
|
||||
dynamic_arg_dims: dict[str, Union[int, list[int]]],
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import operator
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
@ -27,7 +28,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_fix_functionalization")
|
||||
|
||||
self.nodes_to_remove: List[torch.fx.Node] = []
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
@ -117,8 +118,8 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
def defunctionalize(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: Dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
||||
mutated_args: dict[int, Union[torch.fx.Node, str]],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
@ -130,7 +131,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(self, node: torch.fx.Node,
|
||||
mutated_args: Dict[int,
|
||||
mutated_args: dict[int,
|
||||
Union[torch.fx.Node,
|
||||
str]]):
|
||||
"""
|
||||
@ -146,7 +147,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]:
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
@ -161,7 +162,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
def insert_defunctionalized(self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: Optional[Tuple[Union[torch.fx.Node, str],
|
||||
args: Optional[tuple[Union[torch.fx.Node, str],
|
||||
...]] = None):
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@ -57,7 +57,7 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
||||
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
||||
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
||||
|
||||
QUANT_OPS: Dict[QuantKey, OpOverload] = {
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
||||
kFp8DynamicTensorSym:
|
||||
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
||||
@ -80,7 +80,7 @@ class FusedRMSQuantKey(NamedTuple):
|
||||
f"{'' if self.fused_add else 'out'} residual)")
|
||||
|
||||
|
||||
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
||||
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||
@ -101,7 +101,7 @@ class QuantMultiOutputMatch(MultiOutputMatch):
|
||||
self.QUANT_OP = quant_op # in-place quant op
|
||||
self.FUSED_OP = fused_op # in-place fused quant op
|
||||
|
||||
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
|
||||
def insert_fused_node(self, fused_return_mapping: dict[int, tuple[fx.Node,
|
||||
int]],
|
||||
**kwargs):
|
||||
"""
|
||||
@ -548,7 +548,7 @@ class FusionPass(VllmInductorPass):
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
self.matches: List[MultiOutputMatch] = []
|
||||
self.matches: list[MultiOutputMatch] = []
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="fusion_pass")
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import operator
|
||||
from typing import Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
@ -5,7 +5,7 @@ import inspect
|
||||
import json
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
@ -83,7 +83,7 @@ class InductorPass(CustomGraphPass):
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: Dict[Any, Any]):
|
||||
def hash_dict(dict_: dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
|
@ -3,7 +3,7 @@
|
||||
import abc
|
||||
import operator
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable, List, Tuple
|
||||
from collections.abc import Iterable
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
@ -56,7 +56,7 @@ class MultiOutputMatch(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def nodes(self) -> List[fx.Node]:
|
||||
def nodes(self) -> list[fx.Node]:
|
||||
return self.match.nodes
|
||||
|
||||
@property
|
||||
@ -87,7 +87,7 @@ class MultiOutputMatch(abc.ABC):
|
||||
return self.graph.inserting_after(last_node_in_match)
|
||||
|
||||
def insert_getitems(self, tuple_node: fx.Node,
|
||||
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
|
||||
indices: Iterable[int]) -> tuple[fx.Node, ...]:
|
||||
"""
|
||||
Insert operator.getitem nodes to extract elements from a tuple node.
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, Union
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -34,7 +32,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.passes: List[VllmInductorPass] = []
|
||||
self.passes: list[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph):
|
||||
shape = get_pass_context().runtime_shape
|
||||
|
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
@ -125,7 +125,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
@ -142,7 +142,7 @@ class MiddleAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
@ -190,7 +190,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(mm_1)
|
||||
|
||||
rmsnorm = torch.ops.higher_order.auto_functionalized(
|
||||
@ -207,7 +207,7 @@ class LastAllReduceRMSNormPattern(AllReduceRMSNormPattern):
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
tp = get_tp_group()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -48,7 +48,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: List[CodeType] = []
|
||||
self.compiled_codes: list[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
|
Reference in New Issue
Block a user