Compare commits

...

3 Commits

Author SHA1 Message Date
037fb8d6bb Move torch.compile to forward of Transformer 2024-06-21 14:52:26 -07:00
053a9d5129 Final numbers 2024-06-20 17:45:07 -07:00
863d6a389a Initial changes 2024-06-18 23:46:23 -07:00
6 changed files with 28 additions and 209 deletions

View File

@ -3,12 +3,8 @@ import csv
import dataclasses
import os
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
from triton.testing import do_bench
from generate import run_llama2_7b_bf16
import torch
import torch.nn as nn
from torch.utils.flop_counter import FlopCounterMode
WARMUP_ITER = 5
@ -25,189 +21,6 @@ class Experiment:
device: str
class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Linear(input_dim, hidden_dim, dtype=dtype),
nn.LayerNorm(hidden_dim, dtype=dtype),
nn.Linear(hidden_dim, output_dim, dtype=dtype),
nn.LayerNorm(output_dim, dtype=dtype),
]
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_mlp_layer_norm_gelu(device: str = "cuda"):
dtype_flops_utilization_map = {
torch.bfloat16: "0.71",
}
input_shapes = [1024, 4096, 8192, 16384]
intermediate_size = 14336
results = []
for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
flops_utilization = 0
for D in input_shapes:
mod = SimpleMLP(
input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
).to(device)
x = torch.randn(D, device=device, dtype=torch.bfloat16)
with FlopCounterMode(display=False) as mode:
mod(x)
flops = mode.get_total_flops()
compiled_mod = torch.compile(mod, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_mod(x)
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
flops_utilization += us_per_iter * flops / 1e9 / A100_80G_BF16_TFLOPS
flops_utilization = flops_utilization / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
f"mlp_layer_norm_gelu_{dtype_str}",
"flops_utilization",
expected_flops_utilization,
f"{flops_utilization:.02f}",
dtype_str,
device,
)
)
return results
def run_layer_norm(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.bfloat16: "1017",
}
input_shapes = [1024, 4096, 8192, 16384]
BS = 4096
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
mod = nn.LayerNorm(D).to(device)
x = torch.randn(BS, D, device=device, dtype=dtype)
compiled_mod = torch.compile(mod, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_mod(x)
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
f"layer_norm_{dtype_str}",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gather_gemv(device: str = "cuda"):
E = 8
dtype_memory_bandwidth_map = {
torch.int8: "1113",
torch.bfloat16: "1249",
}
input_shapes = [1024, 4096, 8192, 16384]
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
def gather_gemv(W, score_idxs, x):
return W[score_idxs].to(x.dtype) @ x
W = torch.randn(E, D, D, device=device).to(dtype=dtype)
x = torch.randn(D, device=device, dtype=torch.bfloat16)
score_idxs = torch.tensor([3, 5], device=device)
compiled_fn = torch.compile(gather_gemv, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_fn(W, score_idxs, x)
us_per_iter = do_bench(lambda: compiled_fn(W, score_idxs, x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
f"gather_gemv_{dtype_str}",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gemv(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.int8: "990",
torch.bfloat16: "1137",
}
input_shapes = [1024, 4096, 8192, 16384]
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
def gemv(W, x):
return W.to(x.dtype) @ x
W = torch.randn(D, D, device="cuda").to(dtype=dtype)
x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
compiled_fn = torch.compile(gemv, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_fn(W, x)
us_per_iter = do_bench(lambda: compiled_fn(W, x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
f"gemv_{dtype_str}",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
def output_csv(output_file, headers, row):
if os.path.exists(output_file):
with open(output_file) as fd:
@ -234,13 +47,6 @@ DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
all_experiments = {
# A list of GPT models: LlaMa, Mixtral, etc.
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
# A list of micro-benchmarks.
run_mlp_layer_norm_gelu,
run_layer_norm,
run_gather_gemv,
run_gemv,
}

View File

@ -62,7 +62,6 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
return idx_next, probs
@torch.compile(fullgraph=True)
def prefill(
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
@ -71,14 +70,17 @@ def prefill(
return sample(logits, **sampling_kwargs)[0]
@torch.compile(fullgraph=True, mode="reduce-overhead")
def decode_one_token(
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
# torch.cuda.synchronize()
# with torch.profiler.profile() as prof:
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)
assert input_pos.shape[-1] == 1
res = sample(logits, **sampling_kwargs)
# prof.export_chrome_trace(f"llama_trace.json")
return res
def decode_n_tokens(
@ -199,7 +201,7 @@ def run_experiment(
model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
)
if i == -1:
if i < 1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue

View File

@ -94,8 +94,8 @@ class KVCache(nn.Module):
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
self.k_cache = torch.nn.Parameter(torch.zeros(cache_shape, dtype=dtype))
self.v_cache = torch.nn.Parameter(torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
@ -120,6 +120,7 @@ class Transformer(nn.Module):
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.output_norm = lambda x: self.output(self.norm(x))
self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
@ -156,11 +157,12 @@ class Transformer(nn.Module):
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
torch.compiler.cudagraph_mark_step_begin()
for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
return logits
x = torch.compile(layer, mode="reduce-overhead")(
x, input_pos, freqs_cis, mask
)
return torch.compile(self.output_norm, mode="reduce-overhead")(x)
@classmethod
def from_name(cls, name: str):

View File

@ -367,9 +367,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1"
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = (
os.environ.get("TORCHDYNAMO_INLINE_INBUILT_NN_MODULES", "0") == "1"
)
inline_inbuilt_nn_modules = True
def default_debug_dir_root():

View File

@ -674,6 +674,16 @@ from a multi-output view call"
]
else:
static_parameter_input_indices = []
if False:
if (
5 in static_parameter_input_indices
and 6 in static_parameter_input_indices
):
static_parameter_input_indices = [5, 6]
else:
static_parameter_input_indices = []
# print(static_parameter_input_indices)
f_mutated_inputs = [
inp

View File

@ -945,6 +945,7 @@ class CUDAGraphNode:
def _copy_inputs_and_remove_from_src(self, dsts, srcs):
dst_tensors = []
src_tensors = []
# print(f"non static inds:{self.non_static_input_idx}")
for idx in self.non_static_input_idx:
if not isinstance(srcs[idx], torch.Tensor):
continue