mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add operator benchmarking run to CI nightly (#162530)
This PR introduces a new "operator microbenchmark" CI workflow and GitHub Actions for operator microbenchmarks, updating test scripts and job matrices to support new parameters, and broadening the operator benchmark tests to include more data types, larger shapes, and gradient tests. The benchmark configurations now focus more on different cuda hardware and multiple dtypes (bf16, fp16, fp32), for both compile and eager mode. **Benchmark Configuration and Coverage:** * Expanded operator benchmark configurations in `addmm_test.py`, `bmm_test.py`, `matmul_test.py`, and `mm_test.py` to benchmark multiple dtypes on CUDA devices, in eager and compile mode, for forward and backward run. The configs with tag "long" for the above mentioned files are being run in CI. * The CI benchmarking is running on various hardwares: H100, A100. * The CI job also uploads the microbenchmarking outputs to a [HUD](https://hud.pytorch.org/benchmark/llms?repoName=pytorch%2Fpytorch&benchmarkName=PyTorch+operator+microbenchmark) dashboard. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162530 Approved by: https://github.com/huydhn Co-authored-by: Huy Do <huydhn@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc5a072ebf
commit
54b38f3b46
@ -1630,6 +1630,25 @@ test_operator_benchmark() {
|
||||
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
|
||||
}
|
||||
|
||||
test_operator_microbenchmark() {
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
TEST_DIR=$(pwd)
|
||||
|
||||
cd benchmarks/operator_benchmark/pt_extension
|
||||
python -m pip install .
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark"
|
||||
done
|
||||
}
|
||||
|
||||
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
(cd test && python -c "import torch; print(torch.__config__.show())")
|
||||
@ -1686,6 +1705,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
|
||||
test_operator_benchmark cpu ${TEST_MODE}
|
||||
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
|
||||
test_operator_microbenchmark
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
|
2
.github/workflows/_linux-test.yml
vendored
2
.github/workflows/_linux-test.yml
vendored
@ -273,6 +273,8 @@ jobs:
|
||||
TEST_CONFIG: ${{ matrix.config }}
|
||||
SHARD_NUMBER: ${{ matrix.shard }}
|
||||
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
|
||||
EXTRA_FLAGS: ${{ matrix.extra_flags || '' }}
|
||||
OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }}
|
||||
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
|
||||
CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }}
|
||||
VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }}
|
||||
|
46
.github/workflows/operator_microbenchmark.yml
vendored
Normal file
46
.github/workflows/operator_microbenchmark.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: operator_microbenchmark
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/op-benchmark/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
# Run at 06:00 UTC everyday
|
||||
- cron: 0 6 * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
opmicrobenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: opmicrobenchmark-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '8.0 9.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
|
||||
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
opmicrobenchmark-test:
|
||||
name: opmicrobenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: opmicrobenchmark-build
|
||||
with:
|
||||
timeout-minutes: 500
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
|
||||
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
@ -18,6 +18,7 @@ import torch
|
||||
|
||||
# needs to be imported after torch
|
||||
import torch.utils.cpp_extension as cpp_extension # noqa: F401
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
|
||||
"""Performance microbenchmarks.
|
||||
@ -348,10 +349,24 @@ class BenchmarkRunner:
|
||||
func = test_case.run_jit_forward
|
||||
if self.use_compile:
|
||||
func = test_case.run_compile_forward
|
||||
forward_time = timeit.timeit(
|
||||
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
|
||||
|
||||
if not cuda_sync:
|
||||
forward_time = timeit.timeit(
|
||||
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
|
||||
)
|
||||
return forward_time
|
||||
# Stable timing with Timer
|
||||
timer = Timer(
|
||||
stmt="func(iters, print_per_iter, cuda_sync)",
|
||||
globals={
|
||||
"func": func,
|
||||
"iters": iters,
|
||||
"print_per_iter": print_per_iter,
|
||||
"cuda_sync": cuda_sync,
|
||||
},
|
||||
)
|
||||
return forward_time
|
||||
result = timer.adaptive_autorange(min_run_time=0.0001)
|
||||
return result.median * iters
|
||||
|
||||
def _launch_backward(self, test_case, iters, print_per_iter=False):
|
||||
"""This function runs forward path of an op to get an output. Then the backward path is executed
|
||||
|
@ -161,6 +161,8 @@ class PyTorchOperatorTestCase:
|
||||
if self._compile_forward_graph is None:
|
||||
self._compile_forward_graph = self._generate_compile_forward_graph()
|
||||
self._compile_forward_graph(num_runs)
|
||||
if cuda_sync:
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
def _print_per_iter(self):
|
||||
# print last 50 values
|
||||
|
@ -52,27 +52,6 @@ class AddBenchmark(op_bench.TorchBenchmarkBase):
|
||||
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark)
|
||||
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark)
|
||||
|
||||
|
||||
"""Mircobenchmark for addmm operator."""
|
||||
|
||||
|
||||
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()),
|
||||
"mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()),
|
||||
"mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()),
|
||||
}
|
||||
self.set_module_name("addmm")
|
||||
|
||||
def forward(self, input_one, mat1, mat2):
|
||||
return torch.addmm(input_one, mat1, mat2)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark)
|
||||
|
||||
|
||||
"""Mircobenchmark for addr operator."""
|
||||
|
||||
|
||||
@ -106,46 +85,5 @@ addr_configs = op_bench.cross_product_configs(
|
||||
op_bench.generate_pt_test(addr_configs, AddrBenchmark)
|
||||
op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark)
|
||||
|
||||
|
||||
"""Mircobenchmark for addbmm operator."""
|
||||
|
||||
|
||||
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.rand(
|
||||
(M, N), device=device, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch1": torch.rand(
|
||||
(B, M, K), device=device, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch2": torch.rand(
|
||||
(
|
||||
B,
|
||||
K,
|
||||
N,
|
||||
),
|
||||
device=device,
|
||||
requires_grad=self.auto_set(),
|
||||
),
|
||||
}
|
||||
self.set_module_name("addbmm")
|
||||
|
||||
def forward(self, input_one, batch1, batch2):
|
||||
return torch.addbmm(input_one, batch1, batch2)
|
||||
|
||||
|
||||
addbmm_configs = op_bench.cross_product_configs(
|
||||
B=[2, 100],
|
||||
M=[8, 256],
|
||||
N=[256, 16],
|
||||
K=[15, 16],
|
||||
device=["cpu", "cuda"],
|
||||
tags=["addbmm"],
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
115
benchmarks/operator_benchmark/pt/addmm_test.py
Normal file
115
benchmarks/operator_benchmark/pt/addmm_test.py
Normal file
@ -0,0 +1,115 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for add_(matmul) operator. Supports both Caffe2/PyTorch."""
|
||||
|
||||
# Configs for PT add operator
|
||||
addmm_long_configs = op_bench.cross_product_configs(
|
||||
M=[256, 1024, 3000],
|
||||
N=[512, 4096],
|
||||
K=[512, 4096],
|
||||
device=["cuda"],
|
||||
tags=["long"],
|
||||
dtype=[torch.float16, torch.bfloat16, torch.float32],
|
||||
)
|
||||
|
||||
|
||||
addmm_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype": [torch.float],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
"""Mircobenchmark for addmm operator."""
|
||||
|
||||
|
||||
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype):
|
||||
self.inputs = {
|
||||
"input_one": torch.rand(
|
||||
M, K, device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
"mat1": torch.rand(
|
||||
M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
"mat2": torch.rand(
|
||||
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
}
|
||||
self.set_module_name("addmm")
|
||||
|
||||
def forward(self, input_one, mat1, mat2):
|
||||
return torch.addmm(input_one, mat1, mat2)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
addmm_long_configs + addmm_long_configs, AddmmBenchmark
|
||||
)
|
||||
|
||||
"""Mircobenchmark for addbmm operator."""
|
||||
|
||||
|
||||
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, dtype):
|
||||
self.inputs = {
|
||||
"input_one": torch.rand(
|
||||
(M, N), device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
"batch1": torch.rand(
|
||||
(B, M, K), device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
"batch2": torch.rand(
|
||||
(
|
||||
B,
|
||||
K,
|
||||
N,
|
||||
),
|
||||
device=device,
|
||||
requires_grad=self.auto_set(),
|
||||
dtype=dtype,
|
||||
),
|
||||
}
|
||||
self.set_module_name("addbmm")
|
||||
|
||||
def forward(self, input_one, batch1, batch2):
|
||||
return torch.addbmm(input_one, batch1, batch2)
|
||||
|
||||
|
||||
addbmm_long_configs = op_bench.cross_product_configs(
|
||||
B=[8, 32],
|
||||
M=[256, 1024],
|
||||
N=[256, 1024],
|
||||
K=[64, 128],
|
||||
device=["cuda"],
|
||||
dtype=[torch.float16, torch.bfloat16, torch.float32],
|
||||
tags=["long"],
|
||||
)
|
||||
addbmm_short_configs = op_bench.cross_product_configs(
|
||||
B=[1, 8],
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float16, torch.bfloat16, torch.float32],
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -27,12 +27,12 @@ batched_binary_configs_short = op_bench.config_list(
|
||||
)
|
||||
|
||||
batched_binary_configs_long = op_bench.cross_product_configs(
|
||||
B=[1, 128],
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[4, 256],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
B=[8, 32],
|
||||
M=[256, 1024],
|
||||
N=[256, 1024],
|
||||
K=[64, 128],
|
||||
device=["cuda"],
|
||||
dtype=[torch.float32, torch.bfloat16, torch.float16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
@ -40,8 +40,12 @@ batched_binary_configs_long = op_bench.cross_product_configs(
|
||||
class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
|
||||
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
|
||||
"batch1": torch.rand(
|
||||
(B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch2": torch.rand(
|
||||
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
|
||||
),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
@ -54,6 +58,11 @@ op_bench.generate_pt_tests_from_op_list(
|
||||
batched_binary_configs_short + batched_binary_configs_long,
|
||||
BatchedBinaryOpBenchmark,
|
||||
)
|
||||
op_bench.generate_pt_gradient_tests_from_op_list(
|
||||
batched_binary_ops,
|
||||
batched_binary_configs_long,
|
||||
BatchedBinaryOpBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# batched ternary ops
|
||||
@ -66,9 +75,15 @@ batched_ternary_ops = op_bench.op_list(
|
||||
class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"input_": torch.rand((B, M, K), device=device).to(dtype=dtype),
|
||||
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
|
||||
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
|
||||
"input_": torch.rand(
|
||||
(B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch1": torch.rand(
|
||||
(B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch2": torch.rand(
|
||||
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
|
||||
),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
@ -81,6 +96,12 @@ op_bench.generate_pt_tests_from_op_list(
|
||||
batched_binary_configs_short + batched_binary_configs_long,
|
||||
BatchedTernaryOpBenchmark,
|
||||
)
|
||||
op_bench.generate_pt_gradient_tests_from_op_list(
|
||||
batched_ternary_ops,
|
||||
batched_binary_configs_long,
|
||||
BatchedTernaryOpBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# TODO: does it automatically register new scripts?
|
||||
|
||||
|
@ -13,33 +13,46 @@ mm_short_configs = op_bench.config_list(
|
||||
[128, 128, 128, True, False],
|
||||
[256, 256, 256, False, True],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
},
|
||||
cross_product_configs={"device": ["cpu", "cuda"]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
mm_long_configs = op_bench.cross_product_configs(
|
||||
M=[32],
|
||||
N=[512, 128],
|
||||
K=[64],
|
||||
M=[256, 1024, 3000],
|
||||
N=[512, 4096],
|
||||
K=[512, 4096],
|
||||
trans_a=[False, True],
|
||||
trans_b=[True, False],
|
||||
device=["cpu", "cuda"],
|
||||
device=["cuda"],
|
||||
dtype=[torch.float16, torch.bfloat16, torch.float32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class MatMulBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, trans_a, trans_b, device):
|
||||
def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float):
|
||||
# Create tensors without requires_grad first, then set it separately
|
||||
# This avoids creating graph leaves that cannot be deep copied
|
||||
if trans_a:
|
||||
input_one = torch.rand(M, N, device=device, dtype=dtype)
|
||||
else:
|
||||
input_one = torch.rand(N, M, device=device, dtype=dtype).t()
|
||||
|
||||
if trans_b:
|
||||
input_two = torch.rand(N, K, device=device, dtype=dtype)
|
||||
else:
|
||||
input_two = torch.rand(K, N, device=device, dtype=dtype).t()
|
||||
|
||||
# Set requires_grad after tensor creation to avoid graph leaf issues
|
||||
if self.auto_set():
|
||||
input_one.requires_grad_(True)
|
||||
if self.auto_set():
|
||||
input_two.requires_grad_(True)
|
||||
|
||||
self.inputs = {
|
||||
"input_one": torch.rand(M, N, device=device)
|
||||
if trans_a
|
||||
else torch.rand(N, M, device=device).t(),
|
||||
"input_two": torch.rand(N, K, device=device)
|
||||
if trans_b
|
||||
else torch.rand(K, N, device=device).t(),
|
||||
"input_one": input_one,
|
||||
"input_two": input_two,
|
||||
}
|
||||
self.set_module_name("matmul")
|
||||
|
||||
@ -48,6 +61,7 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase):
|
||||
|
||||
|
||||
op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark)
|
||||
op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -23,11 +23,11 @@ mm_short_configs = op_bench.config_list(
|
||||
)
|
||||
|
||||
mm_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
M=[256, 1024, 3000],
|
||||
N=[512, 4096],
|
||||
K=[512, 4096],
|
||||
device=["cuda"],
|
||||
dtype=[torch.float16, torch.bfloat16, torch.float32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
@ -35,8 +35,12 @@ mm_long_configs = op_bench.cross_product_configs(
|
||||
class MmOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"input_one": torch.randn(M, N, device=device).to(dtype=dtype),
|
||||
"input_two": torch.randn(N, K, device=device).to(dtype=dtype),
|
||||
"input_one": torch.randn(
|
||||
M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
"input_two": torch.randn(
|
||||
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
|
||||
),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
@ -47,6 +51,9 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark
|
||||
)
|
||||
op_bench.generate_pt_gradient_tests_from_op_list(
|
||||
ops_list, mm_long_configs, MmOpBenchmark
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user