Compare commits

...

2 Commits

Author SHA1 Message Date
85ba89cd8f Refactor gradient tests for addmm and addbmm
Updated gradient tests to only use long configurations for addmm and addbmm benchmarks.
2025-10-30 10:04:21 -07:00
eb35fb401a Fix duplicate benchmarking for addmm 2025-10-29 19:19:33 -07:00

View File

@ -53,10 +53,8 @@ class AddmmBenchmark(op_bench.TorchBenchmarkBase):
return torch.addmm(input_one, mat1, mat2)
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(
addmm_long_configs + addmm_long_configs, AddmmBenchmark
)
op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark)
op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark)
"""Mircobenchmark for addbmm operator."""
@ -107,9 +105,7 @@ addbmm_short_configs = op_bench.cross_product_configs(
)
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
op_bench.generate_pt_gradient_test(
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
)
op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()