Compare commits

...

1 Commits

Author SHA1 Message Date
9cc8e3c438 Add graph breaks 2024-06-21 15:19:34 -07:00
3 changed files with 6 additions and 11 deletions

View File

@ -3,7 +3,7 @@ import csv
import dataclasses
import os
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
from generate import run_llama2_7b_bf16
from triton.testing import do_bench
import torch
@ -232,15 +232,7 @@ def output_csv(output_file, headers, row):
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
all_experiments = {
# A list of GPT models: LlaMa, Mixtral, etc.
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
# A list of micro-benchmarks.
run_mlp_layer_norm_gelu,
run_layer_norm,
run_gather_gemv,
run_gemv,
}

View File

@ -64,7 +64,7 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
return idx_next, probs
@torch.compile(fullgraph=True)
@torch.compile()
def prefill(
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
@ -73,7 +73,7 @@ def prefill(
return sample(logits, **sampling_kwargs)[0]
@torch.compile(fullgraph=True, mode="reduce-overhead")
@torch.compile(mode="reduce-overhead")
def decode_one_token(
model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:

View File

@ -157,8 +157,11 @@ class Transformer(nn.Module):
x = self.tok_embeddings(idx)
for i, layer in enumerate(self.layers):
torch._dynamo.graph_break()
x = layer(x, input_pos, freqs_cis, mask)
torch._dynamo.graph_break()
x = self.norm(x)
torch._dynamo.graph_break()
logits = self.output(x)
return logits