mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[bugfix] fix weak ref in piecewise cudagraph and tractable test (#10048)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -1,6 +1,10 @@
|
||||
"""
|
||||
Test the piecewise compilation with a simple model, comparing the output
|
||||
with and without the piecewise compilation.
|
||||
|
||||
This is a tractable model, the weights and computation are specially designed
|
||||
if the config `tractable_init` is set to True. Otherwise, the weights are
|
||||
initialized randomly with a fixed seed.
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@ -49,6 +53,12 @@ class LlamaConfig:
|
||||
mlp_size: int = 256
|
||||
vocab_size: int = 128
|
||||
num_layers: int = 2
|
||||
init_value: float = 1.0
|
||||
tractable_init: bool = False
|
||||
random_seed: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.mlp_size >= self.hidden_size
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
@ -66,10 +76,23 @@ class LlamaMLP(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.gate_up_projection.weight.data.fill_(0.0)
|
||||
self.down_projection.weight.data.fill_(0.0)
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
|
||||
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
|
||||
nn.init.eye_(self.down_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(self.down_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
# for tractable_init and positive input, this is
|
||||
# essentially an elementwise-square
|
||||
x = self.gate_up_projection(x)
|
||||
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
|
||||
x[:, x.size(1) // 2:])
|
||||
@ -84,21 +107,39 @@ class LlamaAttention(nn.Module):
|
||||
self.qkv_projection = nn.Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.hidden_size * 3,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.output_projection = nn.Linear(
|
||||
in_features=config.hidden_size,
|
||||
out_features=config.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.qkv_projection.weight.data.fill_(0.0)
|
||||
self.output_projection.weight.data.fill_(0.0)
|
||||
if config.tractable_init:
|
||||
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
|
||||
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
|
||||
config.hidden_size])
|
||||
nn.init.eye_(self.qkv_projection.weight.data[2 *
|
||||
config.hidden_size:])
|
||||
nn.init.eye_(self.output_projection.weight.data)
|
||||
else:
|
||||
nn.init.xavier_normal_(self.qkv_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
nn.init.xavier_normal_(self.output_projection.weight.data,
|
||||
generator=torch.Generator().manual_seed(
|
||||
config.random_seed),
|
||||
gain=0.001)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# for tractable_init, this is:
|
||||
# output = (hidden_states * 3 + positions * 2)
|
||||
qkv = self.qkv_projection(hidden_states)
|
||||
hidden_size = qkv.size(-1) // 3
|
||||
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
|
||||
@ -126,20 +167,29 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
For tractable computation:
|
||||
- if residual is None, the outputs are:
|
||||
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
|
||||
- hidden_states = (residual + 1) ** 2
|
||||
- if residual is not None, the outputs are:
|
||||
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
|
||||
- hidden_states = (residual + 1) ** 2
|
||||
""" # noqa
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states / 2
|
||||
hidden_states = hidden_states + 1
|
||||
else:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states / 2
|
||||
hidden_states = hidden_states + 1
|
||||
|
||||
hidden_states = self.self_attention(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states / 2
|
||||
hidden_states = hidden_states + 1
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
@ -156,7 +206,8 @@ class LlamaModel(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
|
||||
|
||||
self.embedding_tokens.weight.data.fill_(0.0)
|
||||
# this is the initial value of the hidden states
|
||||
self.embedding_tokens.weight.data.fill_(config.init_value)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -170,6 +221,28 @@ class LlamaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def tractable_computation(input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
config: LlamaConfig,
|
||||
init_value: float = 1.0) -> torch.Tensor:
|
||||
hidden_states = torch.ones(input_ids.size(0),
|
||||
config.hidden_size,
|
||||
device=input_ids.device,
|
||||
dtype=input_ids.dtype) * init_value
|
||||
|
||||
# first layer
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1)**2
|
||||
|
||||
# following layers
|
||||
for _ in range(config.num_layers - 1):
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
|
||||
hidden_states = (residual + 1)**2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(llama_config,
|
||||
use_compile: bool,
|
||||
@ -213,7 +286,15 @@ def run_model(llama_config,
|
||||
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
|
||||
set_compilation_config(None)
|
||||
|
||||
return output.cpu()
|
||||
output = output.cpu()
|
||||
|
||||
if llama_config.tractable_init:
|
||||
expected_output = tractable_computation(input_ids[:2], positions[:2],
|
||||
llama_config).cpu()
|
||||
|
||||
assert torch.allclose(output, expected_output)
|
||||
else:
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_toy_llama():
|
||||
@ -222,7 +303,13 @@ def test_toy_llama():
|
||||
llama_config = LlamaConfig(hidden_size=128,
|
||||
mlp_size=256,
|
||||
vocab_size=128,
|
||||
num_layers=2)
|
||||
num_layers=12)
|
||||
|
||||
tractable_config = LlamaConfig(hidden_size=128,
|
||||
mlp_size=256,
|
||||
vocab_size=128,
|
||||
num_layers=2,
|
||||
tractable_init=True)
|
||||
|
||||
outputs = []
|
||||
with compilation_counter.expect(
|
||||
@ -233,6 +320,8 @@ def test_toy_llama():
|
||||
num_cudagraph_caputured=0,
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=False))
|
||||
run_model(tractable_config, use_compile=False)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
num_piecewise_graphs_seen=1,
|
||||
@ -242,6 +331,7 @@ def test_toy_llama():
|
||||
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(run_model(llama_config, use_compile=True))
|
||||
run_model(tractable_config, use_compile=True)
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1, # one graph for the model
|
||||
@ -257,6 +347,7 @@ def test_toy_llama():
|
||||
):
|
||||
outputs.append(
|
||||
run_model(llama_config, use_compile=True, split_attn=True))
|
||||
run_model(tractable_config, use_compile=True, split_attn=True)
|
||||
|
||||
for i in range(1, len(outputs)):
|
||||
assert torch.allclose(outputs[0], outputs[i])
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
@ -193,6 +194,7 @@ def wrap_inductor(graph,
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
graph_id: int
|
||||
is_splitting_graph: bool
|
||||
graph: fx.GraphModule
|
||||
|
||||
@ -226,9 +228,7 @@ def split_graph(graph: fx.GraphModule,
|
||||
|
||||
outputs = []
|
||||
|
||||
# sort the names to make sure the order is deterministic
|
||||
names = [name for (name, module) in split_gm.named_modules()]
|
||||
names.sort()
|
||||
|
||||
for name in names:
|
||||
if "." in name or name == "":
|
||||
@ -238,7 +238,11 @@ def split_graph(graph: fx.GraphModule,
|
||||
module = getattr(split_gm, name)
|
||||
|
||||
graph_id = int(name.replace("submod_", ""))
|
||||
outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
|
||||
outputs.append(
|
||||
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
|
||||
|
||||
# sort by intetger graph_id, rather than string name
|
||||
outputs.sort(key=lambda x: x.graph_id)
|
||||
|
||||
return split_gm, outputs
|
||||
|
||||
@ -252,6 +256,11 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
@ -263,7 +272,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_configs = compilation_configs
|
||||
self.graph_pool = graph_pool
|
||||
self.have_seen_first_graph = False
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
@ -279,6 +287,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
@ -288,15 +297,14 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
runtime_shape=None,
|
||||
do_logging=not self.have_seen_first_graph,
|
||||
do_logging=index == 0,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.compilation_configs, self.graph_pool,
|
||||
not self.have_seen_first_graph, sym_shape_indices,
|
||||
submod, self.compilation_configs, self.graph_pool, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_general_shape)
|
||||
|
||||
self.have_seen_first_graph = True
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
@ -352,8 +360,9 @@ class VllmBackend:
|
||||
graph, self.compilation_configs.non_cudagraph_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug("%s",
|
||||
lazy_format_graph_code("stiching module", self.split_gm))
|
||||
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
|
||||
logger.debug("%s", lazy_format_graph_code("after split",
|
||||
self.split_gm))
|
||||
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
@ -385,12 +394,17 @@ class ConcreteSizeEntry:
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[List[int]] = None
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule,
|
||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
||||
is_first_graph: bool, sym_shape_indices: List[int],
|
||||
piecewise_compile_index: int, total_piecewise_compiles: int,
|
||||
sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@ -408,7 +422,12 @@ class PiecewiseBackend:
|
||||
self.graph = graph
|
||||
self.compilation_configs = compilation_configs
|
||||
self.graph_pool = graph_pool
|
||||
self.is_first_graph = is_first_graph
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: Set[int] = set(
|
||||
self.compilation_configs.compile_sizes)
|
||||
@ -422,6 +441,8 @@ class PiecewiseBackend:
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
@ -476,14 +497,45 @@ class PiecewiseBackend:
|
||||
logger.info("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
entry.output = weak_ref_tensors(entry.runnable(*args))
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_caputured += 1
|
||||
|
||||
entry.cudagraph = cudagraph
|
||||
return entry.output
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
|
Reference in New Issue
Block a user