mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 07:24:54 +08:00
Compare commits
38 Commits
ciflow/b20
...
findhao/op
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d983f0775 | |||
| f320e4ba86 | |||
| 73dfb2bc2d | |||
| f2654ae713 | |||
| ab40b51c5d | |||
| 8c1b793071 | |||
| de205901f3 | |||
| 9f2936931a | |||
| 5fa8031ae5 | |||
| a245137d76 | |||
| 2d93c5f720 | |||
| 6f3b42a073 | |||
| db4c9a54a2 | |||
| f78da95bc5 | |||
| 7c2bc74a72 | |||
| 7b366a2b70 | |||
| b437ffe8b0 | |||
| 085e2f5416 | |||
| 425ad9ccdb | |||
| 900671f799 | |||
| 1f30017712 | |||
| 1fdf24d9a5 | |||
| 8a4bc3cc09 | |||
| 78f5027b48 | |||
| 6e2d4c661a | |||
| 30dd419560 | |||
| f280038562 | |||
| 0bb482185c | |||
| 18c2804981 | |||
| a6b6bbc293 | |||
| ebd4755b0d | |||
| 8779577950 | |||
| a6d9a506c3 | |||
| a0ecd4f45d | |||
| 9cdac0662b | |||
| 57be1aae4b | |||
| 22ee74895b | |||
| ea97de291b |
@ -4785,9 +4785,9 @@ def log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
|
||||
|
||||
print(f"Running {name}")
|
||||
try:
|
||||
from .microbenchmarks.operator_inp_utils import OperatorInputsMode
|
||||
from operatorbench.operators.operator_inp_utils import OperatorInputsMode
|
||||
except ImportError:
|
||||
from microbenchmarks.operator_inp_utils import OperatorInputsMode
|
||||
from .operatorbench.operators.operator_inp_utils import OperatorInputsMode
|
||||
|
||||
operator_mode = OperatorInputsMode()
|
||||
fake_tensor_mode = FakeTensorMode()
|
||||
|
||||
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
7
benchmarks/dynamo/operatorbench/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# Add the current directory to the system path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(current_dir)
|
||||
@ -0,0 +1,65 @@
|
||||
from typing import Any, Callable, List
|
||||
|
||||
from utils.common import BenchmarkConfig, Phase
|
||||
|
||||
import torch
|
||||
|
||||
from .. import BaseOperator
|
||||
|
||||
|
||||
H = 4096
|
||||
V = 128256
|
||||
# Each file defines an operator variant
|
||||
valid_operator_files = ["baseline.py", "custom.py", "inductor.py"]
|
||||
|
||||
|
||||
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py
|
||||
|
||||
|
||||
class FusedLinearCrossEntropyOperator(BaseOperator):
|
||||
# The base operator name
|
||||
name = "FusedLinearCrossEntropy"
|
||||
# The variant placeholder. No need to set in the base operator class
|
||||
variant = None
|
||||
|
||||
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||
super().__init__(benchmark_config)
|
||||
self.forward_output = None
|
||||
|
||||
@classmethod
|
||||
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||
example_inputs_list = []
|
||||
# May need OOM check
|
||||
for BT in [2**i for i in range(12, 16)]:
|
||||
_input = torch.randn(
|
||||
BT,
|
||||
H,
|
||||
requires_grad=True,
|
||||
dtype=benchmark_config.dtype,
|
||||
device=benchmark_config.device.value,
|
||||
)
|
||||
target = torch.randint(
|
||||
V, (BT, 1), dtype=torch.long, device=benchmark_config.device.value
|
||||
).squeeze(1)
|
||||
# This operator needs two inputs
|
||||
example_inputs_list.append((_input, target))
|
||||
return example_inputs_list
|
||||
|
||||
def forward(self, input: Any):
|
||||
return self.operator(input)
|
||||
|
||||
# backward doesn't need inputs, but we need to pass it to match the interface
|
||||
def backward(self, input: Any):
|
||||
assert self.forward_output is not None
|
||||
return self.forward_output.backward(retain_graph=True)
|
||||
|
||||
def full(self, input: Any):
|
||||
y = self.forward(input)
|
||||
y.backward()
|
||||
return y
|
||||
|
||||
def prepare_input_and_functions(self, input: Any, phase: Phase):
|
||||
if phase == Phase.BACKWARD:
|
||||
self.forward_output = self.forward(input)
|
||||
return input
|
||||
@ -0,0 +1,43 @@
|
||||
from utils.common import BenchmarkConfig
|
||||
|
||||
import torch
|
||||
|
||||
from . import FusedLinearCrossEntropyOperator, H, V
|
||||
|
||||
|
||||
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L17
|
||||
|
||||
|
||||
class TorchLMHeadCE(torch.nn.Module):
|
||||
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
|
||||
|
||||
:param H: hidden size
|
||||
:param V: vocab size
|
||||
:param ignore_index: index to ignore
|
||||
:param reduction: reduction method
|
||||
"""
|
||||
|
||||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(
|
||||
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||
)
|
||||
self.ce_loss = torch.nn.CrossEntropyLoss(
|
||||
ignore_index=ignore_index, reduction="mean"
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
x, y = inputs
|
||||
logits = self.lin(x)
|
||||
return self.ce_loss(logits, y)
|
||||
|
||||
|
||||
class Operator(FusedLinearCrossEntropyOperator):
|
||||
variant = "Baseline"
|
||||
|
||||
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||
super().__init__(benchmark_config)
|
||||
self.operator = TorchLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||
self.benchmark_config.device.value
|
||||
)
|
||||
@ -0,0 +1,34 @@
|
||||
from liger_kernel.transformers.fused_linear_cross_entropy import (
|
||||
LigerFusedLinearCrossEntropyLoss,
|
||||
)
|
||||
from utils.common import BenchmarkConfig
|
||||
|
||||
import torch
|
||||
|
||||
from . import FusedLinearCrossEntropyOperator, H, V
|
||||
|
||||
|
||||
# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
|
||||
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py#L40
|
||||
class LigerLMHeadCE(torch.nn.Module):
|
||||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||
super().__init__()
|
||||
self.lin = torch.nn.Linear(
|
||||
in_features=H, out_features=V, bias=False, dtype=dtype
|
||||
)
|
||||
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
|
||||
ignore_index=ignore_index, reduction="mean"
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.ce_loss(self.lin.weight, *inputs)
|
||||
|
||||
|
||||
class Operator(FusedLinearCrossEntropyOperator):
|
||||
variant = "Liger"
|
||||
|
||||
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||
super().__init__(benchmark_config)
|
||||
self.operator = LigerLMHeadCE(H=H, V=V, dtype=self.benchmark_config.dtype).to(
|
||||
self.benchmark_config.device.value
|
||||
)
|
||||
@ -0,0 +1,22 @@
|
||||
from utils.common import BenchmarkConfig
|
||||
|
||||
import torch
|
||||
|
||||
from . import FusedLinearCrossEntropyOperator, H, V
|
||||
from .baseline import TorchLMHeadCE
|
||||
|
||||
|
||||
class TorchLMHeadCECompiled(TorchLMHeadCE):
|
||||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
|
||||
super().__init__(H, V, dtype, ignore_index)
|
||||
|
||||
|
||||
class Operator(FusedLinearCrossEntropyOperator):
|
||||
variant = "Inductor"
|
||||
|
||||
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||
super().__init__(benchmark_config)
|
||||
self.operator = TorchLMHeadCECompiled(
|
||||
H=H, V=V, dtype=self.benchmark_config.dtype
|
||||
).to(self.benchmark_config.device.value)
|
||||
self.operator = torch.compile(self.operator)
|
||||
342
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
342
benchmarks/dynamo/operatorbench/operators/__init__.py
Normal file
@ -0,0 +1,342 @@
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from utils.common import BenchmarkConfig, Phase
|
||||
from utils.metrics import Device
|
||||
|
||||
import torch
|
||||
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch._inductor.utils import gen_gm_and_inputs
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
from .operator_inp_utils import OperatorInputsLoader, to_channels_last
|
||||
|
||||
|
||||
class OperatorNotFoundError(RuntimeError):
|
||||
"""Custom exception raised when an operator is not found."""
|
||||
|
||||
|
||||
class BaseOperator:
|
||||
"""
|
||||
Base class for operators.
|
||||
|
||||
This class defines the structure for operator implementations.
|
||||
The forward, backward, full methods should **only contain**
|
||||
the code that users want to benchmark.
|
||||
|
||||
Attributes:
|
||||
name (str): The main name of the operator, e.g. "FusedLinearCrossEntropy".
|
||||
variant (str): The variant of the operator, e.g. "baseline".
|
||||
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||
full_name (str): The full name of the operator (name.variant). It is only valid for variants.
|
||||
It can be either assigned in the operator file or generated from name and variant.
|
||||
example_inputs_list (list): List of example inputs for the operator.
|
||||
"""
|
||||
|
||||
name = None
|
||||
variant = None
|
||||
benchmark_config = None
|
||||
full_name = None
|
||||
|
||||
def __init__(self, benchmark_config: BenchmarkConfig):
|
||||
"""
|
||||
Initialize the BaseOperator.
|
||||
|
||||
Args:
|
||||
benchmark_config (BenchmarkConfig): Configuration for the benchmark.
|
||||
"""
|
||||
self.benchmark_config = benchmark_config
|
||||
if self.full_name is None:
|
||||
self.full_name = f"{self.name}.{self.variant}"
|
||||
|
||||
@classmethod
|
||||
def get_inputs(
|
||||
cls,
|
||||
input_mapping: Dict[str, List],
|
||||
benchmark_config: Optional[BenchmarkConfig] = None,
|
||||
):
|
||||
"""
|
||||
Get or generate example inputs for the operator.
|
||||
|
||||
The format of the inputs is important and should meet the requirements
|
||||
of the operator. It is not necessary to have a unified format for
|
||||
different operators, but the format should be consistent within the
|
||||
same operator.
|
||||
|
||||
This function is different from generate_inputs in that it does not
|
||||
generate inputs, but returns the inputs that have been generated in
|
||||
previous runs.
|
||||
|
||||
Args:
|
||||
input_mapping (Dict[str, List]): Mapping from operator name to the input list.
|
||||
benchmark_config (Optional[BenchmarkConfig]): Configuration for the benchmark.
|
||||
|
||||
Returns:
|
||||
list: List of example inputs.
|
||||
"""
|
||||
if cls.name not in input_mapping:
|
||||
assert (
|
||||
benchmark_config is not None
|
||||
), "Benchmark config is required to generate inputs"
|
||||
generated_inputs = cls.generate_inputs(benchmark_config)
|
||||
input_mapping[cls.name] = generated_inputs
|
||||
return input_mapping[cls.name]
|
||||
|
||||
@classmethod
|
||||
def generate_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||
"""
|
||||
Generate example inputs for the operator. Each operator should implement
|
||||
this method and the format should be consistent with the operator.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def forward(self):
|
||||
"""Perform the forward pass of the operator."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def backward(self):
|
||||
"""Perform the backward pass of the operator. It can be bypassed if the operator does not have a backward pass."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def full(self):
|
||||
"""Perform the full (forward + backward) pass of the operator."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
def prepare_input_and_functions(self, input):
|
||||
"""
|
||||
If needed, process the input before running the operator. This can be
|
||||
used to prepare the forward output for the backward benchmarking. By default,
|
||||
we return the input directly.
|
||||
|
||||
Args:
|
||||
input: The input to the operator.
|
||||
|
||||
Returns:
|
||||
The processed input.
|
||||
"""
|
||||
return input
|
||||
|
||||
|
||||
def _list_operator_paths() -> List[str]:
|
||||
"""
|
||||
List the paths of all operator directories.
|
||||
|
||||
Returns:
|
||||
List[str]: A sorted list of absolute paths to operator directories.
|
||||
"""
|
||||
p = pathlib.Path(__file__).parent
|
||||
# Only load the model directories that contain a "__init.py__" file
|
||||
return sorted(
|
||||
str(child.absolute())
|
||||
for child in p.iterdir()
|
||||
if child.is_dir() and os.path.exists(os.path.join(child, "__init__.py"))
|
||||
)
|
||||
|
||||
|
||||
def _load_valid_operators(module_path: str, operator_name: str) -> List:
|
||||
"""
|
||||
Load valid operators from a given module path.
|
||||
|
||||
Args:
|
||||
module_path (str): The path to the operator module.
|
||||
operator_name (str): The name of the operator.
|
||||
|
||||
Returns:
|
||||
List: A list of loaded operator classes.
|
||||
|
||||
Raises:
|
||||
OperatorNotFoundError: If the operator module fails to load.
|
||||
"""
|
||||
loaded_operators = []
|
||||
cls_name = "Operator"
|
||||
|
||||
# Import the operator module
|
||||
try:
|
||||
operator_module = importlib.import_module(module_path, package=__name__)
|
||||
# We only load the operator files that define the valid_operator_files attribute in the operator module
|
||||
valid_operator_files = getattr(operator_module, "valid_operator_files", None)
|
||||
if valid_operator_files is None:
|
||||
raise ImportError(f"{module_path} does not define valid_operator_files")
|
||||
except ImportError as e:
|
||||
raise OperatorNotFoundError(
|
||||
f"Failed to load operator module {module_path}: {str(e)}"
|
||||
) from e
|
||||
|
||||
for file_name in valid_operator_files:
|
||||
tmp_file_name = file_name
|
||||
if file_name.endswith(".py"):
|
||||
tmp_file_name = file_name[:-3]
|
||||
operator_file_module_path = f"{module_path}.{tmp_file_name}"
|
||||
try:
|
||||
file_module = importlib.import_module(
|
||||
operator_file_module_path, package=__name__
|
||||
)
|
||||
Operator = getattr(file_module, cls_name, None)
|
||||
if Operator is None:
|
||||
print(
|
||||
f"Warning: {file_module} does not define attribute '{cls_name}', skipping."
|
||||
)
|
||||
else:
|
||||
if not hasattr(Operator, "name") or Operator.name is None:
|
||||
Operator.name = f"{operator_name}"
|
||||
loaded_operators.append(Operator)
|
||||
except ImportError as e:
|
||||
print(
|
||||
f"Warning: Failed to load operator from {operator_file_module_path}: {str(e)}"
|
||||
)
|
||||
return loaded_operators
|
||||
|
||||
|
||||
def list_operators(benchmark_config: BenchmarkConfig):
|
||||
"""
|
||||
List all available operators. Each operator represents a variant of an base operator.
|
||||
|
||||
Returns:
|
||||
List: A list of all operator classes.
|
||||
"""
|
||||
# This list is used to store all the operator classes, not instances
|
||||
operators = []
|
||||
if benchmark_config.mode == "native":
|
||||
operators.extend(dynamically_create_native_operator_classes(benchmark_config))
|
||||
else:
|
||||
for operator_path in _list_operator_paths():
|
||||
operator_name = os.path.basename(operator_path)
|
||||
module_path = f"operators.{operator_name}"
|
||||
loaded_operators = _load_valid_operators(module_path, operator_name)
|
||||
operators.extend(loaded_operators)
|
||||
return operators
|
||||
|
||||
|
||||
def dynamically_create_native_operator_classes(benchmark_config: BenchmarkConfig):
|
||||
"""
|
||||
To keep same with custom operators, we dynamically create operator classes here.
|
||||
"""
|
||||
timm_loader = OperatorInputsLoader.get_timm_loader()
|
||||
huggingface_loader = OperatorInputsLoader.get_huggingface_loader()
|
||||
torchbench_loader = OperatorInputsLoader.get_torchbench_loader()
|
||||
all_ops = (
|
||||
list(timm_loader.get_all_ops())
|
||||
+ list(huggingface_loader.get_all_ops())
|
||||
+ list(torchbench_loader.get_all_ops())
|
||||
)
|
||||
# remove duplicate operators
|
||||
all_ops = list(set(all_ops))
|
||||
|
||||
def merge_inputs(cls, benchmark_config: BenchmarkConfig):
|
||||
"""
|
||||
We don't differentiate inputs for different suite any more.
|
||||
"""
|
||||
op_eval = cls.op_eval
|
||||
inps_gens = []
|
||||
if str(op_eval) in timm_loader.operator_db:
|
||||
inps_gens.append(
|
||||
timm_loader.get_inputs_for_operator(
|
||||
op_eval,
|
||||
dtype=benchmark_config.dtype,
|
||||
device=benchmark_config.device.value,
|
||||
)
|
||||
)
|
||||
if str(op_eval) in huggingface_loader.operator_db:
|
||||
inps_gens.append(
|
||||
huggingface_loader.get_inputs_for_operator(
|
||||
op_eval,
|
||||
dtype=benchmark_config.dtype,
|
||||
device=benchmark_config.device.value,
|
||||
)
|
||||
)
|
||||
if str(op_eval) in torchbench_loader.operator_db:
|
||||
inps_gens.append(
|
||||
torchbench_loader.get_inputs_for_operator(
|
||||
op_eval,
|
||||
dtype=benchmark_config.dtype,
|
||||
device=benchmark_config.device.value,
|
||||
)
|
||||
)
|
||||
input_list = []
|
||||
num_samples = min(benchmark_config.max_samples, 1000000)
|
||||
index = 0
|
||||
while index < num_samples:
|
||||
for inp_gen in inps_gens:
|
||||
try:
|
||||
inps = next(inp_gen)
|
||||
# the second element is kwargs and it has to be an empty dict
|
||||
input_list.append(inps)
|
||||
index += 1
|
||||
except StopIteration:
|
||||
break
|
||||
return input_list
|
||||
|
||||
def prepare_input_and_functions(self, input: Any, phase: Phase):
|
||||
args, kwargs = input
|
||||
if self.benchmark_config.channels_last:
|
||||
args, kwargs = tree_map_only(torch.Tensor, to_channels_last, (args, kwargs))
|
||||
|
||||
gm, gm_args = gen_gm_and_inputs(self.op_eval, args, kwargs)
|
||||
torch.jit._builtins._register_builtin(
|
||||
torch.ops.aten.convolution_backward.default, "aten::convolution_backward"
|
||||
)
|
||||
if self.benchmark_config.device == Device.CUDA:
|
||||
if self.variant == "Eager":
|
||||
cudagraphs_eager = cudagraphs_inner(
|
||||
gm, gm_args, copy_outputs=False, copy_inputs=False
|
||||
)
|
||||
self.forward = cudagraphs_eager
|
||||
self.full = cudagraphs_eager
|
||||
elif self.variant == "Inductor":
|
||||
compiled_fn = compile_fx(gm, gm_args)
|
||||
cudagraphs_compiled = cudagraphs_inner(
|
||||
compiled_fn, gm_args, copy_outputs=False, copy_inputs=False
|
||||
)
|
||||
self.forward = cudagraphs_compiled
|
||||
self.full = cudagraphs_compiled
|
||||
else:
|
||||
if self.variant == "Eager":
|
||||
self.forward = gm
|
||||
self.full = gm
|
||||
elif self.variant == "Inductor":
|
||||
compiled_fn = compile_fx(gm, gm_args)
|
||||
self.forward = compiled_fn
|
||||
self.full = compiled_fn
|
||||
return gm_args
|
||||
|
||||
operators = []
|
||||
for op_eval in all_ops:
|
||||
class_name = f"native_{str(op_eval).replace('.', '_')}"
|
||||
# create a new module for each operator
|
||||
op_name_module = types.ModuleType(f"operators.{class_name}")
|
||||
sys.modules[f"operators.{class_name}"] = op_name_module
|
||||
# create a new module for each varient to help with code organization and printing
|
||||
eager_module = types.ModuleType(f"operators.{class_name}.Eager")
|
||||
sys.modules[f"operators.{class_name}.Eager"] = eager_module
|
||||
inductor_module = types.ModuleType(f"operators.{class_name}.Inductor")
|
||||
sys.modules[f"operators.{class_name}.Inductor"] = inductor_module
|
||||
# the new class for operator, and it is the parent class for all its variants
|
||||
new_op_class = type(class_name, (BaseOperator,), {})
|
||||
# need the loaders to generate inputs for the same operator
|
||||
new_op_class.huggingface_loader = huggingface_loader
|
||||
new_op_class.torchbench_loader = torchbench_loader
|
||||
new_op_class.timm_loader = timm_loader
|
||||
new_op_class.op_eval = op_eval
|
||||
new_op_class.name = str(op_eval)
|
||||
new_op_class.generate_inputs = classmethod(merge_inputs)
|
||||
# create eager and inductor variants classes
|
||||
new_eager_op_class = type(f"{class_name}.Eager.Operator", (new_op_class,), {})
|
||||
new_eager_op_class.variant = "Eager"
|
||||
new_eager_op_class.full_name = f"{new_eager_op_class.name}.Eager"
|
||||
new_eager_op_class.prepare_input_and_functions = prepare_input_and_functions
|
||||
eager_module.Operator = new_eager_op_class
|
||||
new_inductor_op_class = type(
|
||||
f"{class_name}.Inductor.Operator", (new_op_class,), {}
|
||||
)
|
||||
new_inductor_op_class.variant = "Inductor"
|
||||
new_inductor_op_class.full_name = f"{new_inductor_op_class.name}.Inductor"
|
||||
new_inductor_op_class.prepare_input_and_functions = prepare_input_and_functions
|
||||
inductor_module.Operator = new_inductor_op_class
|
||||
operators.append(new_eager_op_class)
|
||||
operators.append(new_inductor_op_class)
|
||||
return operators
|
||||
@ -8,7 +8,7 @@ from contextlib import nullcontext
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
from operator_inp_utils import OperatorInputsLoader
|
||||
from operator_inp_utils import OperatorInputsLoader, to_channels_last
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
@ -72,6 +72,7 @@ def compute_speedups(
|
||||
from torch._inductor.utils import timed
|
||||
|
||||
timings[rep, m] = timed(model, example_inputs)
|
||||
print(timings)
|
||||
return np.median(timings, axis=0)
|
||||
|
||||
|
||||
@ -96,10 +97,6 @@ def convert_to_jit(gm, gm_args):
|
||||
return torch.jit.trace(gm, gm_args)
|
||||
|
||||
|
||||
def to_channels_last(ten):
|
||||
return ten if ten.ndim != 4 else ten.to(memory_format=torch.channels_last)
|
||||
|
||||
|
||||
def microbenchmark(
|
||||
operator,
|
||||
args,
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user