Compare commits

...

3 Commits

Author SHA1 Message Date
9dfb3d234a Add memory bandwidth calculation 2025-11-18 23:08:29 -08:00
abfc59b1e3 Add test to ci 2025-11-18 12:10:48 -08:00
074dffa1cc Add optimizer tests in operator microbenchmarks 2025-11-18 12:02:33 -08:00
4 changed files with 142 additions and 2 deletions

View File

@ -1768,7 +1768,7 @@ test_operator_microbenchmark() {
cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
for OP_BENCHMARK_TESTS in optimizer; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile

View File

@ -266,7 +266,11 @@ class BenchmarkRunner:
print(
f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}"
)
print(f"Peak Memory (KB) : {results['peak_memory']}\n")
print(f"Peak Memory (KB) : {results['peak_memory']}")
# Calculate and print memory bandwidth if operator provides memory traffic
if results.get('memory_bandwidth_gb_s') is not None:
print(f"Memory Bandwidth (GB/s) : {results['memory_bandwidth_gb_s']:.2f}")
print()
def _perf_result_to_dict(self, results, test_case):
"""This function is the parallel of _print_perf_result, which instead of
@ -711,6 +715,15 @@ class BenchmarkRunner:
result_dict = dict()
result_dict["reported_run_time_us"] = [r[0] for r in results]
result_dict["peak_memory"] = results[0][1]
# Calculate memory bandwidth if operator provides memory traffic
memory_traffic_bytes = test_case.op_bench.get_memory_traffic_bytes()
if memory_traffic_bytes is not None:
execution_time_s = result_dict["reported_run_time_us"][0] / 1e6
result_dict["memory_bandwidth_gb_s"] = memory_traffic_bytes / execution_time_s / 1e9
else:
result_dict["memory_bandwidth_gb_s"] = None
self._print_perf_result(results=result_dict, test_case=test_case)
# output results to csv

View File

@ -118,6 +118,54 @@ class TorchBenchmarkBase(torch.nn.Module):
name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "")
return name
def get_memory_traffic_bytes(self):
"""Return the number of bytes read/written by this operator.
Override this method in subclasses to enable memory bandwidth calculation.
The framework will use this value along with execution time to compute
and report memory bandwidth in GB/s.
This provides automatic calculation for matmul-like operations by
inferring dimensions from input tensor shapes:
- 2D inputs: (M, N) @ (N, K) → matmul, mm
- 3D inputs: (B, M, N) @ (B, N, K) → bmm, baddbmm
For custom memory patterns, override this method.
Returns:
int or None: Total bytes transferred (reads + writes), or None if not applicable
"""
if not hasattr(self, 'inputs') or not self.inputs:
return None
input_tensors = [v for v in self.inputs.values() if isinstance(v, torch.Tensor)]
if len(input_tensors) < 2:
return None
input_a, input_b = input_tensors[0], input_tensors[1]
if input_a.dim() != input_b.dim() or input_a.dim() not in (2, 3):
return None
bytes_per_element = input_a.element_size()
if input_a.dim() == 3:
B_a, M, N_a = input_a.shape
B_b, N_b, K = input_b.shape
if B_a != B_b or N_a != N_b:
return None
B = B_a
else:
M, N_a = input_a.shape
N_b, K = input_b.shape
if N_a != N_b:
return None
B = 1
N = N_a
total_elements = B * (M * N + N * K + M * K)
return total_elements * bytes_per_element
class PyTorchOperatorTestCase:
"""This class includes all the information needed to benchmark an operator.

View File

@ -0,0 +1,79 @@
import operator_benchmark as op_bench
import torch
import torch.optim as optim
"""Microbenchmarks for optimizer operators."""
optimizer_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["adamw", optim.AdamW],
["adam", optim.Adam],
["sgd", optim.SGD],
["rmsprop", optim.RMSprop],
["adagrad", optim.Adagrad],
],
)
optimizer_configs_long = op_bench.cross_product_configs(
num_params=[1, 10, 100],
param_size=[100000, 1000000, 10000000],
device=["cuda"],
tags=["long"],
)
class OptimizerBenchmark(op_bench.TorchBenchmarkBase):
def init(self, op_func, device, shape=None, num_params=None, param_size=None):
if shape is not None:
num_params = num_params if num_params is not None else 1
self.params = [
torch.randn(shape, device=device, requires_grad=True)
for _ in range(num_params)
]
for param in self.params:
param.grad = torch.randn(shape, device=device)
else:
self.params = [
torch.randn(param_size, device=device, requires_grad=True)
for _ in range(num_params)
]
for param in self.params:
param.grad = torch.randn_like(param)
kwargs = {"momentum": 0.9} if op_func == optim.SGD else {}
self.optimizer = op_func(self.params, lr=0.001, **kwargs)
# Memory traffic calculation for bandwidth
self.total_elements = sum(p.numel() for p in self.params)
self.bytes_per_element = self.params[0].element_size()
# SGD w/ momentum: read(param, grad, momentum) + write(param, momentum) = 5x
# Adam/AdamW: read(param, grad, exp_avg, exp_avg_sq) + write(param, exp_avg, exp_avg_sq) = 7x
# Adagrad/RMSprop: read(param, grad, state) + write(param, state) = 5x
if op_func in (optim.Adam, optim.AdamW):
self.memory_multiplier = 7
else:
self.memory_multiplier = 5
self.inputs = {"dummy": self.params[0]}
def forward(self, dummy):
self.optimizer.step()
for param in self.params:
param.grad = torch.randn_like(param)
return self.params[0]
def get_memory_traffic_bytes(self):
return self.total_elements * self.bytes_per_element * self.memory_multiplier
op_bench.generate_pt_tests_from_op_list(
optimizer_list, optimizer_configs_long, OptimizerBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()