Add compile time instruction count metric (#133834)

PYTHONPATH=$(pwd) python benchmarks/update_hint_benchmark.py out
as of this diff, compile_time_instruction_count counts the number of instruction from within
convert_frame.compile_inner
```
update_hint_regression,compile_time_instruction_count,10522459165
```
 will add result from CI once populated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133834
Approved by: https://github.com/aorenste
This commit is contained in:
Laith Sakka
2024-08-27 11:17:01 -07:00
committed by PyTorch MergeBot
parent ef0f5919c7
commit d6091c8726
6 changed files with 124 additions and 22 deletions

View File

@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from fbscribelogger import make_scribe_logger
import torch._C._instruction_counter as i_counter
import torch._dynamo.config as config
from torch._dynamo.utils import CompileTimeInstructionCounter
scribe_log_torch_benchmark_compile_time = make_scribe_logger(
@ -51,10 +53,19 @@ struct TorchBenchmarkCompileTimeLogEntry {
class BenchmarkBase(ABC):
_instruction_count = False
# measure total number of instruction spent in _work.
_enable_instruction_count = False
# measure total number of instruction spent in convert_frame.compile_inner
# TODO is there other parts we need to add ?
_enable_compile_time_instruction_count = False
def enable_instruction_count(self):
self._instruction_count = True
self._enable_instruction_count = True
return self
def enable_compile_time_instruction_count(self):
self._enable_compile_time_instruction_count = True
return self
def name(self):
@ -64,29 +75,44 @@ class BenchmarkBase(ABC):
return ""
@abstractmethod
def prepare(self):
def _prepare(self):
pass
@abstractmethod
def work(self):
def _work(self):
pass
def prepare_once(self): # noqa: B027
def _prepare_once(self): # noqa: B027
pass
def count_instructions(self):
def _count_instructions(self):
print(f"collecting instruction count for {self.name()}")
self.prepare_once()
results = []
for i in range(10):
self._prepare()
id = i_counter.start()
self._work()
count = i_counter.end(id)
print(f"instruction count for iteration {i} is {count}")
results.append(count)
return min(results)
def _count_compile_time_instructions(self):
print(f"collecting compile time instruction count for {self.name()}")
config.record_compile_time_instruction_count = True
results = []
for i in range(10):
self.prepare()
id = i_counter.start()
self.work()
count = i_counter.end(id)
print(f"instruction count for iteration {i} is {count}")
if i != 0:
results.append(count)
self._prepare()
# CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
# hence this will only count instruction count spent in compile_inner.
CompileTimeInstructionCounter.clear()
self._work()
count = CompileTimeInstructionCounter.value()
print(f"compile time instruction count for iteration {i} is {count}")
results.append(count)
config.record_compile_time_instruction_count = False
return min(results)
def append_results(self, path):
@ -102,12 +128,36 @@ class BenchmarkBase(ABC):
print(f"{entry[0]},{entry[1]},{entry[2]}")
def collect_all(self):
self._prepare_once()
self.results = []
if self._instruction_count:
r = self.count_instructions()
if (
self._enable_instruction_count
and self._enable_compile_time_instruction_count
):
raise RuntimeError(
"not supported until we update the logger, both logs to the same field now"
)
if self._enable_instruction_count:
r = self._count_instructions()
self.results.append((self.name(), "instruction_count", r))
scribe_log_torch_benchmark_compile_time(
name=self.name(),
instruction_count=r,
)
if self._enable_compile_time_instruction_count:
r = self._count_compile_time_instructions()
self.results.append(
(
self.name(),
"compile_time_instruction_count",
r,
)
)
# TODO add a new field compile_time_instruction_count to the logger.
scribe_log_torch_benchmark_compile_time(
name=self.name(),
instruction_count=r,
)
return self

View File

@ -15,17 +15,17 @@ class Benchmark(BenchmarkBase):
def description(self):
return "information at https://github.com/pytorch/pytorch/pull/129893"
def prepare_once(self):
def _prepare_once(self):
torch._dynamo.config.capture_scalar_outputs = True
random.seed(42)
self.splits = torch.randint(10, (self.N,))
sz = self.splits.sum().item()
self.input = torch.randn(sz)
def prepare(self):
def _prepare(self):
torch._dynamo.reset()
def work(self):
def _work(self):
@torch.compile(fullgraph=True)
def f(a, b):
xs = b.tolist()
@ -34,12 +34,15 @@ class Benchmark(BenchmarkBase):
torch._check(x <= self.N)
return a.split(xs)
f(self.input, self.splits)
for i in range(1000):
f(self.input, self.splits)
def main():
result_path = sys.argv[1]
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
result_path
)
if __name__ == "__main__":

View File

@ -0,0 +1,4 @@
# Defined in torch/csrc/instruction_counter/Module.cpp
def start() -> int: ...
def end(id: int) -> int: ...

View File

@ -374,6 +374,10 @@ enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1")
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = not is_fbcode()
# When set, total compile time instruction count is recorded using
# torch._dynamo.utilsCompileTimeInstructionCounter.
record_compile_time_instruction_count = False
def default_debug_dir_root():
# [@compile_ignored: debug]

View File

@ -27,6 +27,7 @@ import torch
import torch._logging
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._dynamo.utils import CompileTimeInstructionCounter
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import (
@ -652,7 +653,8 @@ def _compile(
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
) -> Optional[GuardedCode]:
with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
return _compile_inner(code, one_graph, hooks, transform)
with CompileTimeInstructionCounter.record():
return _compile_inner(code, one_graph, hooks, transform)
@compile_time_strobelight_meta(phase_name="compile_inner")
@maybe_cprofile

View File

@ -64,6 +64,7 @@ import torch.utils._pytree as pytree
from torch import fx
from torch._C import (
_get_function_stack_at,
_instruction_counter,
_len_torch_function_stack,
_pop_torch_function_stack,
_push_on_torch_function_stack,
@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id):
def store_user_object_weakref(obj):
obj_id = id(obj)
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
class CompileTimeInstructionCounter:
_counter: int = 0
_id: int = -1
_depth = 0
@classmethod
def start(cls) -> None:
cls._depth = cls._depth + 1
if cls._depth == 1:
cls._id = _instruction_counter.start()
@classmethod
def end(cls) -> None:
cls._depth = cls._depth - 1
if cls._depth == 0:
cls._counter += _instruction_counter.end(cls._id)
cls._id = -1
@classmethod
def clear(cls) -> None:
cls._counter = 0
@classmethod
def value(cls) -> int:
return cls._counter
@classmethod
@contextmanager
def record(cls):
try:
if config.record_compile_time_instruction_count:
cls.start()
yield
finally:
if config.record_compile_time_instruction_count:
cls.end()