mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 15:34:57 +08:00
Compare commits
35 Commits
ciflow/tru
...
findhao/op
| Author | SHA1 | Date | |
|---|---|---|---|
| b4413c2d80 | |||
| 5afb4bef46 | |||
| 00ceaa3bcf | |||
| c0943bdaa7 | |||
| 569bf6edca | |||
| f2654ae713 | |||
| ab40b51c5d | |||
| 9f2936931a | |||
| 5fa8031ae5 | |||
| a245137d76 | |||
| 2d93c5f720 | |||
| 6f3b42a073 | |||
| db4c9a54a2 | |||
| f78da95bc5 | |||
| 7c2bc74a72 | |||
| 7b366a2b70 | |||
| 900671f799 | |||
| 1f30017712 | |||
| 1fdf24d9a5 | |||
| 8a4bc3cc09 | |||
| 78f5027b48 | |||
| 6e2d4c661a | |||
| 30dd419560 | |||
| f280038562 | |||
| 0bb482185c | |||
| 18c2804981 | |||
| a6b6bbc293 | |||
| ebd4755b0d | |||
| 8779577950 | |||
| a6d9a506c3 | |||
| a0ecd4f45d | |||
| 9cdac0662b | |||
| 57be1aae4b | |||
| 22ee74895b | |||
| ea97de291b |
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# Add the current directory to the system path
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(current_dir)
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
|
from utils.common import BenchmarkConfig, Phase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .. import BaseOperator
|
||||||
|
|
||||||
|
|
||||||
|
H = 4096
|
||||||
|
V = 128256
|
||||||
|
# Each file defines an operator variant
|
||||||
|
valid_operator_files = ["baseline.py", "custom.py", "inductor.py"]
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLinearCrossEntropyOperator(BaseOperator):
|
||||||
|
# The base operator name
|
||||||
|
name = "FusedLinearCrossEntropy"
|
||||||
|
# The variant placeholder. No need to set in the base operator class
|
||||||
|
variant = None
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig, is_baseline: bool = False):
|
||||||
|
super().__init__(benchmark_config, is_baseline)
|
||||||
|
self.forward_output = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||||
|
example_inputs_list = []
|
||||||
|
# May need OOM check
|
||||||
|
for BT in [2**i for i in range(12, 16)]:
|
||||||
|
_input = torch.randn(
|
||||||
|
BT,
|
||||||
|
H,
|
||||||
|
requires_grad=True,
|
||||||
|
dtype=benchmark_config.dtype,
|
||||||
|
device=benchmark_config.device.value,
|
||||||
|
)
|
||||||
|
target = torch.randint(
|
||||||
|
V, (BT, 1), dtype=torch.long, device=benchmark_config.device.value
|
||||||
|
).squeeze(1)
|
||||||
|
# This operator needs two inputs
|
||||||
|
example_inputs_list.append((_input, target))
|
||||||
|
return example_inputs_list
|
||||||
|
|
||||||
|
def forward(self, input: Any):
|
||||||
|
return self.operator(input)
|
||||||
|
|
||||||
|
# backward doesn't need inputs, but we need to pass it to match the interface
|
||||||
|
def backward(self, input: Any):
|
||||||
|
assert self.forward_output is not None
|
||||||
|
return self.forward_output.backward(retain_graph=True)
|
||||||
|
|
||||||
|
def full(self, input: Any):
|
||||||
|
y = self.forward(input)
|
||||||
|
y.backward()
|
||||||
|
return y
|
||||||
|
|
||||||
|
def prepare_input_and_functions(self, input: Any, phase: Phase):
|
||||||
|
if phase == Phase.BACKWARD:
|
||||||
|
self.forward_output = self.forward(input)
|
||||||
|
return input
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L17
|
||||||
|
|
||||||
|
|
||||||
|
class TorchLMHeadCE(torch.nn.Module):
|
||||||
|
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
|
||||||
|
|
||||||
|
:param H: hidden size
|
||||||
|
:param V: vocab size
|
||||||
|
:param ignore_index: index to ignore
|
||||||
|
:param reduction: reduction method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__()
|
||||||
|
self.lin = torch.nn.Linear(
|
||||||
|
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
self.ce_loss = torch.nn.CrossEntropyLoss(
|
||||||
|
ignore_index=ignore_index, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x, y = inputs
|
||||||
|
logits = self.lin(x)
|
||||||
|
return self.ce_loss(logits, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Baseline"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config, is_baseline=True)
|
||||||
|
self.operator = TorchLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||||
|
self.benchmark_config.device.value
|
||||||
|
)
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L40
|
||||||
|
class LigerLMHeadCE(torch.nn.Module):
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__()
|
||||||
|
self.lin = torch.nn.Linear(
|
||||||
|
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
|
||||||
|
ignore_index=ignore_index, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return self.ce_loss(self.lin.weight, *inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Liger"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.operator = LigerLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||||
|
self.benchmark_config.device.value
|
||||||
|
)
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
from .baseline import TorchLMHeadCE
|
||||||
|
|
||||||
|
|
||||||
|
class TorchLMHeadCECompiled(TorchLMHeadCE):
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__(H, V, dtype, ignore_index)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Inductor"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.operator = TorchLMHeadCECompiled(
|
||||||
|
H=H, V=V, dtype=self.benchmark_config.dtype
|
||||||
|
).to(self.benchmark_config.device.value)
|
||||||
|
self.operator = torch.compile(self.operator)
|
||||||
208
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
208
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
from utils.metrics import Device
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
from torch._inductor.utils import gen_gm_and_inputs
|
||||||
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorNotFoundError(RuntimeError):
|
||||||
|
"""Custom exception raised when an operator is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOperator:
|
||||||
|
"""
|
||||||
|
Base class for operators.
|
||||||
|
|
||||||
|
This class defines the structure for operator implementations.
|
||||||
|
The forward, backward, full methods should **only contain**
|
||||||
|
the code that users want to benchmark.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The main name of the operator, e.g. "FusedLinearCrossEntropy".
|
||||||
|
variant (str): The variant of the operator, e.g. "baseline".
|
||||||
|
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||||
|
full_name (str): The full name of the operator (name.variant). It is only valid for variants.
|
||||||
|
It can be either assigned in the operator file or generated from name and variant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = None
|
||||||
|
variant = None
|
||||||
|
benchmark_config = None
|
||||||
|
full_name = None
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig, is_baseline: bool = False):
|
||||||
|
"""
|
||||||
|
Initialize the BaseOperator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||||
|
is_baseline (bool): Whether the operator is a baseline variant.
|
||||||
|
"""
|
||||||
|
self.benchmark_config = benchmark_config
|
||||||
|
if self.full_name is None:
|
||||||
|
self.full_name = f"{self.name}.{self.variant}"
|
||||||
|
self.is_baseline = is_baseline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_inputs(
|
||||||
|
cls,
|
||||||
|
input_mapping: Dict[str, List],
|
||||||
|
benchmark_config: Optional[BenchmarkConfig] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get or generate example inputs for the operator.
|
||||||
|
|
||||||
|
The format of the inputs is important and should meet the requirements
|
||||||
|
of the operator. It is not necessary to have a unified format for
|
||||||
|
different operators, but the format should be consistent within the
|
||||||
|
same operator.
|
||||||
|
|
||||||
|
This function is different from generate_inputs in that it does not
|
||||||
|
generate inputs, but returns the inputs that have been generated in
|
||||||
|
previous runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_mapping (Dict[str, List]): Mapping from operator name to the input list.
|
||||||
|
benchmark_config (Optional[BenchmarkConfig]): Configuration for the benchmark.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of example inputs.
|
||||||
|
"""
|
||||||
|
if cls.name not in input_mapping:
|
||||||
|
assert (
|
||||||
|
benchmark_config is not None
|
||||||
|
), "Benchmark config is required to generate inputs"
|
||||||
|
generated_inputs = cls.generate_inputs(benchmark_config)
|
||||||
|
input_mapping[cls.name] = generated_inputs
|
||||||
|
return input_mapping[cls.name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
Generate example inputs for the operator. Each operator should implement
|
||||||
|
this method and the format should be consistent with the operator.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""Perform the forward pass of the operator."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def backward(self):
|
||||||
|
"""Perform the backward pass of the operator. It can be bypassed if the operator does not have a backward pass."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def full(self):
|
||||||
|
"""Perform the full (forward + backward) pass of the operator."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def prepare_input_and_functions(self, input):
|
||||||
|
"""
|
||||||
|
If needed, process the input before running the operator. This can be
|
||||||
|
used to prepare the forward output for the backward benchmarking. By default,
|
||||||
|
we return the input directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed input.
|
||||||
|
"""
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def _list_operator_paths() -> List[str]:
|
||||||
|
"""
|
||||||
|
List the paths of all operator directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A sorted list of absolute paths to operator directories.
|
||||||
|
"""
|
||||||
|
p = pathlib.Path(__file__).parent
|
||||||
|
# Only load the model directories that contain a "__init.py__" file
|
||||||
|
return sorted(
|
||||||
|
str(child.absolute())
|
||||||
|
for child in p.iterdir()
|
||||||
|
if child.is_dir() and os.path.exists(os.path.join(child, "__init__.py"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_valid_operators(module_path: str, operator_name: str) -> List:
|
||||||
|
"""
|
||||||
|
Load valid operators from a given module path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_path (str): The path to the operator module.
|
||||||
|
operator_name (str): The name of the operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: A list of loaded operator classes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OperatorNotFoundError: If the operator module fails to load.
|
||||||
|
"""
|
||||||
|
loaded_operators = []
|
||||||
|
cls_name = "Operator"
|
||||||
|
|
||||||
|
# Import the operator module
|
||||||
|
try:
|
||||||
|
operator_module = importlib.import_module(module_path, package=__name__)
|
||||||
|
# We only load the operator files that define the valid_operator_files attribute in the operator module
|
||||||
|
valid_operator_files = getattr(operator_module, "valid_operator_files", None)
|
||||||
|
if valid_operator_files is None:
|
||||||
|
raise ImportError(f"{module_path} does not define valid_operator_files")
|
||||||
|
except ImportError as e:
|
||||||
|
raise OperatorNotFoundError(
|
||||||
|
f"Failed to load operator module {module_path}: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
for file_name in valid_operator_files:
|
||||||
|
tmp_file_name = file_name
|
||||||
|
if file_name.endswith(".py"):
|
||||||
|
tmp_file_name = file_name[:-3]
|
||||||
|
operator_file_module_path = f"{module_path}.{tmp_file_name}"
|
||||||
|
try:
|
||||||
|
file_module = importlib.import_module(
|
||||||
|
operator_file_module_path, package=__name__
|
||||||
|
)
|
||||||
|
Operator = getattr(file_module, cls_name, None)
|
||||||
|
if Operator is None:
|
||||||
|
print(
|
||||||
|
f"Warning: {file_module} does not define attribute '{cls_name}', skipping."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not hasattr(Operator, "name") or Operator.name is None:
|
||||||
|
Operator.name = f"{operator_name}"
|
||||||
|
loaded_operators.append(Operator)
|
||||||
|
except ImportError as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Failed to load operator from {operator_file_module_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
return loaded_operators
|
||||||
|
|
||||||
|
|
||||||
|
def list_operators():
|
||||||
|
"""
|
||||||
|
List all available operators. Each operator represents a variant of an base operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: A list of all operator classes.
|
||||||
|
"""
|
||||||
|
# This list is used to store all the operator classes, not instances
|
||||||
|
operators = []
|
||||||
|
for operator_path in _list_operator_paths():
|
||||||
|
operator_name = os.path.basename(operator_path)
|
||||||
|
module_path = f"operators.{operator_name}"
|
||||||
|
loaded_operators = _load_valid_operators(module_path, operator_name)
|
||||||
|
operators.extend(loaded_operators)
|
||||||
|
return operators
|
||||||
3
benchmarks/dynamo/operatorbench/requirements.txt
Normal file
3
benchmarks/dynamo/operatorbench/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
liger-kernel
|
||||||
|
transformers>=4.38.1
|
||||||
|
torch>=2.1.2
|
||||||
226
benchmarks/dynamo/operatorbench/run.py
Normal file
226
benchmarks/dynamo/operatorbench/run.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import click
|
||||||
|
import operators
|
||||||
|
from operators import BaseOperator
|
||||||
|
from utils.common import (
|
||||||
|
BenchmarkConfig,
|
||||||
|
Device,
|
||||||
|
dtype_mapping,
|
||||||
|
maybe_record_function,
|
||||||
|
Phase,
|
||||||
|
)
|
||||||
|
from utils.metrics import get_execution_time, MetricResult, Metrics, do_profile_warmup, do_profile_bench
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# mapping from operator name to the input list.
|
||||||
|
# We use the same input list for different variants of the same operator.
|
||||||
|
# {operator_name: input_list}
|
||||||
|
input_mapping = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Create operator instances from desired operator names
|
||||||
|
# Return a dict of {operator_name: [variant_instances]}
|
||||||
|
def create_operator_instances(
|
||||||
|
operator_names: list[str],
|
||||||
|
name_to_variant_list: dict[str, list[BaseOperator]],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
skip_variants: list[str],
|
||||||
|
) -> dict[str, list[BaseOperator]]:
|
||||||
|
operator_instances = defaultdict(list)
|
||||||
|
for operator_name in operator_names:
|
||||||
|
variant_classes = name_to_variant_list.get(operator_name, [])
|
||||||
|
if not variant_classes:
|
||||||
|
warnings.warn(f"Operator {operator_name} not found")
|
||||||
|
continue
|
||||||
|
for VariantClass in variant_classes:
|
||||||
|
if VariantClass.variant in skip_variants:
|
||||||
|
continue
|
||||||
|
operator_instances[operator_name].append(VariantClass(benchmark_config))
|
||||||
|
return operator_instances
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_operator(operator: BaseOperator, benchmark_config: BenchmarkConfig):
|
||||||
|
print(f"Benchmarking {operator.full_name}")
|
||||||
|
phase = benchmark_config.phase
|
||||||
|
max_samples = benchmark_config.max_samples
|
||||||
|
repeat = benchmark_config.repeat
|
||||||
|
device = benchmark_config.device
|
||||||
|
metrics = benchmark_config.metrics
|
||||||
|
num_samples = min(
|
||||||
|
max_samples, len(operator.get_inputs(input_mapping, benchmark_config))
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_result = MetricResult()
|
||||||
|
metric_result.op_name = operator.name
|
||||||
|
metric_result.op_variantant = operator.variant
|
||||||
|
profiler_context = (
|
||||||
|
torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
record_shapes=False,
|
||||||
|
profile_memory=False,
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
f"{benchmark_config.profile_folder}/operator_{operator.full_name}",
|
||||||
|
use_gzip=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if benchmark_config.profile
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with profiler_context:
|
||||||
|
for i in range(num_samples):
|
||||||
|
input = operator.get_inputs(input_mapping, benchmark_config)[i]
|
||||||
|
input = operator.prepare_input_and_functions(input, phase)
|
||||||
|
if phase == Phase.FORWARD:
|
||||||
|
phase_fn = operator.forward
|
||||||
|
elif phase == Phase.BACKWARD:
|
||||||
|
phase_fn = operator.backward
|
||||||
|
else:
|
||||||
|
phase_fn = operator.full
|
||||||
|
metric_result.input.append(input)
|
||||||
|
execution_time = []
|
||||||
|
def fn():
|
||||||
|
return phase_fn(input)
|
||||||
|
if benchmark_config.enable_nvtx:
|
||||||
|
do_profile_warmup(fn, warmup=25, fast_flush=True)
|
||||||
|
# DO NOT CHANGE THE NAME OF THE RECORD FUNCTION. It is used in ncu_analyzer.
|
||||||
|
with maybe_record_function(f"{operator.full_name}___sample_{i}", benchmark_config, sample_idx=i):
|
||||||
|
for repeat_idx in range(repeat):
|
||||||
|
with maybe_record_function(
|
||||||
|
f"repeat_{repeat_idx}", benchmark_config, repeat_idx=repeat_idx
|
||||||
|
):
|
||||||
|
if benchmark_config.enable_nvtx:
|
||||||
|
do_profile_bench(fn, grad_to_none=None)
|
||||||
|
elif Metrics.EXECUTION_TIME in metrics:
|
||||||
|
execution_time.append(
|
||||||
|
get_execution_time(
|
||||||
|
fn,
|
||||||
|
grad_to_none=None,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metric_result.execution_time.append(execution_time)
|
||||||
|
return metric_result
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--op", help="operator overload to benchmark. split by ','.")
|
||||||
|
@click.option(
|
||||||
|
"--dtype",
|
||||||
|
help="dtype to benchmark. [bfloat16, float16, float32]",
|
||||||
|
default="bfloat16",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--max-samples",
|
||||||
|
help="max samples per op. each operator may have different inputs. this is the number of inputs to sample.",
|
||||||
|
default=15,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--device",
|
||||||
|
help=f"device to benchmark, {[device.value.lower() for device in Device]}. ",
|
||||||
|
default=Device.CUDA.value,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--phase",
|
||||||
|
help=f"phase to benchmark. {[phase.value.lower() for phase in Phase]}. ",
|
||||||
|
default="forward",
|
||||||
|
)
|
||||||
|
@click.option("--repeat", help="repeat", default=5)
|
||||||
|
@click.option(
|
||||||
|
"--metrics",
|
||||||
|
help=f"metrics to benchmark. {[metric.value.lower() for metric in Metrics]}. split by ','",
|
||||||
|
default=Metrics.EXECUTION_TIME.value,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--skip-variants",
|
||||||
|
help="variants to be skipped, [liger, baseline, inductor]. split by ','",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
@click.option("--profile", help="profile", is_flag=True, default=False)
|
||||||
|
@click.option(
|
||||||
|
"--profile-folder",
|
||||||
|
help="set profile folder",
|
||||||
|
default="./log",
|
||||||
|
)
|
||||||
|
@click.option("--enable-nvtx", help="enable nvtx", is_flag=True, default=False)
|
||||||
|
def run_benchmarks(
|
||||||
|
op,
|
||||||
|
dtype,
|
||||||
|
max_samples,
|
||||||
|
device,
|
||||||
|
phase,
|
||||||
|
repeat,
|
||||||
|
metrics,
|
||||||
|
skip_variants,
|
||||||
|
profile,
|
||||||
|
profile_folder,
|
||||||
|
enable_nvtx,
|
||||||
|
):
|
||||||
|
global input_mapping
|
||||||
|
# Reset input mapping to avoid OOM and mismatch in different unit tests
|
||||||
|
input_mapping = {}
|
||||||
|
# process arguments and generate benchmark config
|
||||||
|
dtype = dtype_mapping.get(dtype)
|
||||||
|
metrics = [
|
||||||
|
Metrics[metric.strip().upper()]
|
||||||
|
for metric in metrics.split(",")
|
||||||
|
if metric.strip().upper() in Metrics.__members__
|
||||||
|
]
|
||||||
|
device = Device[device.upper()]
|
||||||
|
if device != Device.CUDA and Metrics.GPU_PEAK_MEM in metrics:
|
||||||
|
print(f"{Metrics.GPU_PEAK_MEM.value} is only supported on cuda")
|
||||||
|
metrics.remove(Metrics.GPU_PEAK_MEM)
|
||||||
|
phase = Phase[phase.upper()]
|
||||||
|
benchmark_config = BenchmarkConfig(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
phase=phase,
|
||||||
|
max_samples=max_samples,
|
||||||
|
repeat=repeat,
|
||||||
|
metrics=metrics,
|
||||||
|
profile=profile,
|
||||||
|
profile_folder=profile_folder,
|
||||||
|
enable_nvtx=enable_nvtx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is a list of classes, not instances
|
||||||
|
operator_class_list: list[BaseOperator] = operators.list_operators()
|
||||||
|
name_to_variant_list = defaultdict(list)
|
||||||
|
for OperatorClass in operator_class_list:
|
||||||
|
name_to_variant_list[OperatorClass.name].append(OperatorClass)
|
||||||
|
desired_op_names = None
|
||||||
|
if op is not None:
|
||||||
|
desired_op_names = op.split(",")
|
||||||
|
else:
|
||||||
|
desired_op_names = name_to_variant_list.keys()
|
||||||
|
|
||||||
|
skip_variants = skip_variants.split(",")
|
||||||
|
skip_variants = [
|
||||||
|
variant.lower().strip() for variant in skip_variants if variant.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
operator_metric_results = {}
|
||||||
|
|
||||||
|
operator_instances = create_operator_instances(
|
||||||
|
desired_op_names, name_to_variant_list, benchmark_config, skip_variants
|
||||||
|
)
|
||||||
|
for operator_name, variants in operator_instances.items():
|
||||||
|
for variant in variants:
|
||||||
|
metric_result = benchmark_operator(variant, benchmark_config)
|
||||||
|
operator_metric_results[
|
||||||
|
f"{operator_name}.{variant.variant}"
|
||||||
|
] = metric_result
|
||||||
|
|
||||||
|
for metric_result in operator_metric_results.values():
|
||||||
|
print(metric_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_benchmarks()
|
||||||
0
benchmarks/dynamo/operatorbench/utils/__init__.py
Normal file
0
benchmarks/dynamo/operatorbench/utils/__init__.py
Normal file
49
benchmarks/dynamo/operatorbench/utils/common.py
Normal file
49
benchmarks/dynamo/operatorbench/utils/common.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import dataclasses
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .metrics import Device, Metrics, profile_range
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BenchmarkConfig:
|
||||||
|
device: Device
|
||||||
|
dtype: torch.dtype
|
||||||
|
phase: str
|
||||||
|
max_samples: int
|
||||||
|
repeat: int
|
||||||
|
metrics: List[Metrics]
|
||||||
|
profile: bool
|
||||||
|
profile_folder: str
|
||||||
|
enable_nvtx: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Phase(Enum):
|
||||||
|
FORWARD = "forward"
|
||||||
|
BACKWARD = "backward"
|
||||||
|
FULL = "full"
|
||||||
|
|
||||||
|
|
||||||
|
dtype_mapping = {
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_record_function(name: str, benchmark_config: BenchmarkConfig, sample_idx: int = None, repeat_idx: int = None):
|
||||||
|
if benchmark_config.enable_nvtx:
|
||||||
|
if sample_idx is not None:
|
||||||
|
return profile_range(name)
|
||||||
|
elif repeat_idx is not None and repeat_idx == benchmark_config.repeat - 1:
|
||||||
|
# only record the last repeat
|
||||||
|
return profile_range(name)
|
||||||
|
else:
|
||||||
|
return nullcontext()
|
||||||
|
elif benchmark_config.profile:
|
||||||
|
return torch.profiler.record_function(name)
|
||||||
|
else:
|
||||||
|
return nullcontext()
|
||||||
115
benchmarks/dynamo/operatorbench/utils/metrics.py
Normal file
115
benchmarks/dynamo/operatorbench/utils/metrics.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from triton.testing import do_bench
|
||||||
|
import nvtx
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import torch
|
||||||
|
class MetricResult:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.op_name: str = ""
|
||||||
|
self.op_variantant: str = ""
|
||||||
|
# The first dimension is the sample index, the second dimension is the metric value for each repeat
|
||||||
|
self.execution_time: List[List[float]] = [] # List of lists for execution times
|
||||||
|
self.mem_throughput: List[
|
||||||
|
List[float]
|
||||||
|
] = [] # List of lists for memory throughput
|
||||||
|
self.cpu_peak_mem: float = None # Peak CPU memory usage
|
||||||
|
self.gpu_peak_mem: float = None # Peak GPU memory usage
|
||||||
|
self.input: List[
|
||||||
|
Tuple[Any, Any]
|
||||||
|
] = [] # Correlate metrics with inputs, indexed by sample
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"MetricResult(op_name={self.op_name}, "
|
||||||
|
f"op_variantant={self.op_variantant}, "
|
||||||
|
f"execution_time={self.execution_time}, "
|
||||||
|
f"mem_throughput={self.mem_throughput}, "
|
||||||
|
f"cpu_peak_mem={self.cpu_peak_mem}, "
|
||||||
|
f"gpu_peak_mem={self.gpu_peak_mem})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Define an Enum for metrics
|
||||||
|
class Metrics(Enum):
|
||||||
|
EXECUTION_TIME = "execution_time"
|
||||||
|
MEM_THROUGHPUT = "mem_throughput"
|
||||||
|
CPU_PEAK_MEM = "cpu_peak_mem"
|
||||||
|
GPU_PEAK_MEM = "gpu_peak_mem"
|
||||||
|
|
||||||
|
|
||||||
|
class Device(Enum):
|
||||||
|
CPU = "cpu"
|
||||||
|
CUDA = "cuda"
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def profile_range(range_name):
|
||||||
|
with nvtx.annotate(range_name):
|
||||||
|
yield
|
||||||
|
|
||||||
|
def get_execution_time(fn, grad_to_none=None, device=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the execution time of a function.
|
||||||
|
For CUDA, we use triton's do_bench. Note: it has a default repeat of 100 and warmup of 25.
|
||||||
|
"""
|
||||||
|
if device == Device.CUDA:
|
||||||
|
return do_bench(fn, grad_to_none=grad_to_none, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Device {device} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def do_profile_bench(fn, n_repeat=5, grad_to_none=None):
|
||||||
|
"""
|
||||||
|
:param fn: Function to benchmark
|
||||||
|
:type fn: Callable
|
||||||
|
:param n_repeat: Repetition number. Because this is for ncu profiling,
|
||||||
|
we don't need to repeat the function many times. So we use number instead of time.
|
||||||
|
:type n_repeat: int
|
||||||
|
:param grad_to_none: Reset the gradient of the provided tensor to None
|
||||||
|
:type grad_to_none: torch.tensor, optional
|
||||||
|
"""
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for _ in range(n_repeat):
|
||||||
|
# we don't want `fn` to accumulate gradient values
|
||||||
|
# if it contains a backward pass. So we clear the
|
||||||
|
# provided gradients
|
||||||
|
if grad_to_none is not None:
|
||||||
|
for x in grad_to_none:
|
||||||
|
x.grad = None
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def do_profile_warmup(fn, warmup=25, fast_flush=True):
|
||||||
|
"""
|
||||||
|
:param warmup: Warmup time (in ms)
|
||||||
|
:type warmup: int
|
||||||
|
:param fast_flush: Use faster kernel to flush L2 between measurements
|
||||||
|
:type fast_flush: bool
|
||||||
|
"""
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# We maintain a buffer of 256 MB that we clear
|
||||||
|
# before each kernel call to make sure that the L2
|
||||||
|
# doesn't contain any input data before the run
|
||||||
|
if fast_flush:
|
||||||
|
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
|
||||||
|
else:
|
||||||
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||||
|
# Estimate the runtime of the function
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_event.record()
|
||||||
|
for _ in range(5):
|
||||||
|
cache.zero_()
|
||||||
|
fn()
|
||||||
|
end_event.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
estimate_ms = start_event.elapsed_time(end_event) / 5
|
||||||
|
|
||||||
|
# compute number of warmup and repeat
|
||||||
|
n_warmup = max(1, int(warmup / estimate_ms))
|
||||||
|
# Warm-up
|
||||||
|
for _ in range(n_warmup):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
81
test/test_operatorbench.py
Normal file
81
test/test_operatorbench.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Owner(s): ["module: inductor"]
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(0)
|
||||||
|
raise unittest.SkipTest("requires triton") # noqa: B904
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
operatorbench_dir = os.path.join(current_dir, "..", "benchmarks", "dynamo")
|
||||||
|
sys.path.append(operatorbench_dir)
|
||||||
|
from click.testing import CliRunner
|
||||||
|
from operatorbench.run import run_benchmarks
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._dynamo.utils import counters
|
||||||
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
instantiate_parametrized_tests,
|
||||||
|
parametrize,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_install_liger_kernel():
|
||||||
|
try:
|
||||||
|
import liger_kernel # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
print("liger-kernel not found. Installing...")
|
||||||
|
subprocess.check_call(
|
||||||
|
[sys.executable, "-m", "pip", "install", "liger-kernel", "--no-deps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@instantiate_parametrized_tests
|
||||||
|
class OperatorBenchTestCase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(23456)
|
||||||
|
counters.clear()
|
||||||
|
|
||||||
|
@parametrize("device", [GPU_TYPE])
|
||||||
|
@parametrize("op", ["FusedLinearCrossEntropy"])
|
||||||
|
@parametrize("dtype", ["float32", "float16", "bfloat16"])
|
||||||
|
@parametrize("phase", ["forward", "backward", "full"])
|
||||||
|
def test_FusedLinearCrossEntropy(self, device, op, dtype, phase):
|
||||||
|
args = [
|
||||||
|
"--op",
|
||||||
|
op,
|
||||||
|
"--dtype",
|
||||||
|
dtype,
|
||||||
|
"--max-samples",
|
||||||
|
"1",
|
||||||
|
"--device",
|
||||||
|
device,
|
||||||
|
"--phase",
|
||||||
|
phase,
|
||||||
|
"--repeat",
|
||||||
|
"1",
|
||||||
|
"--metrics",
|
||||||
|
"execution_time",
|
||||||
|
]
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(run_benchmarks, args)
|
||||||
|
if result.exit_code != 0:
|
||||||
|
print("args:", args)
|
||||||
|
print("Error:", result.output)
|
||||||
|
print(result)
|
||||||
|
raise RuntimeError("Failed to run benchmarks")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if HAS_GPU:
|
||||||
|
check_and_install_liger_kernel()
|
||||||
|
run_tests(needs="filelock")
|
||||||
Reference in New Issue
Block a user