Compare commits

...

54 Commits

Author SHA1 Message Date
7669445a70 Merge remote-tracking branch 'origin/add_op_tests' into perf_ops 2025-09-28 01:55:13 -07:00
b7eae1cc34 b200 benchmarks seperate run 2025-09-28 01:51:07 -07:00
a710d65523 Update measurement for cuda 2025-09-26 22:16:33 -07:00
eddf149b0c Remove short config from cpu run 2025-09-24 09:10:53 -07:00
62a91acda9 Updates 2025-09-24 09:05:42 -07:00
45760a2f7f Don't configure AWS credentials on A100/H100 2025-09-23 23:48:43 -07:00
5fa2fe9539 Seperate the b100 and h100+a100 run 2025-09-23 22:18:47 -07:00
2e3d0429c2 Update and re-run 2025-09-23 20:04:26 -07:00
c8a53c3383 <Replace this line with a title. Use 1 line only, 67 chars or less>
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
2025-09-23 19:59:50 -07:00
682d542bfb Change cron schedule comment to 'everyday' 2025-09-23 15:18:58 -07:00
a60037fd72 Add B200 and seperate add addmm 2025-09-23 15:12:52 -07:00
bfc9680175 Remove py2.8 seperate code 2025-09-22 11:24:24 -07:00
5031e026fc Build with fallback 2025-09-21 23:44:31 -07:00
195779ec3b [Testing] Operator benchmark baseline 2.8 2025-09-19 18:12:49 -07:00
e934b6ab40 Update add_test 2025-09-18 11:50:06 -07:00
ca59a71675 Test 2025-09-16 18:06:05 -07:00
93ad3fec44 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-16 10:01:05 -07:00
783f8064d1 Fix yml string 2025-09-15 21:29:02 -07:00
f47fa8d2f8 Fix syntax for running operator benchmarks 2025-09-14 11:04:24 -07:00
3635731fc2 Update yml 2025-09-12 18:47:13 -07:00
98a71c71b2 Fix extra_flags formatting in benchmark workflow 2025-09-12 17:38:22 -07:00
86e3803f3b Update Docker image name for operator benchmark 2025-09-12 16:05:05 -07:00
056bcfc333 Update operator_microbenchmark.yml to use a bigger runner 2025-09-12 15:38:56 -07:00
cc2b171704 Update docker 2025-09-12 15:36:11 -07:00
7b16f72b09 Fix quotes 2025-09-12 13:50:46 -07:00
f47e539765 fix trailing comma 2025-09-12 13:23:30 -07:00
49e5e122fe Fix params for python cmd 2025-09-12 12:55:57 -07:00
6b8cc19597 Fix concurrency and cuda arch 2025-09-12 11:43:08 -07:00
d683fb9ebe Add include into the test matrix 2025-09-12 11:14:54 -07:00
9eca494626 Tweak the build step 2025-09-12 10:56:27 -07:00
7f5b0bcec8 Fix CI 2025-09-11 23:30:42 -07:00
4c257bca07 Fix CI 2025-09-11 21:26:28 -07:00
05bb4d4fc6 Fix CI 2025-09-11 17:00:49 -07:00
8d0cafb8bb Add h100, a100 2025-09-11 14:23:36 -07:00
629de8d7ba Fixes 2025-09-11 14:18:51 -07:00
71ae2d8280 Add ci flow 2025-09-10 22:05:03 -07:00
2fe66701c1 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-10 21:54:16 -07:00
c021d0349e Add ci flow 2025-09-10 21:47:50 -07:00
c6f1a29b17 Merge remote-tracking branch 'origin/main' into add_op_tests 2025-09-09 21:23:23 -07:00
54c9527a81 Add mm benchmarking tests 2025-09-09 14:25:07 -07:00
cf31d4b744 Add mm benchmarking tests 2025-09-09 14:17:08 -07:00
9c701f03ee update json 2025-09-08 22:16:33 -07:00
c193ed6c84 Merge remote-tracking branch 'origin/main' into add_compile_benchmarking 2025-09-08 12:42:06 -07:00
eab7bd0d4c Remove mm_bwd_test.py 2025-09-08 11:34:58 -07:00
199318f978 Remove cpu benchmarking 2025-09-08 11:28:30 -07:00
9b226b2ce4 Add cpu memory calculation 2025-09-08 00:10:10 -07:00
6357d4e05a Add cpu memory calculation 2025-09-08 00:03:50 -07:00
162e7d3c20 Updates 2025-09-07 22:02:53 -07:00
ada9c165dd Lint fixes 2025-09-04 13:12:34 -07:00
461c7ad698 Enable bwd pass 2025-09-03 21:51:42 -07:00
819159610d Add fixes 2025-09-03 20:47:09 -07:00
d257ebf9c7 Add peak memory calculation 2025-09-03 11:10:16 -07:00
aab478833d Make jit and compile mutually exclusive 2025-08-27 14:21:37 -07:00
ba1319f414 Update the op benchmarking, to benchmark using torch.compile 2025-08-25 00:15:50 -07:00
11 changed files with 325 additions and 97 deletions

View File

@ -1614,6 +1614,26 @@ test_operator_benchmark() {
--expected "expected_ci_operator_benchmark_eager_float32_cpu.csv"
}
test_operator_microbenchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-reports
mkdir -p "$TEST_REPORTS_DIR"
TEST_DIR=$(pwd)
pip_uninstall torch torchvision torchaudio
pip_install torch==2.8.0 torchvision torchaudio ninja --force-reinstall
cd benchmarks/operator_benchmark/pt_extension
python -m pip install .
cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}.json" \
--benchmark-name "PyTorch operator microbenchmark"
done
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
@ -1668,6 +1688,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then
test_operator_benchmark cpu ${TEST_MODE}
fi
elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then
test_operator_microbenchmark
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then

View File

@ -273,6 +273,8 @@ jobs:
TEST_CONFIG: ${{ matrix.config }}
SHARD_NUMBER: ${{ matrix.shard }}
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
EXTRA_FLAGS: ${{ matrix.extra_flags || '' }}
OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }}
REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }}
CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }}
VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }}

View File

@ -0,0 +1,46 @@
name: operator_microbenchmark
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 6 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
opmicrobenchmark-build:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '8.0 9.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" },
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" },
]}
secrets: inherit
opmicrobenchmark-test:
name: opmicrobenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -0,0 +1,46 @@
name: operator_microbenchmark_b200
on:
push:
tags:
- ciflow/op-benchmark/*
workflow_dispatch:
schedule:
# Run at 06:00 UTC everyday
- cron: 0 6 * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
opmicrobenchmark-build:
if: github.repository_owner == 'pytorch'
name: opmicrobenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
secrets: inherit
opmicrobenchmark-test:
name: opmicrobenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: opmicrobenchmark-build
with:
timeout-minutes: 500
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80
docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
secrets: inherit

View File

@ -18,6 +18,7 @@ import torch
# needs to be imported after torch
import torch.utils.cpp_extension as cpp_extension # noqa: F401
from torch.utils.benchmark import Timer
"""Performance microbenchmarks.
@ -348,10 +349,24 @@ class BenchmarkRunner:
func = test_case.run_jit_forward
if self.use_compile:
func = test_case.run_compile_forward
forward_time = timeit.timeit(
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
if not cuda_sync:
forward_time = timeit.timeit(
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
)
return forward_time
# Stable timing with Timer
timer = Timer(
stmt="func(iters, print_per_iter, cuda_sync)",
globals={
"func": func,
"iters": iters,
"print_per_iter": print_per_iter,
"cuda_sync": cuda_sync,
},
)
return forward_time
result = timer.adaptive_autorange(min_run_time=0.0001)
return result.median * iters
def _launch_backward(self, test_case, iters, print_per_iter=False):
"""This function runs forward path of an op to get an output. Then the backward path is executed

View File

@ -161,6 +161,8 @@ class PyTorchOperatorTestCase:
if self._compile_forward_graph is None:
self._compile_forward_graph = self._generate_compile_forward_graph()
self._compile_forward_graph(num_runs)
if cuda_sync:
torch.cuda.synchronize(torch.cuda.current_device())
def _print_per_iter(self):
# print last 50 values

View File

@ -52,27 +52,6 @@ class AddBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark)
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark)
"""Mircobenchmark for addmm operator."""
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device):
self.inputs = {
"input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()),
"mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()),
"mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()),
}
self.set_module_name("addmm")
def forward(self, input_one, mat1, mat2):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark)
"""Mircobenchmark for addr operator."""
@ -106,46 +85,5 @@ addr_configs = op_bench.cross_product_configs(
op_bench.generate_pt_test(addr_configs, AddrBenchmark)
op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark)
"""Mircobenchmark for addbmm operator."""
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device):
self.inputs = {
"input_one": torch.rand(
(M, N), device=device, requires_grad=self.auto_set()
),
"batch1": torch.rand(
(B, M, K), device=device, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(
B,
K,
N,
),
device=device,
requires_grad=self.auto_set(),
),
}
self.set_module_name("addbmm")
def forward(self, input_one, batch1, batch2):
return torch.addbmm(input_one, batch1, batch2)
addbmm_configs = op_bench.cross_product_configs(
B=[2, 100],
M=[8, 256],
N=[256, 16],
K=[15, 16],
device=["cpu", "cuda"],
tags=["addbmm"],
)
op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -0,0 +1,115 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for add_(matmul) operator. Supports both Caffe2/PyTorch."""
# Configs for PT add operator
addmm_long_configs = op_bench.cross_product_configs(
M=[256, 1024, 3000],
N=[512, 4096],
K=[512, 4096],
device=["cuda"],
tags=["long"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
)
addmm_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype": [torch.float],
},
tags=["short"],
)
"""Mircobenchmark for addmm operator."""
class AddmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype):
self.inputs = {
"input_one": torch.rand(
M, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"mat1": torch.rand(
M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"mat2": torch.rand(
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
}
self.set_module_name("addmm")
def forward(self, input_one, mat1, mat2):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(
addmm_long_configs + addmm_long_configs, AddmmBenchmark
)
"""Mircobenchmark for addbmm operator."""
class AddbmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype):
self.inputs = {
"input_one": torch.rand(
(M, N), device=device, requires_grad=self.auto_set(), dtype=dtype
),
"batch1": torch.rand(
(B, M, K), device=device, requires_grad=self.auto_set(), dtype=dtype
),
"batch2": torch.rand(
(
B,
K,
N,
),
device=device,
requires_grad=self.auto_set(),
dtype=dtype,
),
}
self.set_module_name("addbmm")
def forward(self, input_one, batch1, batch2):
return torch.addbmm(input_one, batch1, batch2)
addbmm_long_configs = op_bench.cross_product_configs(
B=[8, 32],
M=[256, 1024],
N=[256, 1024],
K=[64, 128],
device=["cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"],
)
addbmm_short_configs = op_bench.cross_product_configs(
B=[1, 8],
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["short"],
)
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -27,12 +27,12 @@ batched_binary_configs_short = op_bench.config_list(
)
batched_binary_configs_long = op_bench.cross_product_configs(
B=[1, 128],
M=[8, 128],
N=[32, 64],
K=[4, 256],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
B=[8, 32],
M=[256, 1024],
N=[256, 1024],
K=[64, 128],
device=["cuda"],
dtype=[torch.float32, torch.bfloat16, torch.float16],
tags=["long"],
)
@ -40,8 +40,12 @@ batched_binary_configs_long = op_bench.cross_product_configs(
class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = {
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
"batch1": torch.rand(
(B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
),
}
self.op_func = op_func
@ -54,6 +58,11 @@ op_bench.generate_pt_tests_from_op_list(
batched_binary_configs_short + batched_binary_configs_long,
BatchedBinaryOpBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
batched_binary_ops,
batched_binary_configs_long,
BatchedBinaryOpBenchmark,
)
# batched ternary ops
@ -66,9 +75,15 @@ batched_ternary_ops = op_bench.op_list(
class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = {
"input_": torch.rand((B, M, K), device=device).to(dtype=dtype),
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
"input_": torch.rand(
(B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set()
),
"batch1": torch.rand(
(B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set()
),
}
self.op_func = op_func
@ -81,6 +96,12 @@ op_bench.generate_pt_tests_from_op_list(
batched_binary_configs_short + batched_binary_configs_long,
BatchedTernaryOpBenchmark,
)
op_bench.generate_pt_gradient_tests_from_op_list(
batched_ternary_ops,
batched_binary_configs_long,
BatchedTernaryOpBenchmark,
)
# TODO: does it automatically register new scripts?

View File

@ -13,33 +13,46 @@ mm_short_configs = op_bench.config_list(
[128, 128, 128, True, False],
[256, 256, 256, False, True],
],
cross_product_configs={
"device": ["cpu", "cuda"],
},
cross_product_configs={"device": ["cpu", "cuda"], "dtype": [torch.float]},
tags=["short"],
)
mm_long_configs = op_bench.cross_product_configs(
M=[32],
N=[512, 128],
K=[64],
M=[256, 1024, 3000],
N=[512, 4096],
K=[512, 4096],
trans_a=[False, True],
trans_b=[True, False],
device=["cpu", "cuda"],
device=["cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"],
)
class MatMulBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, trans_a, trans_b, device):
def init(self, M, N, K, trans_a, trans_b, device, dtype):
# Create tensors without requires_grad first, then set it separately
# This avoids creating graph leaves that cannot be deep copied
if trans_a:
input_one = torch.rand(M, N, device=device, dtype=dtype)
else:
input_one = torch.rand(N, M, device=device, dtype=dtype).t()
if trans_b:
input_two = torch.rand(N, K, device=device, dtype=dtype)
else:
input_two = torch.rand(K, N, device=device, dtype=dtype).t()
# Set requires_grad after tensor creation to avoid graph leaf issues
if self.auto_set():
input_one.requires_grad_(True)
if self.auto_set():
input_two.requires_grad_(True)
self.inputs = {
"input_one": torch.rand(M, N, device=device)
if trans_a
else torch.rand(N, M, device=device).t(),
"input_two": torch.rand(N, K, device=device)
if trans_b
else torch.rand(K, N, device=device).t(),
"input_one": input_one,
"input_two": input_two,
}
self.set_module_name("matmul")
@ -48,6 +61,7 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark)
op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark)
if __name__ == "__main__":

View File

@ -23,11 +23,11 @@ mm_short_configs = op_bench.config_list(
)
mm_long_configs = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
M=[256, 1024, 3000],
N=[512, 4096],
K=[512, 4096],
device=["cuda"],
dtype=[torch.float16, torch.bfloat16, torch.float32],
tags=["long"],
)
@ -35,8 +35,12 @@ mm_long_configs = op_bench.cross_product_configs(
class MmOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype, op_func):
self.inputs = {
"input_one": torch.randn(M, N, device=device).to(dtype=dtype),
"input_two": torch.randn(N, K, device=device).to(dtype=dtype),
"input_one": torch.randn(
M, N, device=device, requires_grad=self.auto_set(), dtype=dtype
),
"input_two": torch.randn(
N, K, device=device, requires_grad=self.auto_set(), dtype=dtype
),
}
self.op_func = op_func
@ -47,6 +51,9 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_tests_from_op_list(
ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark
)
op_bench.generate_pt_gradient_tests_from_op_list(
ops_list, mm_long_configs, MmOpBenchmark
)
if __name__ == "__main__":