mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add compile time instruction count metric (#133834)
PYTHONPATH=$(pwd) python benchmarks/update_hint_benchmark.py out as of this diff, compile_time_instruction_count counts the number of instruction from within convert_frame.compile_inner ``` update_hint_regression,compile_time_instruction_count,10522459165 ``` will add result from CI once populated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133834 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef0f5919c7
commit
d6091c8726
@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
|
||||
from fbscribelogger import make_scribe_logger
|
||||
|
||||
import torch._C._instruction_counter as i_counter
|
||||
import torch._dynamo.config as config
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter
|
||||
|
||||
|
||||
scribe_log_torch_benchmark_compile_time = make_scribe_logger(
|
||||
@ -51,10 +53,19 @@ struct TorchBenchmarkCompileTimeLogEntry {
|
||||
|
||||
|
||||
class BenchmarkBase(ABC):
|
||||
_instruction_count = False
|
||||
# measure total number of instruction spent in _work.
|
||||
_enable_instruction_count = False
|
||||
|
||||
# measure total number of instruction spent in convert_frame.compile_inner
|
||||
# TODO is there other parts we need to add ?
|
||||
_enable_compile_time_instruction_count = False
|
||||
|
||||
def enable_instruction_count(self):
|
||||
self._instruction_count = True
|
||||
self._enable_instruction_count = True
|
||||
return self
|
||||
|
||||
def enable_compile_time_instruction_count(self):
|
||||
self._enable_compile_time_instruction_count = True
|
||||
return self
|
||||
|
||||
def name(self):
|
||||
@ -64,29 +75,44 @@ class BenchmarkBase(ABC):
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self):
|
||||
def _prepare(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def work(self):
|
||||
def _work(self):
|
||||
pass
|
||||
|
||||
def prepare_once(self): # noqa: B027
|
||||
def _prepare_once(self): # noqa: B027
|
||||
pass
|
||||
|
||||
def count_instructions(self):
|
||||
def _count_instructions(self):
|
||||
print(f"collecting instruction count for {self.name()}")
|
||||
self.prepare_once()
|
||||
results = []
|
||||
for i in range(10):
|
||||
self._prepare()
|
||||
id = i_counter.start()
|
||||
self._work()
|
||||
count = i_counter.end(id)
|
||||
print(f"instruction count for iteration {i} is {count}")
|
||||
results.append(count)
|
||||
return min(results)
|
||||
|
||||
def _count_compile_time_instructions(self):
|
||||
print(f"collecting compile time instruction count for {self.name()}")
|
||||
config.record_compile_time_instruction_count = True
|
||||
|
||||
results = []
|
||||
for i in range(10):
|
||||
self.prepare()
|
||||
id = i_counter.start()
|
||||
self.work()
|
||||
count = i_counter.end(id)
|
||||
print(f"instruction count for iteration {i} is {count}")
|
||||
if i != 0:
|
||||
results.append(count)
|
||||
self._prepare()
|
||||
# CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
|
||||
# hence this will only count instruction count spent in compile_inner.
|
||||
CompileTimeInstructionCounter.clear()
|
||||
self._work()
|
||||
count = CompileTimeInstructionCounter.value()
|
||||
print(f"compile time instruction count for iteration {i} is {count}")
|
||||
results.append(count)
|
||||
|
||||
config.record_compile_time_instruction_count = False
|
||||
return min(results)
|
||||
|
||||
def append_results(self, path):
|
||||
@ -102,12 +128,36 @@ class BenchmarkBase(ABC):
|
||||
print(f"{entry[0]},{entry[1]},{entry[2]}")
|
||||
|
||||
def collect_all(self):
|
||||
self._prepare_once()
|
||||
self.results = []
|
||||
if self._instruction_count:
|
||||
r = self.count_instructions()
|
||||
if (
|
||||
self._enable_instruction_count
|
||||
and self._enable_compile_time_instruction_count
|
||||
):
|
||||
raise RuntimeError(
|
||||
"not supported until we update the logger, both logs to the same field now"
|
||||
)
|
||||
|
||||
if self._enable_instruction_count:
|
||||
r = self._count_instructions()
|
||||
self.results.append((self.name(), "instruction_count", r))
|
||||
scribe_log_torch_benchmark_compile_time(
|
||||
name=self.name(),
|
||||
instruction_count=r,
|
||||
)
|
||||
if self._enable_compile_time_instruction_count:
|
||||
r = self._count_compile_time_instructions()
|
||||
|
||||
self.results.append(
|
||||
(
|
||||
self.name(),
|
||||
"compile_time_instruction_count",
|
||||
r,
|
||||
)
|
||||
)
|
||||
# TODO add a new field compile_time_instruction_count to the logger.
|
||||
scribe_log_torch_benchmark_compile_time(
|
||||
name=self.name(),
|
||||
instruction_count=r,
|
||||
)
|
||||
return self
|
||||
|
@ -15,17 +15,17 @@ class Benchmark(BenchmarkBase):
|
||||
def description(self):
|
||||
return "information at https://github.com/pytorch/pytorch/pull/129893"
|
||||
|
||||
def prepare_once(self):
|
||||
def _prepare_once(self):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
random.seed(42)
|
||||
self.splits = torch.randint(10, (self.N,))
|
||||
sz = self.splits.sum().item()
|
||||
self.input = torch.randn(sz)
|
||||
|
||||
def prepare(self):
|
||||
def _prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def work(self):
|
||||
def _work(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(a, b):
|
||||
xs = b.tolist()
|
||||
@ -34,12 +34,15 @@ class Benchmark(BenchmarkBase):
|
||||
torch._check(x <= self.N)
|
||||
return a.split(xs)
|
||||
|
||||
f(self.input, self.splits)
|
||||
for i in range(1000):
|
||||
f(self.input, self.splits)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
|
||||
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
4
torch/_C/_instruction_counter.pyi
Normal file
4
torch/_C/_instruction_counter.pyi
Normal file
@ -0,0 +1,4 @@
|
||||
# Defined in torch/csrc/instruction_counter/Module.cpp
|
||||
|
||||
def start() -> int: ...
|
||||
def end(id: int) -> int: ...
|
@ -374,6 +374,10 @@ enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1")
|
||||
# Inline inbuilt nn modules
|
||||
inline_inbuilt_nn_modules = not is_fbcode()
|
||||
|
||||
# When set, total compile time instruction count is recorded using
|
||||
# torch._dynamo.utilsCompileTimeInstructionCounter.
|
||||
record_compile_time_instruction_count = False
|
||||
|
||||
|
||||
def default_debug_dir_root():
|
||||
# [@compile_ignored: debug]
|
||||
|
@ -27,6 +27,7 @@ import torch
|
||||
import torch._logging
|
||||
from torch._C._dynamo.guards import GlobalStateGuard
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
from torch._logging import structured
|
||||
from torch._utils_internal import (
|
||||
@ -652,7 +653,8 @@ def _compile(
|
||||
transform: Callable[[List[Instruction], Dict[str, Any]], Any],
|
||||
) -> Optional[GuardedCode]:
|
||||
with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
|
||||
return _compile_inner(code, one_graph, hooks, transform)
|
||||
with CompileTimeInstructionCounter.record():
|
||||
return _compile_inner(code, one_graph, hooks, transform)
|
||||
|
||||
@compile_time_strobelight_meta(phase_name="compile_inner")
|
||||
@maybe_cprofile
|
||||
|
@ -64,6 +64,7 @@ import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._C import (
|
||||
_get_function_stack_at,
|
||||
_instruction_counter,
|
||||
_len_torch_function_stack,
|
||||
_pop_torch_function_stack,
|
||||
_push_on_torch_function_stack,
|
||||
@ -3203,3 +3204,41 @@ def get_user_object_from_id(obj_id):
|
||||
def store_user_object_weakref(obj):
|
||||
obj_id = id(obj)
|
||||
user_obj_id_to_weakref[obj_id] = weakref.ref(obj)
|
||||
|
||||
|
||||
class CompileTimeInstructionCounter:
|
||||
_counter: int = 0
|
||||
_id: int = -1
|
||||
_depth = 0
|
||||
|
||||
@classmethod
|
||||
def start(cls) -> None:
|
||||
cls._depth = cls._depth + 1
|
||||
if cls._depth == 1:
|
||||
cls._id = _instruction_counter.start()
|
||||
|
||||
@classmethod
|
||||
def end(cls) -> None:
|
||||
cls._depth = cls._depth - 1
|
||||
if cls._depth == 0:
|
||||
cls._counter += _instruction_counter.end(cls._id)
|
||||
cls._id = -1
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
cls._counter = 0
|
||||
|
||||
@classmethod
|
||||
def value(cls) -> int:
|
||||
return cls._counter
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def record(cls):
|
||||
try:
|
||||
if config.record_compile_time_instruction_count:
|
||||
cls.start()
|
||||
yield
|
||||
finally:
|
||||
if config.record_compile_time_instruction_count:
|
||||
cls.end()
|
||||
|
Reference in New Issue
Block a user