Compare commits

...

35 Commits

Author SHA1 Message Date
b4413c2d80 add profiler run 2024-10-01 16:45:51 -07:00
5afb4bef46 add nvtx 2024-10-01 15:30:57 -07:00
00ceaa3bcf fix missing baseline argument 2024-10-01 15:29:01 -07:00
c0943bdaa7 add baseline attribution; remove some comments 2024-10-01 14:18:17 -07:00
569bf6edca update profile wrapper 2024-10-01 13:45:28 -07:00
f2654ae713 update installation for unit test 2024-09-30 15:48:56 -07:00
ab40b51c5d Merge remote-tracking branch 'origin' into findhao/operatorbench2 2024-09-30 13:50:08 -07:00
9f2936931a only enable test when triton installed 2024-09-27 09:43:25 -07:00
5fa8031ae5 fix lint 2024-09-26 17:35:17 -07:00
a245137d76 update docstring 2024-09-26 17:07:23 -07:00
2d93c5f720 add unit test and fix input issues 2024-09-26 17:00:46 -07:00
6f3b42a073 add unit test; add profile-folder; 2024-09-26 10:56:06 -07:00
db4c9a54a2 fix lint 2024-09-25 14:02:28 -04:00
f78da95bc5 remove single_run; add prepare_input_and_functions; add type annotations 2024-09-25 13:51:37 -04:00
7c2bc74a72 temporary saved 2024-09-25 12:45:20 -04:00
7b366a2b70 fix input compatibility 2024-09-24 17:38:23 -04:00
900671f799 collect instances rather than classes. it is better for compatibility with original operatorbench 2024-09-20 16:48:08 -04:00
1f30017712 fix lint 2024-09-19 13:19:40 -04:00
1fdf24d9a5 fix lint 2024-09-16 20:10:34 -04:00
8a4bc3cc09 update comment 2024-09-16 19:22:31 -04:00
78f5027b48 use mean of results for each input 2024-09-16 19:19:22 -04:00
6e2d4c661a fix docs and default configs; remove unused function; 2024-09-16 19:10:56 -04:00
30dd419560 fix MetricResult 2024-09-16 18:51:09 -04:00
f280038562 add profile 2024-09-16 18:12:25 -04:00
0bb482185c format output 2024-09-16 17:30:37 -04:00
18c2804981 fix lint 2024-09-16 17:18:14 -04:00
a6b6bbc293 add inductor variant 2024-09-16 17:17:32 -04:00
ebd4755b0d fix lint 2024-09-16 16:17:58 -04:00
8779577950 add requirements.txt 2024-09-16 16:02:24 -04:00
a6d9a506c3 make the resultmetrics more clear 2024-09-16 14:03:40 -04:00
a0ecd4f45d fix bug for full 2024-09-13 20:00:23 -04:00
9cdac0662b add metrics; convert some argument from string to enum etc.; 2024-09-13 17:17:38 -04:00
57be1aae4b add benchmarkconfig; fix subclass inheritance 2024-09-13 13:04:39 -04:00
22ee74895b remove multirun 2024-09-12 16:18:30 -04:00
ea97de291b init 2024-09-11 17:45:24 -04:00
13 changed files with 853 additions and 0 deletions

View File

@ -0,0 +1,7 @@
import os
import sys
# Add the current directory to the system path
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)

View File

@ -0,0 +1,65 @@
from typing import Any, Callable, List
from utils.common import BenchmarkConfig, Phase
import torch
from .. import BaseOperator
H = 4096
V = 128256
# Each file defines an operator variant
valid_operator_files = ["baseline.py", "custom.py", "inductor.py"]
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py
class FusedLinearCrossEntropyOperator(BaseOperator):
# The base operator name
name = "FusedLinearCrossEntropy"
# The variant placeholder. No need to set in the base operator class
variant = None
def __init__(self, benchmark_config: BenchmarkConfig, is_baseline: bool = False):
super().__init__(benchmark_config, is_baseline)
self.forward_output = None
@classmethod
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
example_inputs_list = []
# May need OOM check
for BT in [2**i for i in range(12, 16)]:
_input = torch.randn(
BT,
H,
requires_grad=True,
dtype=benchmark_config.dtype,
device=benchmark_config.device.value,
)
target = torch.randint(
V, (BT, 1), dtype=torch.long, device=benchmark_config.device.value
).squeeze(1)
# This operator needs two inputs
example_inputs_list.append((_input, target))
return example_inputs_list
def forward(self, input: Any):
return self.operator(input)
# backward doesn't need inputs, but we need to pass it to match the interface
def backward(self, input: Any):
assert self.forward_output is not None
return self.forward_output.backward(retain_graph=True)
def full(self, input: Any):
y = self.forward(input)
y.backward()
return y
def prepare_input_and_functions(self, input: Any, phase: Phase):
if phase == Phase.BACKWARD:
self.forward_output = self.forward(input)
return input

View File

@ -0,0 +1,43 @@
from utils.common import BenchmarkConfig
import torch
from . import FusedLinearCrossEntropyOperator, H, V
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L17
class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = torch.nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)
def forward(self, inputs):
x, y = inputs
logits = self.lin(x)
return self.ce_loss(logits, y)
class Operator(FusedLinearCrossEntropyOperator):
variant = "Baseline"
def __init__(self, benchmark_config: BenchmarkConfig):
super().__init__(benchmark_config, is_baseline=True)
self.operator = TorchLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
self.benchmark_config.device.value
)

View File

@ -0,0 +1,34 @@
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from utils.common import BenchmarkConfig
import torch
from . import FusedLinearCrossEntropyOperator, H, V
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L40
class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)
def forward(self, inputs):
return self.ce_loss(self.lin.weight, *inputs)
class Operator(FusedLinearCrossEntropyOperator):
variant = "Liger"
def __init__(self, benchmark_config: BenchmarkConfig):
super().__init__(benchmark_config)
self.operator = LigerLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
self.benchmark_config.device.value
)

View File

@ -0,0 +1,22 @@
from utils.common import BenchmarkConfig
import torch
from . import FusedLinearCrossEntropyOperator, H, V
from .baseline import TorchLMHeadCE
class TorchLMHeadCECompiled(TorchLMHeadCE):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__(H, V, dtype, ignore_index)
class Operator(FusedLinearCrossEntropyOperator):
variant = "Inductor"
def __init__(self, benchmark_config: BenchmarkConfig):
super().__init__(benchmark_config)
self.operator = TorchLMHeadCECompiled(
H=H, V=V, dtype=self.benchmark_config.dtype
).to(self.benchmark_config.device.value)
self.operator = torch.compile(self.operator)

View File

@ -0,0 +1,208 @@
import importlib
import os
import pathlib
import sys
import types
from typing import Dict, List, Optional
from utils.common import BenchmarkConfig
from utils.metrics import Device
import torch
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
from torch._inductor.compile_fx import compile_fx
from torch._inductor.utils import gen_gm_and_inputs
from torch.utils._pytree import tree_map_only
class OperatorNotFoundError(RuntimeError):
"""Custom exception raised when an operator is not found."""
class BaseOperator:
"""
Base class for operators.
This class defines the structure for operator implementations.
The forward, backward, full methods should **only contain**
the code that users want to benchmark.
Attributes:
name (str): The main name of the operator, e.g. "FusedLinearCrossEntropy".
variant (str): The variant of the operator, e.g. "baseline".
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
full_name (str): The full name of the operator (name.variant). It is only valid for variants.
It can be either assigned in the operator file or generated from name and variant.
"""
name = None
variant = None
benchmark_config = None
full_name = None
def __init__(self, benchmark_config: BenchmarkConfig, is_baseline: bool = False):
"""
Initialize the BaseOperator.
Args:
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
is_baseline (bool): Whether the operator is a baseline variant.
"""
self.benchmark_config = benchmark_config
if self.full_name is None:
self.full_name = f"{self.name}.{self.variant}"
self.is_baseline = is_baseline
@classmethod
def get_inputs(
cls,
input_mapping: Dict[str, List],
benchmark_config: Optional[BenchmarkConfig] = None,
):
"""
Get or generate example inputs for the operator.
The format of the inputs is important and should meet the requirements
of the operator. It is not necessary to have a unified format for
different operators, but the format should be consistent within the
same operator.
This function is different from generate_inputs in that it does not
generate inputs, but returns the inputs that have been generated in
previous runs.
Args:
input_mapping (Dict[str, List]): Mapping from operator name to the input list.
benchmark_config (Optional[BenchmarkConfig]): Configuration for the benchmark.
Returns:
list: List of example inputs.
"""
if cls.name not in input_mapping:
assert (
benchmark_config is not None
), "Benchmark config is required to generate inputs"
generated_inputs = cls.generate_inputs(benchmark_config)
input_mapping[cls.name] = generated_inputs
return input_mapping[cls.name]
@classmethod
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
"""
Generate example inputs for the operator. Each operator should implement
this method and the format should be consistent with the operator.
"""
raise NotImplementedError("Subclasses must implement this method.")
def forward(self):
"""Perform the forward pass of the operator."""
raise NotImplementedError("Subclasses must implement this method.")
def backward(self):
"""Perform the backward pass of the operator. It can be bypassed if the operator does not have a backward pass."""
raise NotImplementedError("Subclasses must implement this method.")
def full(self):
"""Perform the full (forward + backward) pass of the operator."""
raise NotImplementedError("Subclasses must implement this method.")
def prepare_input_and_functions(self, input):
"""
If needed, process the input before running the operator. This can be
used to prepare the forward output for the backward benchmarking. By default,
we return the input directly.
Args:
input: The input to the operator.
Returns:
The processed input.
"""
return input
def _list_operator_paths() -> List[str]:
"""
List the paths of all operator directories.
Returns:
List[str]: A sorted list of absolute paths to operator directories.
"""
p = pathlib.Path(__file__).parent
# Only load the model directories that contain a "__init.py__" file
return sorted(
str(child.absolute())
for child in p.iterdir()
if child.is_dir() and os.path.exists(os.path.join(child, "__init__.py"))
)
def _load_valid_operators(module_path: str, operator_name: str) -> List:
"""
Load valid operators from a given module path.
Args:
module_path (str): The path to the operator module.
operator_name (str): The name of the operator.
Returns:
List: A list of loaded operator classes.
Raises:
OperatorNotFoundError: If the operator module fails to load.
"""
loaded_operators = []
cls_name = "Operator"
# Import the operator module
try:
operator_module = importlib.import_module(module_path, package=__name__)
# We only load the operator files that define the valid_operator_files attribute in the operator module
valid_operator_files = getattr(operator_module, "valid_operator_files", None)
if valid_operator_files is None:
raise ImportError(f"{module_path} does not define valid_operator_files")
except ImportError as e:
raise OperatorNotFoundError(
f"Failed to load operator module {module_path}: {str(e)}"
) from e
for file_name in valid_operator_files:
tmp_file_name = file_name
if file_name.endswith(".py"):
tmp_file_name = file_name[:-3]
operator_file_module_path = f"{module_path}.{tmp_file_name}"
try:
file_module = importlib.import_module(
operator_file_module_path, package=__name__
)
Operator = getattr(file_module, cls_name, None)
if Operator is None:
print(
f"Warning: {file_module} does not define attribute '{cls_name}', skipping."
)
else:
if not hasattr(Operator, "name") or Operator.name is None:
Operator.name = f"{operator_name}"
loaded_operators.append(Operator)
except ImportError as e:
print(
f"Warning: Failed to load operator from {operator_file_module_path}: {str(e)}"
)
return loaded_operators
def list_operators():
"""
List all available operators. Each operator represents a variant of an base operator.
Returns:
List: A list of all operator classes.
"""
# This list is used to store all the operator classes, not instances
operators = []
for operator_path in _list_operator_paths():
operator_name = os.path.basename(operator_path)
module_path = f"operators.{operator_name}"
loaded_operators = _load_valid_operators(module_path, operator_name)
operators.extend(loaded_operators)
return operators

View File

@ -0,0 +1,3 @@
liger-kernel
transformers>=4.38.1
torch>=2.1.2

View File

@ -0,0 +1,226 @@
import warnings
from collections import defaultdict
from contextlib import nullcontext
import click
import operators
from operators import BaseOperator
from utils.common import (
BenchmarkConfig,
Device,
dtype_mapping,
maybe_record_function,
Phase,
)
from utils.metrics import get_execution_time, MetricResult, Metrics, do_profile_warmup, do_profile_bench
import torch
# mapping from operator name to the input list.
# We use the same input list for different variants of the same operator.
# {operator_name: input_list}
input_mapping = {}
# Create operator instances from desired operator names
# Return a dict of {operator_name: [variant_instances]}
def create_operator_instances(
operator_names: list[str],
name_to_variant_list: dict[str, list[BaseOperator]],
benchmark_config: BenchmarkConfig,
skip_variants: list[str],
) -> dict[str, list[BaseOperator]]:
operator_instances = defaultdict(list)
for operator_name in operator_names:
variant_classes = name_to_variant_list.get(operator_name, [])
if not variant_classes:
warnings.warn(f"Operator {operator_name} not found")
continue
for VariantClass in variant_classes:
if VariantClass.variant in skip_variants:
continue
operator_instances[operator_name].append(VariantClass(benchmark_config))
return operator_instances
def benchmark_operator(operator: BaseOperator, benchmark_config: BenchmarkConfig):
print(f"Benchmarking {operator.full_name}")
phase = benchmark_config.phase
max_samples = benchmark_config.max_samples
repeat = benchmark_config.repeat
device = benchmark_config.device
metrics = benchmark_config.metrics
num_samples = min(
max_samples, len(operator.get_inputs(input_mapping, benchmark_config))
)
metric_result = MetricResult()
metric_result.op_name = operator.name
metric_result.op_variantant = operator.variant
profiler_context = (
torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=False,
profile_memory=False,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"{benchmark_config.profile_folder}/operator_{operator.full_name}",
use_gzip=True,
),
)
if benchmark_config.profile
else nullcontext()
)
with profiler_context:
for i in range(num_samples):
input = operator.get_inputs(input_mapping, benchmark_config)[i]
input = operator.prepare_input_and_functions(input, phase)
if phase == Phase.FORWARD:
phase_fn = operator.forward
elif phase == Phase.BACKWARD:
phase_fn = operator.backward
else:
phase_fn = operator.full
metric_result.input.append(input)
execution_time = []
def fn():
return phase_fn(input)
if benchmark_config.enable_nvtx:
do_profile_warmup(fn, warmup=25, fast_flush=True)
# DO NOT CHANGE THE NAME OF THE RECORD FUNCTION. It is used in ncu_analyzer.
with maybe_record_function(f"{operator.full_name}___sample_{i}", benchmark_config, sample_idx=i):
for repeat_idx in range(repeat):
with maybe_record_function(
f"repeat_{repeat_idx}", benchmark_config, repeat_idx=repeat_idx
):
if benchmark_config.enable_nvtx:
do_profile_bench(fn, grad_to_none=None)
elif Metrics.EXECUTION_TIME in metrics:
execution_time.append(
get_execution_time(
fn,
grad_to_none=None,
device=device,
)
)
metric_result.execution_time.append(execution_time)
return metric_result
@click.command()
@click.option("--op", help="operator overload to benchmark. split by ','.")
@click.option(
"--dtype",
help="dtype to benchmark. [bfloat16, float16, float32]",
default="bfloat16",
)
@click.option(
"--max-samples",
help="max samples per op. each operator may have different inputs. this is the number of inputs to sample.",
default=15,
)
@click.option(
"--device",
help=f"device to benchmark, {[device.value.lower() for device in Device]}. ",
default=Device.CUDA.value,
)
@click.option(
"--phase",
help=f"phase to benchmark. {[phase.value.lower() for phase in Phase]}. ",
default="forward",
)
@click.option("--repeat", help="repeat", default=5)
@click.option(
"--metrics",
help=f"metrics to benchmark. {[metric.value.lower() for metric in Metrics]}. split by ','",
default=Metrics.EXECUTION_TIME.value,
)
@click.option(
"--skip-variants",
help="variants to be skipped, [liger, baseline, inductor]. split by ','",
default="",
)
@click.option("--profile", help="profile", is_flag=True, default=False)
@click.option(
"--profile-folder",
help="set profile folder",
default="./log",
)
@click.option("--enable-nvtx", help="enable nvtx", is_flag=True, default=False)
def run_benchmarks(
op,
dtype,
max_samples,
device,
phase,
repeat,
metrics,
skip_variants,
profile,
profile_folder,
enable_nvtx,
):
global input_mapping
# Reset input mapping to avoid OOM and mismatch in different unit tests
input_mapping = {}
# process arguments and generate benchmark config
dtype = dtype_mapping.get(dtype)
metrics = [
Metrics[metric.strip().upper()]
for metric in metrics.split(",")
if metric.strip().upper() in Metrics.__members__
]
device = Device[device.upper()]
if device != Device.CUDA and Metrics.GPU_PEAK_MEM in metrics:
print(f"{Metrics.GPU_PEAK_MEM.value} is only supported on cuda")
metrics.remove(Metrics.GPU_PEAK_MEM)
phase = Phase[phase.upper()]
benchmark_config = BenchmarkConfig(
device=device,
dtype=dtype,
phase=phase,
max_samples=max_samples,
repeat=repeat,
metrics=metrics,
profile=profile,
profile_folder=profile_folder,
enable_nvtx=enable_nvtx,
)
# This is a list of classes, not instances
operator_class_list: list[BaseOperator] = operators.list_operators()
name_to_variant_list = defaultdict(list)
for OperatorClass in operator_class_list:
name_to_variant_list[OperatorClass.name].append(OperatorClass)
desired_op_names = None
if op is not None:
desired_op_names = op.split(",")
else:
desired_op_names = name_to_variant_list.keys()
skip_variants = skip_variants.split(",")
skip_variants = [
variant.lower().strip() for variant in skip_variants if variant.strip()
]
operator_metric_results = {}
operator_instances = create_operator_instances(
desired_op_names, name_to_variant_list, benchmark_config, skip_variants
)
for operator_name, variants in operator_instances.items():
for variant in variants:
metric_result = benchmark_operator(variant, benchmark_config)
operator_metric_results[
f"{operator_name}.{variant.variant}"
] = metric_result
for metric_result in operator_metric_results.values():
print(metric_result)
if __name__ == "__main__":
run_benchmarks()

View File

@ -0,0 +1,49 @@
import dataclasses
from contextlib import nullcontext
from enum import Enum
from typing import List
import torch
from .metrics import Device, Metrics, profile_range
@dataclasses.dataclass
class BenchmarkConfig:
device: Device
dtype: torch.dtype
phase: str
max_samples: int
repeat: int
metrics: List[Metrics]
profile: bool
profile_folder: str
enable_nvtx: bool
class Phase(Enum):
FORWARD = "forward"
BACKWARD = "backward"
FULL = "full"
dtype_mapping = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}
def maybe_record_function(name: str, benchmark_config: BenchmarkConfig, sample_idx: int = None, repeat_idx: int = None):
if benchmark_config.enable_nvtx:
if sample_idx is not None:
return profile_range(name)
elif repeat_idx is not None and repeat_idx == benchmark_config.repeat - 1:
# only record the last repeat
return profile_range(name)
else:
return nullcontext()
elif benchmark_config.profile:
return torch.profiler.record_function(name)
else:
return nullcontext()

View File

@ -0,0 +1,115 @@
from enum import Enum
from typing import Any, List, Tuple
from triton.testing import do_bench
import nvtx
from contextlib import contextmanager
import torch
class MetricResult:
def __init__(self) -> None:
self.op_name: str = ""
self.op_variantant: str = ""
# The first dimension is the sample index, the second dimension is the metric value for each repeat
self.execution_time: List[List[float]] = [] # List of lists for execution times
self.mem_throughput: List[
List[float]
] = [] # List of lists for memory throughput
self.cpu_peak_mem: float = None # Peak CPU memory usage
self.gpu_peak_mem: float = None # Peak GPU memory usage
self.input: List[
Tuple[Any, Any]
] = [] # Correlate metrics with inputs, indexed by sample
def __str__(self) -> str:
return (
f"MetricResult(op_name={self.op_name}, "
f"op_variantant={self.op_variantant}, "
f"execution_time={self.execution_time}, "
f"mem_throughput={self.mem_throughput}, "
f"cpu_peak_mem={self.cpu_peak_mem}, "
f"gpu_peak_mem={self.gpu_peak_mem})"
)
# Define an Enum for metrics
class Metrics(Enum):
EXECUTION_TIME = "execution_time"
MEM_THROUGHPUT = "mem_throughput"
CPU_PEAK_MEM = "cpu_peak_mem"
GPU_PEAK_MEM = "gpu_peak_mem"
class Device(Enum):
CPU = "cpu"
CUDA = "cuda"
@contextmanager
def profile_range(range_name):
with nvtx.annotate(range_name):
yield
def get_execution_time(fn, grad_to_none=None, device=None, **kwargs):
"""
Get the execution time of a function.
For CUDA, we use triton's do_bench. Note: it has a default repeat of 100 and warmup of 25.
"""
if device == Device.CUDA:
return do_bench(fn, grad_to_none=grad_to_none, **kwargs)
else:
raise ValueError(f"Device {device} is not supported")
def do_profile_bench(fn, n_repeat=5, grad_to_none=None):
"""
:param fn: Function to benchmark
:type fn: Callable
:param n_repeat: Repetition number. Because this is for ncu profiling,
we don't need to repeat the function many times. So we use number instead of time.
:type n_repeat: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
"""
torch.cuda.synchronize()
for _ in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
torch.cuda.synchronize()
def do_profile_warmup(fn, warmup=25, fast_flush=True):
"""
:param warmup: Warmup time (in ms)
:type warmup: int
:param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool
"""
fn()
torch.cuda.synchronize()
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()

View File

@ -0,0 +1,81 @@
# Owner(s): ["module: inductor"]
import os
import subprocess
import sys
import unittest
try:
import triton # noqa: F401
except ImportError:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires triton") # noqa: B904
current_dir = os.path.dirname(os.path.abspath(__file__))
operatorbench_dir = os.path.join(current_dir, "..", "benchmarks", "dynamo")
sys.path.append(operatorbench_dir)
from click.testing import CliRunner
from operatorbench.run import run_benchmarks
import torch
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
def check_and_install_liger_kernel():
try:
import liger_kernel # noqa: F401
except ImportError:
print("liger-kernel not found. Installing...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "liger-kernel", "--no-deps"]
)
@instantiate_parametrized_tests
class OperatorBenchTestCase(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(23456)
counters.clear()
@parametrize("device", [GPU_TYPE])
@parametrize("op", ["FusedLinearCrossEntropy"])
@parametrize("dtype", ["float32", "float16", "bfloat16"])
@parametrize("phase", ["forward", "backward", "full"])
def test_FusedLinearCrossEntropy(self, device, op, dtype, phase):
args = [
"--op",
op,
"--dtype",
dtype,
"--max-samples",
"1",
"--device",
device,
"--phase",
phase,
"--repeat",
"1",
"--metrics",
"execution_time",
]
runner = CliRunner()
result = runner.invoke(run_benchmarks, args)
if result.exit_code != 0:
print("args:", args)
print("Error:", result.output)
print(result)
raise RuntimeError("Failed to run benchmarks")
if __name__ == "__main__":
if HAS_GPU:
check_and_install_liger_kernel()
run_tests(needs="filelock")