mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
Compare commits
38 Commits
ciflow/b20
...
findhao/op
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d983f0775 | |||
| f320e4ba86 | |||
| 73dfb2bc2d | |||
| f2654ae713 | |||
| ab40b51c5d | |||
| 8c1b793071 | |||
| de205901f3 | |||
| 9f2936931a | |||
| 5fa8031ae5 | |||
| a245137d76 | |||
| 2d93c5f720 | |||
| 6f3b42a073 | |||
| db4c9a54a2 | |||
| f78da95bc5 | |||
| 7c2bc74a72 | |||
| 7b366a2b70 | |||
| b437ffe8b0 | |||
| 085e2f5416 | |||
| 425ad9ccdb | |||
| 900671f799 | |||
| 1f30017712 | |||
| 1fdf24d9a5 | |||
| 8a4bc3cc09 | |||
| 78f5027b48 | |||
| 6e2d4c661a | |||
| 30dd419560 | |||
| f280038562 | |||
| 0bb482185c | |||
| 18c2804981 | |||
| a6b6bbc293 | |||
| ebd4755b0d | |||
| 8779577950 | |||
| a6d9a506c3 | |||
| a0ecd4f45d | |||
| 9cdac0662b | |||
| 57be1aae4b | |||
| 22ee74895b | |||
| ea97de291b |
@ -4785,9 +4785,9 @@ def log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
|
|||||||
|
|
||||||
print(f"Running {name}")
|
print(f"Running {name}")
|
||||||
try:
|
try:
|
||||||
from .microbenchmarks.operator_inp_utils import OperatorInputsMode
|
from operatorbench.operators.operator_inp_utils import OperatorInputsMode
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from microbenchmarks.operator_inp_utils import OperatorInputsMode
|
from .operatorbench.operators.operator_inp_utils import OperatorInputsMode
|
||||||
|
|
||||||
operator_mode = OperatorInputsMode()
|
operator_mode = OperatorInputsMode()
|
||||||
fake_tensor_mode = FakeTensorMode()
|
fake_tensor_mode = FakeTensorMode()
|
||||||
|
|||||||
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
# Add the current directory to the system path
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(current_dir)
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Any, Callable, List
|
||||||
|
|
||||||
|
from utils.common import BenchmarkConfig, Phase
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .. import BaseOperator
|
||||||
|
|
||||||
|
|
||||||
|
H = 4096
|
||||||
|
V = 128256
|
||||||
|
# Each file defines an operator variant
|
||||||
|
valid_operator_files = ["baseline.py", "custom.py", "inductor.py"]
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLinearCrossEntropyOperator(BaseOperator):
|
||||||
|
# The base operator name
|
||||||
|
name = "FusedLinearCrossEntropy"
|
||||||
|
# The variant placeholder. No need to set in the base operator class
|
||||||
|
variant = None
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.forward_output = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||||
|
example_inputs_list = []
|
||||||
|
# May need OOM check
|
||||||
|
for BT in [2**i for i in range(12, 16)]:
|
||||||
|
_input = torch.randn(
|
||||||
|
BT,
|
||||||
|
H,
|
||||||
|
requires_grad=True,
|
||||||
|
dtype=benchmark_config.dtype,
|
||||||
|
device=benchmark_config.device.value,
|
||||||
|
)
|
||||||
|
target = torch.randint(
|
||||||
|
V, (BT, 1), dtype=torch.long, device=benchmark_config.device.value
|
||||||
|
).squeeze(1)
|
||||||
|
# This operator needs two inputs
|
||||||
|
example_inputs_list.append((_input, target))
|
||||||
|
return example_inputs_list
|
||||||
|
|
||||||
|
def forward(self, input: Any):
|
||||||
|
return self.operator(input)
|
||||||
|
|
||||||
|
# backward doesn't need inputs, but we need to pass it to match the interface
|
||||||
|
def backward(self, input: Any):
|
||||||
|
assert self.forward_output is not None
|
||||||
|
return self.forward_output.backward(retain_graph=True)
|
||||||
|
|
||||||
|
def full(self, input: Any):
|
||||||
|
y = self.forward(input)
|
||||||
|
y.backward()
|
||||||
|
return y
|
||||||
|
|
||||||
|
def prepare_input_and_functions(self, input: Any, phase: Phase):
|
||||||
|
if phase == Phase.BACKWARD:
|
||||||
|
self.forward_output = self.forward(input)
|
||||||
|
return input
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L17
|
||||||
|
|
||||||
|
|
||||||
|
class TorchLMHeadCE(torch.nn.Module):
|
||||||
|
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
|
||||||
|
|
||||||
|
:param H: hidden size
|
||||||
|
:param V: vocab size
|
||||||
|
:param ignore_index: index to ignore
|
||||||
|
:param reduction: reduction method
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__()
|
||||||
|
self.lin = torch.nn.Linear(
|
||||||
|
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
self.ce_loss = torch.nn.CrossEntropyLoss(
|
||||||
|
ignore_index=ignore_index, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x, y = inputs
|
||||||
|
logits = self.lin(x)
|
||||||
|
return self.ce_loss(logits, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Baseline"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.operator = TorchLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||||
|
self.benchmark_config.device.value
|
||||||
|
)
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||||
|
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L40
|
||||||
|
class LigerLMHeadCE(torch.nn.Module):
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__()
|
||||||
|
self.lin = torch.nn.Linear(
|
||||||
|
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||||
|
)
|
||||||
|
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
|
||||||
|
ignore_index=ignore_index, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return self.ce_loss(self.lin.weight, *inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Liger"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.operator = LigerLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||||
|
self.benchmark_config.device.value
|
||||||
|
)
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
from utils.common import BenchmarkConfig
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from . import FusedLinearCrossEntropyOperator, H, V
|
||||||
|
from .baseline import TorchLMHeadCE
|
||||||
|
|
||||||
|
|
||||||
|
class TorchLMHeadCECompiled(TorchLMHeadCE):
|
||||||
|
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||||
|
super().__init__(H, V, dtype, ignore_index)
|
||||||
|
|
||||||
|
|
||||||
|
class Operator(FusedLinearCrossEntropyOperator):
|
||||||
|
variant = "Inductor"
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
super().__init__(benchmark_config)
|
||||||
|
self.operator = TorchLMHeadCECompiled(
|
||||||
|
H=H, V=V, dtype=self.benchmark_config.dtype
|
||||||
|
).to(self.benchmark_config.device.value)
|
||||||
|
self.operator = torch.compile(self.operator)
|
||||||
342
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
342
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from utils.common import BenchmarkConfig, Phase
|
||||||
|
from utils.metrics import Device
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
from torch._inductor.utils import gen_gm_and_inputs
|
||||||
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
||||||
|
from .operator_inp_utils import OperatorInputsLoader, to_channels_last
|
||||||
|
|
||||||
|
|
||||||
|
class OperatorNotFoundError(RuntimeError):
|
||||||
|
"""Custom exception raised when an operator is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOperator:
|
||||||
|
"""
|
||||||
|
Base class for operators.
|
||||||
|
|
||||||
|
This class defines the structure for operator implementations.
|
||||||
|
The forward, backward, full methods should **only contain**
|
||||||
|
the code that users want to benchmark.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name (str): The main name of the operator, e.g. "FusedLinearCrossEntropy".
|
||||||
|
variant (str): The variant of the operator, e.g. "baseline".
|
||||||
|
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||||
|
full_name (str): The full name of the operator (name.variant). It is only valid for variants.
|
||||||
|
It can be either assigned in the operator file or generated from name and variant.
|
||||||
|
example_inputs_list (list): List of example inputs for the operator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = None
|
||||||
|
variant = None
|
||||||
|
benchmark_config = None
|
||||||
|
full_name = None
|
||||||
|
|
||||||
|
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
Initialize the BaseOperator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||||
|
"""
|
||||||
|
self.benchmark_config = benchmark_config
|
||||||
|
if self.full_name is None:
|
||||||
|
self.full_name = f"{self.name}.{self.variant}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_inputs(
|
||||||
|
cls,
|
||||||
|
input_mapping: Dict[str, List],
|
||||||
|
benchmark_config: Optional[BenchmarkConfig] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get or generate example inputs for the operator.
|
||||||
|
|
||||||
|
The format of the inputs is important and should meet the requirements
|
||||||
|
of the operator. It is not necessary to have a unified format for
|
||||||
|
different operators, but the format should be consistent within the
|
||||||
|
same operator.
|
||||||
|
|
||||||
|
This function is different from generate_inputs in that it does not
|
||||||
|
generate inputs, but returns the inputs that have been generated in
|
||||||
|
previous runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_mapping (Dict[str, List]): Mapping from operator name to the input list.
|
||||||
|
benchmark_config (Optional[BenchmarkConfig]): Configuration for the benchmark.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of example inputs.
|
||||||
|
"""
|
||||||
|
if cls.name not in input_mapping:
|
||||||
|
assert (
|
||||||
|
benchmark_config is not None
|
||||||
|
), "Benchmark config is required to generate inputs"
|
||||||
|
generated_inputs = cls.generate_inputs(benchmark_config)
|
||||||
|
input_mapping[cls.name] = generated_inputs
|
||||||
|
return input_mapping[cls.name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
Generate example inputs for the operator. Each operator should implement
|
||||||
|
this method and the format should be consistent with the operator.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""Perform the forward pass of the operator."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def backward(self):
|
||||||
|
"""Perform the backward pass of the operator. It can be bypassed if the operator does not have a backward pass."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def full(self):
|
||||||
|
"""Perform the full (forward + backward) pass of the operator."""
|
||||||
|
raise NotImplementedError("Subclasses must implement this method.")
|
||||||
|
|
||||||
|
def prepare_input_and_functions(self, input):
|
||||||
|
"""
|
||||||
|
If needed, process the input before running the operator. This can be
|
||||||
|
used to prepare the forward output for the backward benchmarking. By default,
|
||||||
|
we return the input directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input to the operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed input.
|
||||||
|
"""
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def _list_operator_paths() -> List[str]:
|
||||||
|
"""
|
||||||
|
List the paths of all operator directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A sorted list of absolute paths to operator directories.
|
||||||
|
"""
|
||||||
|
p = pathlib.Path(__file__).parent
|
||||||
|
# Only load the model directories that contain a "__init.py__" file
|
||||||
|
return sorted(
|
||||||
|
str(child.absolute())
|
||||||
|
for child in p.iterdir()
|
||||||
|
if child.is_dir() and os.path.exists(os.path.join(child, "__init__.py"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_valid_operators(module_path: str, operator_name: str) -> List:
|
||||||
|
"""
|
||||||
|
Load valid operators from a given module path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_path (str): The path to the operator module.
|
||||||
|
operator_name (str): The name of the operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: A list of loaded operator classes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OperatorNotFoundError: If the operator module fails to load.
|
||||||
|
"""
|
||||||
|
loaded_operators = []
|
||||||
|
cls_name = "Operator"
|
||||||
|
|
||||||
|
# Import the operator module
|
||||||
|
try:
|
||||||
|
operator_module = importlib.import_module(module_path, package=__name__)
|
||||||
|
# We only load the operator files that define the valid_operator_files attribute in the operator module
|
||||||
|
valid_operator_files = getattr(operator_module, "valid_operator_files", None)
|
||||||
|
if valid_operator_files is None:
|
||||||
|
raise ImportError(f"{module_path} does not define valid_operator_files")
|
||||||
|
except ImportError as e:
|
||||||
|
raise OperatorNotFoundError(
|
||||||
|
f"Failed to load operator module {module_path}: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
for file_name in valid_operator_files:
|
||||||
|
tmp_file_name = file_name
|
||||||
|
if file_name.endswith(".py"):
|
||||||
|
tmp_file_name = file_name[:-3]
|
||||||
|
operator_file_module_path = f"{module_path}.{tmp_file_name}"
|
||||||
|
try:
|
||||||
|
file_module = importlib.import_module(
|
||||||
|
operator_file_module_path, package=__name__
|
||||||
|
)
|
||||||
|
Operator = getattr(file_module, cls_name, None)
|
||||||
|
if Operator is None:
|
||||||
|
print(
|
||||||
|
f"Warning: {file_module} does not define attribute '{cls_name}', skipping."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not hasattr(Operator, "name") or Operator.name is None:
|
||||||
|
Operator.name = f"{operator_name}"
|
||||||
|
loaded_operators.append(Operator)
|
||||||
|
except ImportError as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Failed to load operator from {operator_file_module_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
return loaded_operators
|
||||||
|
|
||||||
|
|
||||||
|
def list_operators(benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
List all available operators. Each operator represents a variant of an base operator.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: A list of all operator classes.
|
||||||
|
"""
|
||||||
|
# This list is used to store all the operator classes, not instances
|
||||||
|
operators = []
|
||||||
|
if benchmark_config.mode == "native":
|
||||||
|
operators.extend(dynamically_create_native_operator_classes(benchmark_config))
|
||||||
|
else:
|
||||||
|
for operator_path in _list_operator_paths():
|
||||||
|
operator_name = os.path.basename(operator_path)
|
||||||
|
module_path = f"operators.{operator_name}"
|
||||||
|
loaded_operators = _load_valid_operators(module_path, operator_name)
|
||||||
|
operators.extend(loaded_operators)
|
||||||
|
return operators
|
||||||
|
|
||||||
|
|
||||||
|
def dynamically_create_native_operator_classes(benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
To keep same with custom operators, we dynamically create operator classes here.
|
||||||
|
"""
|
||||||
|
timm_loader = OperatorInputsLoader.get_timm_loader()
|
||||||
|
huggingface_loader = OperatorInputsLoader.get_huggingface_loader()
|
||||||
|
torchbench_loader = OperatorInputsLoader.get_torchbench_loader()
|
||||||
|
all_ops = (
|
||||||
|
list(timm_loader.get_all_ops())
|
||||||
|
+ list(huggingface_loader.get_all_ops())
|
||||||
|
+ list(torchbench_loader.get_all_ops())
|
||||||
|
)
|
||||||
|
# remove duplicate operators
|
||||||
|
all_ops = list(set(all_ops))
|
||||||
|
|
||||||
|
def merge_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||||
|
"""
|
||||||
|
We don't differentiate inputs for different suite any more.
|
||||||
|
"""
|
||||||
|
op_eval = cls.op_eval
|
||||||
|
inps_gens = []
|
||||||
|
if str(op_eval) in timm_loader.operator_db:
|
||||||
|
inps_gens.append(
|
||||||
|
timm_loader.get_inputs_for_operator(
|
||||||
|
op_eval,
|
||||||
|
dtype=benchmark_config.dtype,
|
||||||
|
device=benchmark_config.device.value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if str(op_eval) in huggingface_loader.operator_db:
|
||||||
|
inps_gens.append(
|
||||||
|
huggingface_loader.get_inputs_for_operator(
|
||||||
|
op_eval,
|
||||||
|
dtype=benchmark_config.dtype,
|
||||||
|
device=benchmark_config.device.value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if str(op_eval) in torchbench_loader.operator_db:
|
||||||
|
inps_gens.append(
|
||||||
|
torchbench_loader.get_inputs_for_operator(
|
||||||
|
op_eval,
|
||||||
|
dtype=benchmark_config.dtype,
|
||||||
|
device=benchmark_config.device.value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
input_list = []
|
||||||
|
num_samples = min(benchmark_config.max_samples, 1000000)
|
||||||
|
index = 0
|
||||||
|
while index < num_samples:
|
||||||
|
for inp_gen in inps_gens:
|
||||||
|
try:
|
||||||
|
inps = next(inp_gen)
|
||||||
|
# the second element is kwargs and it has to be an empty dict
|
||||||
|
input_list.append(inps)
|
||||||
|
index += 1
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
return input_list
|
||||||
|
|
||||||
|
def prepare_input_and_functions(self, input: Any, phase: Phase):
|
||||||
|
args, kwargs = input
|
||||||
|
if self.benchmark_config.channels_last:
|
||||||
|
args, kwargs = tree_map_only(torch.Tensor, to_channels_last, (args, kwargs))
|
||||||
|
|
||||||
|
gm, gm_args = gen_gm_and_inputs(self.op_eval, args, kwargs)
|
||||||
|
torch.jit._builtins._register_builtin(
|
||||||
|
torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
|
||||||
|
)
|
||||||
|
if self.benchmark_config.device == Device.CUDA:
|
||||||
|
if self.variant == "Eager":
|
||||||
|
cudagraphs_eager = cudagraphs_inner(
|
||||||
|
gm, gm_args, copy_outputs=False, copy_inputs=False
|
||||||
|
)
|
||||||
|
self.forward = cudagraphs_eager
|
||||||
|
self.full = cudagraphs_eager
|
||||||
|
elif self.variant == "Inductor":
|
||||||
|
compiled_fn = compile_fx(gm, gm_args)
|
||||||
|
cudagraphs_compiled = cudagraphs_inner(
|
||||||
|
compiled_fn, gm_args, copy_outputs=False, copy_inputs=False
|
||||||
|
)
|
||||||
|
self.forward = cudagraphs_compiled
|
||||||
|
self.full = cudagraphs_compiled
|
||||||
|
else:
|
||||||
|
if self.variant == "Eager":
|
||||||
|
self.forward = gm
|
||||||
|
self.full = gm
|
||||||
|
elif self.variant == "Inductor":
|
||||||
|
compiled_fn = compile_fx(gm, gm_args)
|
||||||
|
self.forward = compiled_fn
|
||||||
|
self.full = compiled_fn
|
||||||
|
return gm_args
|
||||||
|
|
||||||
|
operators = []
|
||||||
|
for op_eval in all_ops:
|
||||||
|
class_name = f"native_{str(op_eval).replace('.', '_')}"
|
||||||
|
# create a new module for each operator
|
||||||
|
op_name_module = types.ModuleType(f"operators.{class_name}")
|
||||||
|
sys.modules[f"operators.{class_name}"] = op_name_module
|
||||||
|
# create a new module for each varient to help with code organization and printing
|
||||||
|
eager_module = types.ModuleType(f"operators.{class_name}.Eager")
|
||||||
|
sys.modules[f"operators.{class_name}.Eager"] = eager_module
|
||||||
|
inductor_module = types.ModuleType(f"operators.{class_name}.Inductor")
|
||||||
|
sys.modules[f"operators.{class_name}.Inductor"] = inductor_module
|
||||||
|
# the new class for operator, and it is the parent class for all its variants
|
||||||
|
new_op_class = type(class_name, (BaseOperator,), {})
|
||||||
|
# need the loaders to generate inputs for the same operator
|
||||||
|
new_op_class.huggingface_loader = huggingface_loader
|
||||||
|
new_op_class.torchbench_loader = torchbench_loader
|
||||||
|
new_op_class.timm_loader = timm_loader
|
||||||
|
new_op_class.op_eval = op_eval
|
||||||
|
new_op_class.name = str(op_eval)
|
||||||
|
new_op_class.generate_inputs = classmethod(merge_inputs)
|
||||||
|
# create eager and inductor variants classes
|
||||||
|
new_eager_op_class = type(f"{class_name}.Eager.Operator", (new_op_class,), {})
|
||||||
|
new_eager_op_class.variant = "Eager"
|
||||||
|
new_eager_op_class.full_name = f"{new_eager_op_class.name}.Eager"
|
||||||
|
new_eager_op_class.prepare_input_and_functions = prepare_input_and_functions
|
||||||
|
eager_module.Operator = new_eager_op_class
|
||||||
|
new_inductor_op_class = type(
|
||||||
|
f"{class_name}.Inductor.Operator", (new_op_class,), {}
|
||||||
|
)
|
||||||
|
new_inductor_op_class.variant = "Inductor"
|
||||||
|
new_inductor_op_class.full_name = f"{new_inductor_op_class.name}.Inductor"
|
||||||
|
new_inductor_op_class.prepare_input_and_functions = prepare_input_and_functions
|
||||||
|
inductor_module.Operator = new_inductor_op_class
|
||||||
|
operators.append(new_eager_op_class)
|
||||||
|
operators.append(new_inductor_op_class)
|
||||||
|
return operators
|
||||||
@ -8,7 +8,7 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from operator_inp_utils import OperatorInputsLoader
|
from operator_inp_utils import OperatorInputsLoader, to_channels_last
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -72,6 +72,7 @@ def compute_speedups(
|
|||||||
from torch._inductor.utils import timed
|
from torch._inductor.utils import timed
|
||||||
|
|
||||||
timings[rep, m] = timed(model, example_inputs)
|
timings[rep, m] = timed(model, example_inputs)
|
||||||
|
print(timings)
|
||||||
return np.median(timings, axis=0)
|
return np.median(timings, axis=0)
|
||||||
|
|
||||||
|
|
||||||
@ -96,10 +97,6 @@ def convert_to_jit(gm, gm_args):
|
|||||||
return torch.jit.trace(gm, gm_args)
|
return torch.jit.trace(gm, gm_args)
|
||||||
|
|
||||||
|
|
||||||
def to_channels_last(ten):
|
|
||||||
return ten if ten.ndim != 4 else ten.to(memory_format=torch.channels_last)
|
|
||||||
|
|
||||||
|
|
||||||
def microbenchmark(
|
def microbenchmark(
|
||||||
operator,
|
operator,
|
||||||
args,
|
args,
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user