[logs] Add dynamo_timed to get better compilation time breakdown for AOTI (#140198)

Adding some dynamo timed for the purpose of better understanding AOTI compilation time.

Probably would require a few more passes. A lot of time is spent in Scheduler.__init__, and not enough annotations are there.

run_command_and_check takes a lot time as well. But there is probably not much we can do. Maybe we can add a config to tune C++ optimization level?

traces:
<img width="1205" alt="Screenshot 2024-11-08 at 4 41 10 PM" src="https://github.com/user-attachments/assets/61645264-b3af-4d4a-804d-700b0f831c7c">

Differential Revision: D65554141

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140198
Approved by: https://github.com/desertfire
This commit is contained in:
Henry Tsang
2024-11-19 18:54:17 +00:00
committed by PyTorch MergeBot
parent 7f10351ba0
commit 4f2543c31d
7 changed files with 110 additions and 89 deletions

View File

@ -148,13 +148,16 @@ class TestDynamoTimed(TestCase):
self.assertExpectedInline(
pprint.pformat(utils.compilation_time_metrics),
"""\
{'GraphLowering.compile_to_module': [0.0, 0.0],
{'GraphLowering.codegen': [0.0, 0.0],
'GraphLowering.compile_to_fn': [0.0, 0.0],
'GraphLowering.compile_to_module': [0.0, 0.0],
'GraphLowering.run': [0.0, 0.0],
'OutputGraph.call_user_compiler': [0.0],
'PyCodeCache.load_by_key_path': [0.0, 0.0],
'PythonWrapperCodegen.generate': [0.0, 0.0],
'Scheduler.__init__': [0.0, 0.0],
'Scheduler.codegen': [0.0, 0.0],
'Scheduler.fused_nodes': [0.0, 0.0],
'_compile.compile_inner': [0.0],
'_recursive_joint_graph_passes': [0.0],
'_recursive_post_grad_passes': [0.0, 0.0],

View File

@ -1771,11 +1771,12 @@ class CompiledFxGraph:
def run_command_and_check(cmd_: str) -> None:
cmd = shlex.split(cmd_)
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e
with dynamo_timed("run_command_and_check", log_pt2_compile_event=True):
cmd = shlex.split(cmd_)
try:
subprocess.check_call(cmd)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output) from e
@functools.lru_cache(None)

View File

@ -12,6 +12,7 @@ from sympy import Expr
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch._ops
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
from .. import config, ir
@ -776,12 +777,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
self.prefix.writeline("}")
def generate(self, is_inference):
if V.graph.aot_mode and not V.graph.is_const_graph:
self.codegen_model_kernels()
self.codegen_model_constructor()
self.codegen_const_run_driver()
self.write_wrapper_decl()
return super().generate(is_inference)
with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True):
if V.graph.aot_mode and not V.graph.is_const_graph:
self.codegen_model_kernels()
self.codegen_model_constructor()
self.codegen_const_run_driver()
self.write_wrapper_decl()
return super().generate(is_inference)
def finalize_prefix(self):
cached_dtypes_buffer = IndentedBuffer()

View File

@ -8,6 +8,7 @@ import sympy
from torch import dtype as torch_dtype
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
from ..codecache import CudaKernelParamCache
@ -230,19 +231,22 @@ class CppWrapperGpu(CppWrapperCpu):
)
def generate(self, is_inference):
self.prefix.writeline("\n")
if not V.graph.aot_mode:
for kernel in chain(
sorted(self.src_to_kernel.values()),
sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
):
self.prefix.writeline(
maybe_hipify_code_wrapper(
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
)
)
with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
self.prefix.writeline("\n")
return super().generate(is_inference)
if not V.graph.aot_mode:
for kernel in chain(
sorted(self.src_to_kernel.values()),
sorted(
[entry[0] for entry in self.user_defined_kernel_cache.values()]
),
):
self.prefix.writeline(
maybe_hipify_code_wrapper(
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
)
)
self.prefix.writeline("\n")
return super().generate(is_inference)
def generate_user_defined_triton_kernel(
self,

View File

@ -1886,24 +1886,25 @@ class GraphLowering(torch.fx.Interpreter):
return self.codegen()
def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
from .scheduler import Scheduler
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
from .scheduler import Scheduler
self.init_wrapper_code()
self.init_wrapper_code()
self.scheduler = Scheduler(self.operations)
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
self.scheduler = Scheduler(self.operations)
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
self.wrapper_code.push_codegened_graph(self)
self.scheduler.codegen()
self.wrapper_code.push_codegened_graph(self)
self.scheduler.codegen()
log.debug(
"Finished codegen for all nodes. The list of kernel names available: %s",
V.graph.all_codegen_kernel_names,
)
log.debug(
"Finished codegen for all nodes. The list of kernel names available: %s",
V.graph.all_codegen_kernel_names,
)
result = self.wrapper_code.generate(self.is_inference)
self.wrapper_code.pop_codegened_graph()
return result
result = self.wrapper_code.generate(self.is_inference)
self.wrapper_code.pop_codegened_graph()
return result
def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
"""
@ -1915,14 +1916,15 @@ class GraphLowering(torch.fx.Interpreter):
kerenls). The wrapper code is not finalized (via `.generate()`
call), as this will be done in the parent graph's `codegen()`.
"""
from .scheduler import Scheduler
with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True):
from .scheduler import Scheduler
self.wrapper_code = parent_graph.wrapper_code
self.device_ops = parent_graph.device_ops
self.cpp_wrapper = parent_graph.cpp_wrapper
self.wrapper_code = parent_graph.wrapper_code
self.device_ops = parent_graph.device_ops
self.cpp_wrapper = parent_graph.cpp_wrapper
self.scheduler = Scheduler(self.operations)
self.scheduler.codegen()
self.scheduler = Scheduler(self.operations)
self.scheduler.codegen()
def count_bytes(
self,
@ -2013,6 +2015,10 @@ class GraphLowering(torch.fx.Interpreter):
return mod
def compile_to_fn(self) -> Any:
with dynamo_timed("GraphLowering.compile_to_fn", log_pt2_compile_event=True):
return self._compile_to_fn()
def _compile_to_fn(self) -> Any:
if self.aot_mode:
from .codecache import AotCodeCompiler
@ -2032,14 +2038,15 @@ class GraphLowering(torch.fx.Interpreter):
additional_files = self.wrapper_code.additional_files
# Directly return the file path with the compiled code
return AotCodeCompiler.compile(
self,
code,
serialized_extern_kernel_nodes,
device_type=self.device_type,
additional_files=additional_files,
)
with dynamo_timed("AotCodeCompiler.compile", log_pt2_compile_event=True):
# Directly return the file path with the compiled code
return AotCodeCompiler.compile(
self,
code,
serialized_extern_kernel_nodes,
device_type=self.device_type,
additional_files=additional_files,
)
else:
return self.compile_to_module().call

View File

@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Tuple
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
from torch._dynamo.utils import counters
from torch._dynamo.utils import counters, dynamo_timed
logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
@ -100,27 +100,28 @@ class Benchmarker:
Returns:
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
"""
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True):
inferred_device = None
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
if not isinstance(arg_or_kwarg, torch.Tensor):
continue
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
)
if inferred_device is None:
inferred_device = arg_or_kwarg.device
elif arg_or_kwarg.device != inferred_device:
raise ValueError(
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
)
if inferred_device is None:
raise ValueError(
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
if inferred_device == torch.device("cpu"):
return self.benchmark_cpu(_callable, **kwargs)
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
# implementation which was written specifically with CUDA devices in mind, we may want to
# explore alternate implementations for other device types.
return self.benchmark_gpu(_callable, **kwargs)
@maybe_time
@count

View File

@ -2315,25 +2315,28 @@ class Scheduler:
"""
Combine eligible nodes into FusedSchedulerNodes.
"""
for i in range(10):
old_len = len(nodes)
fusion_log.debug(
"===== attempting fusion (%d/10): %d nodes =====",
i + 1,
old_len,
)
nodes = self.fuse_nodes_once(nodes)
new_len = len(nodes)
fusion_log.debug(
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
i + 1,
old_len,
new_len,
)
if new_len == old_len or new_len == 1:
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
break
return nodes
with dynamo_timed("Scheduler.fused_nodes"):
for i in range(10):
old_len = len(nodes)
fusion_log.debug(
"===== attempting fusion (%d/10): %d nodes =====",
i + 1,
old_len,
)
nodes = self.fuse_nodes_once(nodes)
new_len = len(nodes)
fusion_log.debug(
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
i + 1,
old_len,
new_len,
)
if new_len == old_len or new_len == 1:
fusion_log.debug(
"===== fusion complete (%d iterations) =====", i + 1
)
break
return nodes
def process_grouped_nodes(self) -> None:
"""