Compare commits

...

54 Commits

Author SHA1 Message Date
7669445a70 Merge remote-tracking branch 'origin/add_op_tests' into perf_ops 2025-09-28 01:55:13 -07:00
b7eae1cc34 b200 benchmarks seperate run 2025-09-28 01:51:07 -07:00
a710d65523 Update measurement for cuda 2025-09-26 22:16:33 -07:00
eddf149b0c Remove short config from cpu run 2025-09-24 09:10:53 -07:00
62a91acda9 Updates 2025-09-24 09:05:42 -07:00
45760a2f7f Don't configure AWS credentials on A100/H100 2025-09-23 23:48:43 -07:00
5fa2fe9539 Seperate the b100 and h100+a100 run 2025-09-23 22:18:47 -07:00
2e3d0429c2 Update and re-run 2025-09-23 20:04:26 -07:00
c8a53c3383 <Replace this line with a title. Use 1 line only, 67 chars or less>
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2025-09-23 19:59:50 -07:00
682d542bfb Change cron schedule comment to 'everyday' 2025-09-23 15:18:58 -07:00
a60037fd72 Add B200 and seperate add addmm 2025-09-23 15:12:52 -07:00
bfc9680175 Remove py2.8 seperate code 2025-09-22 11:24:24 -07:00
5031e026fc Build with fallback 2025-09-21 23:44:31 -07:00
195779ec3b [Testing] Operator benchmark baseline 2.8 2025-09-19 18:12:49 -07:00
e934b6ab40 Update add_test 2025-09-18 11:50:06 -07:00
ca59a71675 Test 2025-09-16 18:06:05 -07:00
93ad3fec44 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-16 10:01:05 -07:00
783f8064d1 Fix yml string 2025-09-15 21:29:02 -07:00
f47fa8d2f8 Fix syntax for running operator benchmarks 2025-09-14 11:04:24 -07:00
3635731fc2 Update yml 2025-09-12 18:47:13 -07:00
98a71c71b2 Fix extra_flags formatting in benchmark workflow 2025-09-12 17:38:22 -07:00
86e3803f3b Update Docker image name for operator benchmark 2025-09-12 16:05:05 -07:00
056bcfc333 Update operator_microbenchmark.yml to use a bigger runner 2025-09-12 15:38:56 -07:00
cc2b171704 Update docker 2025-09-12 15:36:11 -07:00
7b16f72b09 Fix quotes 2025-09-12 13:50:46 -07:00
f47e539765 fix trailing comma 2025-09-12 13:23:30 -07:00
49e5e122fe Fix params for python cmd 2025-09-12 12:55:57 -07:00
6b8cc19597 Fix concurrency and cuda arch 2025-09-12 11:43:08 -07:00
d683fb9ebe Add include into the test matrix 2025-09-12 11:14:54 -07:00
9eca494626 Tweak the build step 2025-09-12 10:56:27 -07:00
7f5b0bcec8 Fix CI 2025-09-11 23:30:42 -07:00
4c257bca07 Fix CI 2025-09-11 21:26:28 -07:00
05bb4d4fc6 Fix CI 2025-09-11 17:00:49 -07:00
8d0cafb8bb Add h100, a100 2025-09-11 14:23:36 -07:00
629de8d7ba Fixes 2025-09-11 14:18:51 -07:00
71ae2d8280 Add ci flow 2025-09-10 22:05:03 -07:00
2fe66701c1 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-10 21:54:16 -07:00
c021d0349e Add ci flow 2025-09-10 21:47:50 -07:00
c6f1a29b17 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-09 21:23:23 -07:00
54c9527a81 Add mm benchmarking tests 2025-09-09 14:25:07 -07:00
cf31d4b744 Add mm benchmarking tests 2025-09-09 14:17:08 -07:00
9c701f03ee update json 2025-09-08 22:16:33 -07:00
c193ed6c84 Merge remote-tracking branch 'origin/main' into add_compile_benchmarking 2025-09-08 12:42:06 -07:00
eab7bd0d4c Remove mm_bwd_test.py 2025-09-08 11:34:58 -07:00
199318f978 Remove cpu benchmarking 2025-09-08 11:28:30 -07:00
9b226b2ce4 Add cpu memory calculation 2025-09-08 00:10:10 -07:00
6357d4e05a Add cpu memory calculation 2025-09-08 00:03:50 -07:00
162e7d3c20 Updates 2025-09-07 22:02:53 -07:00
ada9c165dd Lint fixes 2025-09-04 13:12:34 -07:00
461c7ad698 Enable bwd pass 2025-09-03 21:51:42 -07:00
819159610d Add fixes 2025-09-03 20:47:09 -07:00
d257ebf9c7 Add peak memory calculation 2025-09-03 11:10:16 -07:00
aab478833d Make jit and compile mutually exclusive 2025-08-27 14:21:37 -07:00
ba1319f414 Update the op benchmarking, to benchmark using torch.compile 2025-08-25 00:15:50 -07:00
11 changed files with 325 additions and 97 deletions

View File

@ -1614,6 +1614,26 @@ test_operator_benchmark() {
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
} }
test_operator_microbenchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
pip_uninstall torch torchvision torchaudio
pip_install torch==2.8.0 torchvision torchaudio ninja --force-reinstall
cd benchmarks/operator_benchmark/pt_extension
python -m pip install .
cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}.json" \
--benchmark-name "PyTorch operator microbenchmark"
done
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.show())")
@ -1668,6 +1688,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
test_operator_benchmark cpu ${TEST_MODE} test_operator_benchmark cpu ${TEST_MODE}
fi fi
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
test_operator_microbenchmark
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then

View File

@ -273,6 +273,8 @@ jobs:
TEST_CONFIG: ${{ matrix.config }} TEST_CONFIG: ${{ matrix.config }}
SHARD_NUMBER: ${{ matrix.shard }} SHARD_NUMBER: ${{ matrix.shard }}
NUM_TEST_SHARDS: ${{ matrix.num_shards }} NUM_TEST_SHARDS: ${{ matrix.num_shards }}
EXTRA_FLAGS: ${{ matrix.extra_flags || '' }}
OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }}
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }}
VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }}

View File

@ -0,0 +1,46 @@
name: operator_microbenchmark
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 6 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
opmicrobenchmark-build:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
]}
secrets: inherit
opmicrobenchmark-test:
name: opmicrobenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -0,0 +1,46 @@
name: operator_microbenchmark_b200
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 6 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
opmicrobenchmark-build:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
secrets: inherit
opmicrobenchmark-test:
name: opmicrobenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -18,6 +18,7 @@ import torch
# needs to be imported after torch # needs to be imported after torch
import torch.utils.cpp_extension as cpp_extension # noqa: F401 import torch.utils.cpp_extension as cpp_extension # noqa: F401
from torch.utils.benchmark import Timer
"""Performance microbenchmarks. """Performance microbenchmarks.
@ -348,10 +349,24 @@ class BenchmarkRunner:
func = test_case.run_jit_forward func = test_case.run_jit_forward
if self.use_compile: if self.use_compile:
func = test_case.run_compile_forward func = test_case.run_compile_forward
forward_time = timeit.timeit(
functools.partial(func, iters, print_per_iter, cuda_sync), number=1 if not cuda_sync:
forward_time = timeit.timeit(
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
)
return forward_time
# Stable timing with Timer
timer = Timer(
stmt="func(iters, print_per_iter, cuda_sync)",
globals={
"func": func,
"iters": iters,
"print_per_iter": print_per_iter,
"cuda_sync": cuda_sync,
},
) )
return forward_time result = timer.adaptive_autorange(min_run_time=0.0001)
return result.median * iters
def _launch_backward(self, test_case, iters, print_per_iter=False): def _launch_backward(self, test_case, iters, print_per_iter=False):
"""This function runs forward path of an op to get an output. Then the backward path is executed """This function runs forward path of an op to get an output. Then the backward path is executed

View File

@ -161,6 +161,8 @@ class PyTorchOperatorTestCase:
if self._compile_forward_graph is None: if self._compile_forward_graph is None:
self._compile_forward_graph = self._generate_compile_forward_graph() self._compile_forward_graph = self._generate_compile_forward_graph()
self._compile_forward_graph(num_runs) self._compile_forward_graph(num_runs)
if cuda_sync:
torch.cuda.synchronize(torch.cuda.current_device())
def _print_per_iter(self): def _print_per_iter(self):
# print last 50 values # print last 50 values

View File

@ -52,27 +52,6 @@ class AddBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark)
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark) op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark)
"""Mircobenchmark for addmm operator."""
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device):
self.inputs = {
"input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()),
"mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()),
"mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()),
}
self.set_module_name("addmm")
def forward(self, input_one, mat1, mat2):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark)
"""Mircobenchmark for addr operator.""" """Mircobenchmark for addr operator."""
@ -106,46 +85,5 @@ addr_configs = op_bench.cross_product_configs(
op_bench.generate_pt_test(addr_configs, AddrBenchmark) op_bench.generate_pt_test(addr_configs, AddrBenchmark)
op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark) op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark)
"""Mircobenchmark for addbmm operator."""
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device):
self.inputs = {
"input_one": torch.rand(
(M, N), device=device, requires_grad=self.auto_set()
),
"batch1": torch.rand(
(B, M, K), device=device, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(
B,
K,
N,
),
device=device,
requires_grad=self.auto_set(),
),
}
self.set_module_name("addbmm")
def forward(self, input_one, batch1, batch2):
return torch.addbmm(input_one, batch1, batch2)
addbmm_configs = op_bench.cross_product_configs(
B=[2, 100],
M=[8, 256],
N=[256, 16],
K=[15, 16],
device=["cpu", "cuda"],
tags=["addbmm"],
)
op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark)
if __name__ == "__main__": if __name__ == "__main__":
op_bench.benchmark_runner.main() op_bench.benchmark_runner.main()

View File

@ -0,0 +1,115 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for add_(matmul) operator. Supports both Caffe2/PyTorch."""
# Configs for PT add operator
addmm_long_configs = op_bench.cross_product_configs(
M=[256, 1024, 3000],
N=[512, 4096],
K=[512, 4096],
device=["cuda"],
tags=["long"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
)
addmm_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype": [torch.float],
},
tags=["short"],
)
"""Mircobenchmark for addmm operator."""
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype):
self.inputs = {
"input_one": torch.rand(
M, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"mat1": torch.rand(
M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"mat2": torch.rand(
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
}
self.set_module_name("addmm")
def forward(self, input_one, mat1, mat2):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(
addmm_long_configs + addmm_long_configs, AddmmBenchmark
)
"""Mircobenchmark for addbmm operator."""
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype):
self.inputs = {
"input_one": torch.rand(
(M, N), device=device, requires_grad=self.auto_set(), dtype=dtype
),
"batch1": torch.rand(
(B, M, K), device=device, requires_grad=self.auto_set(), dtype=dtype
),
"batch2": torch.rand(
(
B,
K,
N,
),
device=device,
requires_grad=self.auto_set(),
dtype=dtype,
),
}
self.set_module_name("addbmm")
def forward(self, input_one, batch1, batch2):
return torch.addbmm(input_one, batch1, batch2)
addbmm_long_configs = op_bench.cross_product_configs(
B=[8, 32],
M=[256, 1024],
N=[256, 1024],
K=[64, 128],
device=["cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"],
)
addbmm_short_configs = op_bench.cross_product_configs(
B=[1, 8],
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["short"],
)
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -27,12 +27,12 @@ batched_binary_configs_short = op_bench.config_list(
) )
batched_binary_configs_long = op_bench.cross_product_configs( batched_binary_configs_long = op_bench.cross_product_configs(
B=[1, 128], B=[8, 32],
M=[8, 128], M=[256, 1024],
N=[32, 64], N=[256, 1024],
K=[4, 256], K=[64, 128],
device=["cpu", "cuda"], device=["cuda"],
dtype=[torch.float, torch.bfloat16], dtype=[torch.float32, torch.bfloat16, torch.float16],
tags=["long"], tags=["long"],
) )
@ -40,8 +40,12 @@ batched_binary_configs_long = op_bench.cross_product_configs(
class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase): class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func): def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = { self.inputs = {
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), "batch1": torch.rand(
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
),
} }
self.op_func = op_func self.op_func = op_func
@ -54,6 +58,11 @@ op_bench.generate_pt_tests_from_op_list(
batched_binary_configs_short + batched_binary_configs_long, batched_binary_configs_short + batched_binary_configs_long,
BatchedBinaryOpBenchmark, BatchedBinaryOpBenchmark,
) )
op_bench.generate_pt_gradient_tests_from_op_list(
batched_binary_ops,
batched_binary_configs_long,
BatchedBinaryOpBenchmark,
)
# batched ternary ops # batched ternary ops
@ -66,9 +75,15 @@ batched_ternary_ops = op_bench.op_list(
class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase): class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func): def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = { self.inputs = {
"input_": torch.rand((B, M, K), device=device).to(dtype=dtype), "input_": torch.rand(
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), (B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set()
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), ),
"batch1": torch.rand(
(B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
),
} }
self.op_func = op_func self.op_func = op_func
@ -81,6 +96,12 @@ op_bench.generate_pt_tests_from_op_list(
batched_binary_configs_short + batched_binary_configs_long, batched_binary_configs_short + batched_binary_configs_long,
BatchedTernaryOpBenchmark, BatchedTernaryOpBenchmark,
) )
op_bench.generate_pt_gradient_tests_from_op_list(
batched_ternary_ops,
batched_binary_configs_long,
BatchedTernaryOpBenchmark,
)
# TODO: does it automatically register new scripts? # TODO: does it automatically register new scripts?

View File

@ -13,33 +13,46 @@ mm_short_configs = op_bench.config_list(
[128, 128, 128, True, False], [128, 128, 128, True, False],
[256, 256, 256, False, True], [256, 256, 256, False, True],
], ],
cross_product_configs={ cross_product_configs={"device": ["cpu", "cuda"], "dtype": [torch.float]},
"device": ["cpu", "cuda"],
},
tags=["short"], tags=["short"],
) )
mm_long_configs = op_bench.cross_product_configs( mm_long_configs = op_bench.cross_product_configs(
M=[32], M=[256, 1024, 3000],
N=[512, 128], N=[512, 4096],
K=[64], K=[512, 4096],
trans_a=[False, True], trans_a=[False, True],
trans_b=[True, False], trans_b=[True, False],
device=["cpu", "cuda"], device=["cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"], tags=["long"],
) )
class MatMulBenchmark(op_bench.TorchBenchmarkBase): class MatMulBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, trans_a, trans_b, device): def init(self, M, N, K, trans_a, trans_b, device, dtype):
# Create tensors without requires_grad first, then set it separately
# This avoids creating graph leaves that cannot be deep copied
if trans_a:
input_one = torch.rand(M, N, device=device, dtype=dtype)
else:
input_one = torch.rand(N, M, device=device, dtype=dtype).t()
if trans_b:
input_two = torch.rand(N, K, device=device, dtype=dtype)
else:
input_two = torch.rand(K, N, device=device, dtype=dtype).t()
# Set requires_grad after tensor creation to avoid graph leaf issues
if self.auto_set():
input_one.requires_grad_(True)
if self.auto_set():
input_two.requires_grad_(True)
self.inputs = { self.inputs = {
"input_one": torch.rand(M, N, device=device) "input_one": input_one,
if trans_a "input_two": input_two,
else torch.rand(N, M, device=device).t(),
"input_two": torch.rand(N, K, device=device)
if trans_b
else torch.rand(K, N, device=device).t(),
} }
self.set_module_name("matmul") self.set_module_name("matmul")
@ -48,6 +61,7 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark)
op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -23,11 +23,11 @@ mm_short_configs = op_bench.config_list(
) )
mm_long_configs = op_bench.cross_product_configs( mm_long_configs = op_bench.cross_product_configs(
M=[8, 128], M=[256, 1024, 3000],
N=[32, 64], N=[512, 4096],
K=[256, 512], K=[512, 4096],
device=["cpu", "cuda"], device=["cuda"],
dtype=[torch.float, torch.bfloat16], dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"], tags=["long"],
) )
@ -35,8 +35,12 @@ mm_long_configs = op_bench.cross_product_configs(
class MmOpBenchmark(op_bench.TorchBenchmarkBase): class MmOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype, op_func): def init(self, M, N, K, device, dtype, op_func):
self.inputs = { self.inputs = {
"input_one": torch.randn(M, N, device=device).to(dtype=dtype), "input_one": torch.randn(
"input_two": torch.randn(N, K, device=device).to(dtype=dtype), M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"input_two": torch.randn(
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
} }
self.op_func = op_func self.op_func = op_func
@ -47,6 +51,9 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_tests_from_op_list( op_bench.generate_pt_tests_from_op_list(
ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark
) )
op_bench.generate_pt_gradient_tests_from_op_list(
ops_list, mm_long_configs, MmOpBenchmark
)
if __name__ == "__main__": if __name__ == "__main__":