diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index a6cb190b7eaf..b44836867977 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1630,6 +1630,25 @@ test_operator_benchmark() { --expected "expected_ci_operator_benchmark_eager_float32_cpu.csv" } +test_operator_microbenchmark() { + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + TEST_DIR=$(pwd) + + cd benchmarks/operator_benchmark/pt_extension + python -m pip install . + + cd "${TEST_DIR}"/benchmarks/operator_benchmark + + for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do + $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ + --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ + --benchmark-name "PyTorch operator microbenchmark" --use-compile + $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ + --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}.json" \ + --benchmark-name "PyTorch operator microbenchmark" + done +} if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") @@ -1686,6 +1705,8 @@ elif [[ "${TEST_CONFIG}" == *operator_benchmark* ]]; then test_operator_benchmark cpu ${TEST_MODE} fi +elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then + test_operator_microbenchmark elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 537e94488b36..5fa116d74e6e 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -273,6 +273,8 @@ jobs: TEST_CONFIG: ${{ matrix.config }} SHARD_NUMBER: ${{ matrix.shard }} NUM_TEST_SHARDS: ${{ matrix.num_shards }} + EXTRA_FLAGS: ${{ matrix.extra_flags || '' }} + OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }} REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml new file mode 100644 index 000000000000..9205b927c5d7 --- /dev/null +++ b/.github/workflows/operator_microbenchmark.yml @@ -0,0 +1,46 @@ +name: operator_microbenchmark + +on: + push: + tags: + - ciflow/op-benchmark/* + workflow_dispatch: + schedule: + # Run at 06:00 UTC everyday + - cron: 0 6 * * * + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + opmicrobenchmark-build: + if: github.repository_owner == 'pytorch' + name: opmicrobenchmark-build + uses: ./.github/workflows/_linux-build.yml + with: + runner: linux.12xlarge.memory + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '8.0 9.0' + test-matrix: | + { include: [ + { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.h100" }, + { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, + ]} + secrets: inherit + + opmicrobenchmark-test: + name: opmicrobenchmark-test + uses: ./.github/workflows/_linux-test.yml + needs: opmicrobenchmark-build + with: + timeout-minutes: 500 + build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} + secrets: inherit diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 3caaf3e3a916..3f79ed2318c4 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -18,6 +18,7 @@ import torch # needs to be imported after torch import torch.utils.cpp_extension as cpp_extension # noqa: F401 +from torch.utils.benchmark import Timer """Performance microbenchmarks. @@ -348,10 +349,24 @@ class BenchmarkRunner: func = test_case.run_jit_forward if self.use_compile: func = test_case.run_compile_forward - forward_time = timeit.timeit( - functools.partial(func, iters, print_per_iter, cuda_sync), number=1 + + if not cuda_sync: + forward_time = timeit.timeit( + functools.partial(func, iters, print_per_iter, cuda_sync), number=1 + ) + return forward_time + # Stable timing with Timer + timer = Timer( + stmt="func(iters, print_per_iter, cuda_sync)", + globals={ + "func": func, + "iters": iters, + "print_per_iter": print_per_iter, + "cuda_sync": cuda_sync, + }, ) - return forward_time + result = timer.adaptive_autorange(min_run_time=0.0001) + return result.median * iters def _launch_backward(self, test_case, iters, print_per_iter=False): """This function runs forward path of an op to get an output. Then the backward path is executed diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index a7ff40ebb340..cfed9ebac04b 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -161,6 +161,8 @@ class PyTorchOperatorTestCase: if self._compile_forward_graph is None: self._compile_forward_graph = self._generate_compile_forward_graph() self._compile_forward_graph(num_runs) + if cuda_sync: + torch.cuda.synchronize(torch.cuda.current_device()) def _print_per_iter(self): # print last 50 values diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index 54504c4f3005..739b8ef14a54 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -52,27 +52,6 @@ class AddBenchmark(op_bench.TorchBenchmarkBase): op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark) op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark) - -"""Mircobenchmark for addmm operator.""" - - -class AddmmBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K, device): - self.inputs = { - "input_one": torch.rand(M, K, device=device, requires_grad=self.auto_set()), - "mat1": torch.rand(M, N, device=device, requires_grad=self.auto_set()), - "mat2": torch.rand(N, K, device=device, requires_grad=self.auto_set()), - } - self.set_module_name("addmm") - - def forward(self, input_one, mat1, mat2): - return torch.addmm(input_one, mat1, mat2) - - -op_bench.generate_pt_test(add_long_configs + add_short_configs, AddmmBenchmark) -op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddmmBenchmark) - - """Mircobenchmark for addr operator.""" @@ -106,46 +85,5 @@ addr_configs = op_bench.cross_product_configs( op_bench.generate_pt_test(addr_configs, AddrBenchmark) op_bench.generate_pt_gradient_test(addr_configs, AddrBenchmark) - -"""Mircobenchmark for addbmm operator.""" - - -class AddbmmBenchmark(op_bench.TorchBenchmarkBase): - def init(self, B, M, N, K, device): - self.inputs = { - "input_one": torch.rand( - (M, N), device=device, requires_grad=self.auto_set() - ), - "batch1": torch.rand( - (B, M, K), device=device, requires_grad=self.auto_set() - ), - "batch2": torch.rand( - ( - B, - K, - N, - ), - device=device, - requires_grad=self.auto_set(), - ), - } - self.set_module_name("addbmm") - - def forward(self, input_one, batch1, batch2): - return torch.addbmm(input_one, batch1, batch2) - - -addbmm_configs = op_bench.cross_product_configs( - B=[2, 100], - M=[8, 256], - N=[256, 16], - K=[15, 16], - device=["cpu", "cuda"], - tags=["addbmm"], -) - -op_bench.generate_pt_test(addbmm_configs, AddbmmBenchmark) -op_bench.generate_pt_gradient_test(addbmm_configs, AddbmmBenchmark) - if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py new file mode 100644 index 000000000000..a98628944b3e --- /dev/null +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -0,0 +1,115 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for add_(matmul) operator. Supports both Caffe2/PyTorch.""" + +# Configs for PT add operator +addmm_long_configs = op_bench.cross_product_configs( + M=[256, 1024, 3000], + N=[512, 4096], + K=[512, 4096], + device=["cuda"], + tags=["long"], + dtype=[torch.float16, torch.bfloat16, torch.float32], +) + + +addmm_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [1, 1, 1], + [64, 64, 64], + [64, 64, 128], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + "dtype": [torch.float], + }, + tags=["short"], +) + + +"""Mircobenchmark for addmm operator.""" + + +class AddmmBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype): + self.inputs = { + "input_one": torch.rand( + M, K, device=device, requires_grad=self.auto_set(), dtype=dtype + ), + "mat1": torch.rand( + M, N, device=device, requires_grad=self.auto_set(), dtype=dtype + ), + "mat2": torch.rand( + N, K, device=device, requires_grad=self.auto_set(), dtype=dtype + ), + } + self.set_module_name("addmm") + + def forward(self, input_one, mat1, mat2): + return torch.addmm(input_one, mat1, mat2) + + +op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark) +op_bench.generate_pt_gradient_test( + addmm_long_configs + addmm_long_configs, AddmmBenchmark +) + +"""Mircobenchmark for addbmm operator.""" + + +class AddbmmBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, K, device, dtype): + self.inputs = { + "input_one": torch.rand( + (M, N), device=device, requires_grad=self.auto_set(), dtype=dtype + ), + "batch1": torch.rand( + (B, M, K), device=device, requires_grad=self.auto_set(), dtype=dtype + ), + "batch2": torch.rand( + ( + B, + K, + N, + ), + device=device, + requires_grad=self.auto_set(), + dtype=dtype, + ), + } + self.set_module_name("addbmm") + + def forward(self, input_one, batch1, batch2): + return torch.addbmm(input_one, batch1, batch2) + + +addbmm_long_configs = op_bench.cross_product_configs( + B=[8, 32], + M=[256, 1024], + N=[256, 1024], + K=[64, 128], + device=["cuda"], + dtype=[torch.float16, torch.bfloat16, torch.float32], + tags=["long"], +) +addbmm_short_configs = op_bench.cross_product_configs( + B=[1, 8], + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype=[torch.float16, torch.bfloat16, torch.float32], + tags=["short"], +) + +op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark) +op_bench.generate_pt_gradient_test( + addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark +) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/bmm_test.py b/benchmarks/operator_benchmark/pt/bmm_test.py index 1c6d1f9aca55..f867f6ac09f8 100644 --- a/benchmarks/operator_benchmark/pt/bmm_test.py +++ b/benchmarks/operator_benchmark/pt/bmm_test.py @@ -27,12 +27,12 @@ batched_binary_configs_short = op_bench.config_list( ) batched_binary_configs_long = op_bench.cross_product_configs( - B=[1, 128], - M=[8, 128], - N=[32, 64], - K=[4, 256], - device=["cpu", "cuda"], - dtype=[torch.float, torch.bfloat16], + B=[8, 32], + M=[256, 1024], + N=[256, 1024], + K=[64, 128], + device=["cuda"], + dtype=[torch.float32, torch.bfloat16, torch.float16], tags=["long"], ) @@ -40,8 +40,12 @@ batched_binary_configs_long = op_bench.cross_product_configs( class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device, dtype, op_func): self.inputs = { - "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), - "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), + "batch1": torch.rand( + (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set() + ), + "batch2": torch.rand( + (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() + ), } self.op_func = op_func @@ -54,6 +58,11 @@ op_bench.generate_pt_tests_from_op_list( batched_binary_configs_short + batched_binary_configs_long, BatchedBinaryOpBenchmark, ) +op_bench.generate_pt_gradient_tests_from_op_list( + batched_binary_ops, + batched_binary_configs_long, + BatchedBinaryOpBenchmark, +) # batched ternary ops @@ -66,9 +75,15 @@ batched_ternary_ops = op_bench.op_list( class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, B, M, N, K, device, dtype, op_func): self.inputs = { - "input_": torch.rand((B, M, K), device=device).to(dtype=dtype), - "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), - "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), + "input_": torch.rand( + (B, M, K), device=device, dtype=dtype, requires_grad=self.auto_set() + ), + "batch1": torch.rand( + (B, M, N), device=device, dtype=dtype, requires_grad=self.auto_set() + ), + "batch2": torch.rand( + (B, N, K), device=device, dtype=dtype, requires_grad=self.auto_set() + ), } self.op_func = op_func @@ -81,6 +96,12 @@ op_bench.generate_pt_tests_from_op_list( batched_binary_configs_short + batched_binary_configs_long, BatchedTernaryOpBenchmark, ) +op_bench.generate_pt_gradient_tests_from_op_list( + batched_ternary_ops, + batched_binary_configs_long, + BatchedTernaryOpBenchmark, +) + # TODO: does it automatically register new scripts? diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index e92728e9ebd3..d0c58aa16e8f 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -13,33 +13,46 @@ mm_short_configs = op_bench.config_list( [128, 128, 128, True, False], [256, 256, 256, False, True], ], - cross_product_configs={ - "device": ["cpu", "cuda"], - }, + cross_product_configs={"device": ["cpu", "cuda"]}, tags=["short"], ) mm_long_configs = op_bench.cross_product_configs( - M=[32], - N=[512, 128], - K=[64], + M=[256, 1024, 3000], + N=[512, 4096], + K=[512, 4096], trans_a=[False, True], trans_b=[True, False], - device=["cpu", "cuda"], + device=["cuda"], + dtype=[torch.float16, torch.bfloat16, torch.float32], tags=["long"], ) class MatMulBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K, trans_a, trans_b, device): + def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float): + # Create tensors without requires_grad first, then set it separately + # This avoids creating graph leaves that cannot be deep copied + if trans_a: + input_one = torch.rand(M, N, device=device, dtype=dtype) + else: + input_one = torch.rand(N, M, device=device, dtype=dtype).t() + + if trans_b: + input_two = torch.rand(N, K, device=device, dtype=dtype) + else: + input_two = torch.rand(K, N, device=device, dtype=dtype).t() + + # Set requires_grad after tensor creation to avoid graph leaf issues + if self.auto_set(): + input_one.requires_grad_(True) + if self.auto_set(): + input_two.requires_grad_(True) + self.inputs = { - "input_one": torch.rand(M, N, device=device) - if trans_a - else torch.rand(N, M, device=device).t(), - "input_two": torch.rand(N, K, device=device) - if trans_b - else torch.rand(K, N, device=device).t(), + "input_one": input_one, + "input_two": input_two, } self.set_module_name("matmul") @@ -48,6 +61,7 @@ class MatMulBenchmark(op_bench.TorchBenchmarkBase): op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) +op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/mm_test.py b/benchmarks/operator_benchmark/pt/mm_test.py index bf2a2651e8fb..f9e0743ba712 100644 --- a/benchmarks/operator_benchmark/pt/mm_test.py +++ b/benchmarks/operator_benchmark/pt/mm_test.py @@ -23,11 +23,11 @@ mm_short_configs = op_bench.config_list( ) mm_long_configs = op_bench.cross_product_configs( - M=[8, 128], - N=[32, 64], - K=[256, 512], - device=["cpu", "cuda"], - dtype=[torch.float, torch.bfloat16], + M=[256, 1024, 3000], + N=[512, 4096], + K=[512, 4096], + device=["cuda"], + dtype=[torch.float16, torch.bfloat16, torch.float32], tags=["long"], ) @@ -35,8 +35,12 @@ mm_long_configs = op_bench.cross_product_configs( class MmOpBenchmark(op_bench.TorchBenchmarkBase): def init(self, M, N, K, device, dtype, op_func): self.inputs = { - "input_one": torch.randn(M, N, device=device).to(dtype=dtype), - "input_two": torch.randn(N, K, device=device).to(dtype=dtype), + "input_one": torch.randn( + M, N, device=device, requires_grad=self.auto_set(), dtype=dtype + ), + "input_two": torch.randn( + N, K, device=device, requires_grad=self.auto_set(), dtype=dtype + ), } self.op_func = op_func @@ -47,6 +51,9 @@ class MmOpBenchmark(op_bench.TorchBenchmarkBase): op_bench.generate_pt_tests_from_op_list( ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark ) +op_bench.generate_pt_gradient_tests_from_op_list( + ops_list, mm_long_configs, MmOpBenchmark +) if __name__ == "__main__":