mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
compile time benchmarks for AOTDispatcher (inference/training/subclasses) (#136759)
this adds a few compile time benchmarks for some disjoint paths in AOTDispatcher: (1) inference vs training code paths (2) "subclasses" vs "no subclasses" codepaths Also see https://github.com/pytorch/pytorch/pull/136760 for a partitioner benchmark (I'm not sure why ghstack didn't display the stack nicely) I ran locally, and got these numbers on the 4 paths: ``` collecting compile time instruction count for aotdispatcher_inference_nosubclass_cpu compile time instruction count for iteration 0 is 11692348671 compile time instruction count for iteration 1 is 3026287204 compile time instruction count for iteration 2 is 3011467318 compile time instruction count for iteration 3 is 3004485935 compile time instruction count for iteration 4 is 3003087410 collecting compile time instruction count for aotdispatcher_training_nosubclass_cpu compile time instruction count for iteration 0 is 6068003223 compile time instruction count for iteration 1 is 5585418102 compile time instruction count for iteration 2 is 5581856618 compile time instruction count for iteration 3 is 5581651794 compile time instruction count for iteration 4 is 5578742619 collecting compile time instruction count for aotdispatcher_inference_subclass_cpu compile time instruction count for iteration 0 is 8634984264 compile time instruction count for iteration 1 is 8633467573 compile time instruction count for iteration 2 is 8632182092 compile time instruction count for iteration 3 is 8632056925 compile time instruction count for iteration 4 is 8632543871 collecting compile time instruction count for aotdispatcher_training_subclass_cpu compile time instruction count for iteration 0 is 14737239311 compile time instruction count for iteration 1 is 14734346427 compile time instruction count for iteration 2 is 14736493730 compile time instruction count for iteration 3 is 14734121272 compile time instruction count for iteration 4 is 14733852882 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136759 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
53af729a66
commit
48b8f818b2
@ -0,0 +1,72 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
def __init__(self, *, training, subclass):
|
||||
self._training = training
|
||||
self._subclass = subclass
|
||||
self._device = "cpu"
|
||||
|
||||
def name(self):
|
||||
prefix = "aotdispatcher"
|
||||
if self._training:
|
||||
prefix += "_training"
|
||||
else:
|
||||
prefix += "_inference"
|
||||
if self._subclass:
|
||||
prefix += "_subclass"
|
||||
else:
|
||||
prefix += "_nosubclass"
|
||||
if self._device == "cpu":
|
||||
prefix += "_cpu"
|
||||
return prefix
|
||||
|
||||
def description(self):
|
||||
return "100 inputs, 100 outputs, each input is added once"
|
||||
|
||||
def _prepare_once(self):
|
||||
_args = [
|
||||
torch.ones(100, requires_grad=self._training, device=self._device)
|
||||
for _ in range(100)
|
||||
]
|
||||
if self._subclass:
|
||||
_args = [
|
||||
TwoTensor(x, x.clone().detach().requires_grad_(self._training))
|
||||
for x in _args
|
||||
]
|
||||
self._args = _args
|
||||
|
||||
def _prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def _work(self):
|
||||
@torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
|
||||
def f(*args):
|
||||
outs = [torch.add(x, x) for x in args]
|
||||
return outs
|
||||
|
||||
f(*self._args)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
all = [
|
||||
Benchmark(training=False, subclass=False),
|
||||
Benchmark(training=True, subclass=False),
|
||||
Benchmark(training=False, subclass=True),
|
||||
Benchmark(training=True, subclass=True),
|
||||
]
|
||||
|
||||
for benchmark in all:
|
||||
benchmark.enable_compile_time_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user