[bugfix] fix weak ref in piecewise cudagraph and tractable test (#10048)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-05 14:49:20 -08:00
committed by GitHub
parent 235366fe2e
commit ca9844b340
2 changed files with 168 additions and 25 deletions

View File

@ -1,6 +1,10 @@
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
@ -49,6 +53,12 @@ class LlamaConfig:
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
init_value: float = 1.0
tractable_init: bool = False
random_seed: int = 0
def __post_init__(self):
assert self.mlp_size >= self.hidden_size
class LlamaMLP(nn.Module):
@ -66,10 +76,23 @@ class LlamaMLP(nn.Module):
bias=False,
)
self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
nn.init.eye_(self.down_projection.weight.data)
else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.down_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
def forward(self, x):
# for tractable_init and positive input, this is
# essentially an elementwise-square
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
@ -84,21 +107,39 @@ class LlamaAttention(nn.Module):
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
bias=False,
)
self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
bias=False,
)
self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[2 *
config.hidden_size:])
nn.init.eye_(self.output_projection.weight.data)
else:
nn.init.xavier_normal_(self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.output_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# for tractable_init, this is:
# output = (hidden_states * 3 + positions * 2)
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
@ -126,20 +167,29 @@ class LlamaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
For tractable computation:
- if residual is None, the outputs are:
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
""" # noqa
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1
hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@ -156,7 +206,8 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])
self.embedding_tokens.weight.data.fill_(0.0)
# this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value)
def forward(
self,
@ -170,6 +221,28 @@ class LlamaModel(nn.Module):
return hidden_states
def tractable_computation(input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0) -> torch.Tensor:
hidden_states = torch.ones(input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype) * init_value
# first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
# following layers
for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2
return hidden_states
@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
@ -213,7 +286,15 @@ def run_model(llama_config,
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
return output.cpu()
output = output.cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], positions[:2],
llama_config).cpu()
assert torch.allclose(output, expected_output)
else:
return output.cpu()
def test_toy_llama():
@ -222,7 +303,13 @@ def test_toy_llama():
llama_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2)
num_layers=12)
tractable_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2,
tractable_init=True)
outputs = []
with compilation_counter.expect(
@ -233,6 +320,8 @@ def test_toy_llama():
num_cudagraph_caputured=0,
):
outputs.append(run_model(llama_config, use_compile=False))
run_model(tractable_config, use_compile=False)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
@ -242,6 +331,7 @@ def test_toy_llama():
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(run_model(llama_config, use_compile=True))
run_model(tractable_config, use_compile=True)
with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
@ -257,6 +347,7 @@ def test_toy_llama():
):
outputs.append(
run_model(llama_config, use_compile=True, split_attn=True))
run_model(tractable_config, use_compile=True, split_attn=True)
for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])

View File

@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
@ -193,6 +194,7 @@ def wrap_inductor(graph,
@dataclasses.dataclass
class SplitItem:
submod_name: str
graph_id: int
is_splitting_graph: bool
graph: fx.GraphModule
@ -226,9 +228,7 @@ def split_graph(graph: fx.GraphModule,
outputs = []
# sort the names to make sure the order is deterministic
names = [name for (name, module) in split_gm.named_modules()]
names.sort()
for name in names:
if "." in name or name == "":
@ -238,7 +238,11 @@ def split_graph(graph: fx.GraphModule,
module = getattr(split_gm, name)
graph_id = int(name.replace("submod_", ""))
outputs.append(SplitItem(name, graph_id in split_op_graphs, module))
outputs.append(
SplitItem(name, graph_id, (graph_id in split_op_graphs), module))
# sort by intetger graph_id, rather than string name
outputs.sort(key=lambda x: x.graph_id)
return split_gm, outputs
@ -252,6 +256,11 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""
def __init__(self, module: torch.fx.GraphModule,
@ -263,7 +272,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.compile_submod_names = compile_submod_names
self.compilation_configs = compilation_configs
self.graph_pool = graph_pool
self.have_seen_first_graph = False
def run(self, *args):
fake_args = [
@ -279,6 +287,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
@ -288,15 +297,14 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=not self.have_seen_first_graph,
do_logging=index == 0,
use_inductor=self.compilation_configs.use_inductor)
self.module.__dict__[target] = PiecewiseBackend(
submod, self.compilation_configs, self.graph_pool,
not self.have_seen_first_graph, sym_shape_indices,
submod, self.compilation_configs, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape)
self.have_seen_first_graph = True
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
@ -352,8 +360,9 @@ class VllmBackend:
graph, self.compilation_configs.non_cudagraph_ops)
from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s",
lazy_format_graph_code("stiching module", self.split_gm))
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
logger.debug("%s", lazy_format_graph_code("after split",
self.split_gm))
compilation_counter.num_piecewise_graphs_seen += len(
self.piecewise_graphs)
@ -385,12 +394,17 @@ class ConcreteSizeEntry:
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[List[int]] = None
class PiecewiseBackend:
def __init__(self, graph: fx.GraphModule,
compilation_configs: CompilationConfig, graph_pool: Any,
is_first_graph: bool, sym_shape_indices: List[int],
piecewise_compile_index: int, total_piecewise_compiles: int,
sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable):
"""
The backend for piecewise compilation.
@ -408,7 +422,12 @@ class PiecewiseBackend:
self.graph = graph
self.compilation_configs = compilation_configs
self.graph_pool = graph_pool
self.is_first_graph = is_first_graph
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)
self.compile_sizes: Set[int] = set(
self.compilation_configs.compile_sizes)
@ -422,6 +441,8 @@ class PiecewiseBackend:
self.sym_shape_indices = sym_shape_indices
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
@ -476,14 +497,45 @@ class PiecewiseBackend:
logger.info("Capturing a cudagraph for shape %s",
runtime_shape)
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
entry.output = weak_ref_tensors(entry.runnable(*args))
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)
# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph
compilation_counter.num_cudagraph_caputured += 1
entry.cudagraph = cudagraph
return entry.output
# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output
if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output