mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
compile time benchmarks for AOTDispatcher (partitioner) (#136760)
compile time benchmark for the min cut partitioner. I'm hoping that this is a reasonable benchmark because: (1) it consists of a single input + many weights that are used sequentially (2) contains a mix of recompute vs non-recomputed ops (matmul + sin) (3) it is relatively simple from running locally: ``` collecting compile time instruction count for aotdispatcher_partitioner_cpu compile time instruction count for iteration 0 is 21764219181 compile time instruction count for iteration 1 is 12475020009 compile time instruction count for iteration 2 is 12463710140 compile time instruction count for iteration 3 is 12455676489 compile time instruction count for iteration 4 is 12451344330 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136760 Approved by: https://github.com/ezyang ghstack dependencies: #136759
This commit is contained in:
committed by
PyTorch MergeBot
parent
48b8f818b2
commit
b41fc14072
@ -0,0 +1,46 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
def name(self):
|
||||
return "aotdispatcher_partitioner_cpu"
|
||||
|
||||
def description(self):
|
||||
return "partitioner benchmark 1 input and 100 weights, mix of recompute and non-recompute ops"
|
||||
|
||||
def _prepare_once(self):
|
||||
self.weights = [torch.randn(16, 16, requires_grad=True) for _ in range(100)]
|
||||
self.inp = torch.randn(16, 16)
|
||||
|
||||
def _prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def _work(self):
|
||||
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
def f(inp, *weights):
|
||||
x = inp
|
||||
for w in weights:
|
||||
x = torch.matmul(w, x).sin().sin()
|
||||
return x
|
||||
|
||||
f(self.inp, *self.weights)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
all = [
|
||||
Benchmark(),
|
||||
]
|
||||
|
||||
for benchmark in all:
|
||||
benchmark.enable_compile_time_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user