mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[logs] Add dynamo_timed to get better compilation time breakdown for AOTI (#140198)
Adding some dynamo timed for the purpose of better understanding AOTI compilation time. Probably would require a few more passes. A lot of time is spent in Scheduler.__init__, and not enough annotations are there. run_command_and_check takes a lot time as well. But there is probably not much we can do. Maybe we can add a config to tune C++ optimization level? traces: <img width="1205" alt="Screenshot 2024-11-08 at 4 41 10 PM" src="https://github.com/user-attachments/assets/61645264-b3af-4d4a-804d-700b0f831c7c"> Differential Revision: D65554141 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140198 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f10351ba0
commit
4f2543c31d
@ -148,13 +148,16 @@ class TestDynamoTimed(TestCase):
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(utils.compilation_time_metrics),
|
||||
"""\
|
||||
{'GraphLowering.compile_to_module': [0.0, 0.0],
|
||||
{'GraphLowering.codegen': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_fn': [0.0, 0.0],
|
||||
'GraphLowering.compile_to_module': [0.0, 0.0],
|
||||
'GraphLowering.run': [0.0, 0.0],
|
||||
'OutputGraph.call_user_compiler': [0.0],
|
||||
'PyCodeCache.load_by_key_path': [0.0, 0.0],
|
||||
'PythonWrapperCodegen.generate': [0.0, 0.0],
|
||||
'Scheduler.__init__': [0.0, 0.0],
|
||||
'Scheduler.codegen': [0.0, 0.0],
|
||||
'Scheduler.fused_nodes': [0.0, 0.0],
|
||||
'_compile.compile_inner': [0.0],
|
||||
'_recursive_joint_graph_passes': [0.0],
|
||||
'_recursive_post_grad_passes': [0.0, 0.0],
|
||||
|
@ -1771,11 +1771,12 @@ class CompiledFxGraph:
|
||||
|
||||
|
||||
def run_command_and_check(cmd_: str) -> None:
|
||||
cmd = shlex.split(cmd_)
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise exc.CppCompileError(cmd, e.output) from e
|
||||
with dynamo_timed("run_command_and_check", log_pt2_compile_event=True):
|
||||
cmd = shlex.split(cmd_)
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise exc.CppCompileError(cmd, e.output) from e
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
@ -12,6 +12,7 @@ from sympy import Expr
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch._ops
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes
|
||||
|
||||
from .. import config, ir
|
||||
@ -776,12 +777,13 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||
self.prefix.writeline("}")
|
||||
|
||||
def generate(self, is_inference):
|
||||
if V.graph.aot_mode and not V.graph.is_const_graph:
|
||||
self.codegen_model_kernels()
|
||||
self.codegen_model_constructor()
|
||||
self.codegen_const_run_driver()
|
||||
self.write_wrapper_decl()
|
||||
return super().generate(is_inference)
|
||||
with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True):
|
||||
if V.graph.aot_mode and not V.graph.is_const_graph:
|
||||
self.codegen_model_kernels()
|
||||
self.codegen_model_constructor()
|
||||
self.codegen_const_run_driver()
|
||||
self.write_wrapper_decl()
|
||||
return super().generate(is_inference)
|
||||
|
||||
def finalize_prefix(self):
|
||||
cached_dtypes_buffer = IndentedBuffer()
|
||||
|
@ -8,6 +8,7 @@ import sympy
|
||||
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name
|
||||
from torch._inductor.runtime.runtime_utils import dynamo_timed
|
||||
from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn
|
||||
|
||||
from ..codecache import CudaKernelParamCache
|
||||
@ -230,19 +231,22 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||
)
|
||||
|
||||
def generate(self, is_inference):
|
||||
self.prefix.writeline("\n")
|
||||
if not V.graph.aot_mode:
|
||||
for kernel in chain(
|
||||
sorted(self.src_to_kernel.values()),
|
||||
sorted([entry[0] for entry in self.user_defined_kernel_cache.values()]),
|
||||
):
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
|
||||
)
|
||||
)
|
||||
with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True):
|
||||
self.prefix.writeline("\n")
|
||||
return super().generate(is_inference)
|
||||
if not V.graph.aot_mode:
|
||||
for kernel in chain(
|
||||
sorted(self.src_to_kernel.values()),
|
||||
sorted(
|
||||
[entry[0] for entry in self.user_defined_kernel_cache.values()]
|
||||
),
|
||||
):
|
||||
self.prefix.writeline(
|
||||
maybe_hipify_code_wrapper(
|
||||
f"static {self.device_codegen.cpp_kernel_type()} {kernel} = nullptr;"
|
||||
)
|
||||
)
|
||||
self.prefix.writeline("\n")
|
||||
return super().generate(is_inference)
|
||||
|
||||
def generate_user_defined_triton_kernel(
|
||||
self,
|
||||
|
@ -1886,24 +1886,25 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return self.codegen()
|
||||
|
||||
def codegen(self) -> Tuple[str, List[Tuple[int, Node]]]:
|
||||
from .scheduler import Scheduler
|
||||
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
|
||||
from .scheduler import Scheduler
|
||||
|
||||
self.init_wrapper_code()
|
||||
self.init_wrapper_code()
|
||||
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
|
||||
|
||||
self.wrapper_code.push_codegened_graph(self)
|
||||
self.scheduler.codegen()
|
||||
self.wrapper_code.push_codegened_graph(self)
|
||||
self.scheduler.codegen()
|
||||
|
||||
log.debug(
|
||||
"Finished codegen for all nodes. The list of kernel names available: %s",
|
||||
V.graph.all_codegen_kernel_names,
|
||||
)
|
||||
log.debug(
|
||||
"Finished codegen for all nodes. The list of kernel names available: %s",
|
||||
V.graph.all_codegen_kernel_names,
|
||||
)
|
||||
|
||||
result = self.wrapper_code.generate(self.is_inference)
|
||||
self.wrapper_code.pop_codegened_graph()
|
||||
return result
|
||||
result = self.wrapper_code.generate(self.is_inference)
|
||||
self.wrapper_code.pop_codegened_graph()
|
||||
return result
|
||||
|
||||
def codegen_subgraph(self, parent_graph: "GraphLowering") -> None:
|
||||
"""
|
||||
@ -1915,14 +1916,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
kerenls). The wrapper code is not finalized (via `.generate()`
|
||||
call), as this will be done in the parent graph's `codegen()`.
|
||||
"""
|
||||
from .scheduler import Scheduler
|
||||
with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True):
|
||||
from .scheduler import Scheduler
|
||||
|
||||
self.wrapper_code = parent_graph.wrapper_code
|
||||
self.device_ops = parent_graph.device_ops
|
||||
self.cpp_wrapper = parent_graph.cpp_wrapper
|
||||
self.wrapper_code = parent_graph.wrapper_code
|
||||
self.device_ops = parent_graph.device_ops
|
||||
self.cpp_wrapper = parent_graph.cpp_wrapper
|
||||
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
self.scheduler.codegen()
|
||||
self.scheduler = Scheduler(self.operations)
|
||||
self.scheduler.codegen()
|
||||
|
||||
def count_bytes(
|
||||
self,
|
||||
@ -2013,6 +2015,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return mod
|
||||
|
||||
def compile_to_fn(self) -> Any:
|
||||
with dynamo_timed("GraphLowering.compile_to_fn", log_pt2_compile_event=True):
|
||||
return self._compile_to_fn()
|
||||
|
||||
def _compile_to_fn(self) -> Any:
|
||||
if self.aot_mode:
|
||||
from .codecache import AotCodeCompiler
|
||||
|
||||
@ -2032,14 +2038,15 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
additional_files = self.wrapper_code.additional_files
|
||||
|
||||
# Directly return the file path with the compiled code
|
||||
return AotCodeCompiler.compile(
|
||||
self,
|
||||
code,
|
||||
serialized_extern_kernel_nodes,
|
||||
device_type=self.device_type,
|
||||
additional_files=additional_files,
|
||||
)
|
||||
with dynamo_timed("AotCodeCompiler.compile", log_pt2_compile_event=True):
|
||||
# Directly return the file path with the compiled code
|
||||
return AotCodeCompiler.compile(
|
||||
self,
|
||||
code,
|
||||
serialized_extern_kernel_nodes,
|
||||
device_type=self.device_type,
|
||||
additional_files=additional_files,
|
||||
)
|
||||
else:
|
||||
return self.compile_to_module().call
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
|
||||
|
||||
logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
|
||||
@ -100,27 +100,28 @@ class Benchmarker:
|
||||
Returns:
|
||||
- The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
|
||||
"""
|
||||
inferred_device = None
|
||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||
continue
|
||||
with dynamo_timed("Benchmarker.benchmark", log_pt2_compile_event=True):
|
||||
inferred_device = None
|
||||
for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
|
||||
if not isinstance(arg_or_kwarg, torch.Tensor):
|
||||
continue
|
||||
if inferred_device is None:
|
||||
inferred_device = arg_or_kwarg.device
|
||||
elif arg_or_kwarg.device != inferred_device:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||
)
|
||||
if inferred_device is None:
|
||||
inferred_device = arg_or_kwarg.device
|
||||
elif arg_or_kwarg.device != inferred_device:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
|
||||
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||
)
|
||||
if inferred_device is None:
|
||||
raise ValueError(
|
||||
"Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950
|
||||
)
|
||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||
if inferred_device == torch.device("cpu"):
|
||||
return self.benchmark_cpu(_callable, **kwargs)
|
||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||
# implementation which was written specifically with CUDA devices in mind, we may want to
|
||||
# explore alternate implementations for other device types.
|
||||
return self.benchmark_gpu(_callable, **kwargs)
|
||||
_callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731
|
||||
if inferred_device == torch.device("cpu"):
|
||||
return self.benchmark_cpu(_callable, **kwargs)
|
||||
# TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
|
||||
# implementation which was written specifically with CUDA devices in mind, we may want to
|
||||
# explore alternate implementations for other device types.
|
||||
return self.benchmark_gpu(_callable, **kwargs)
|
||||
|
||||
@maybe_time
|
||||
@count
|
||||
|
@ -2315,25 +2315,28 @@ class Scheduler:
|
||||
"""
|
||||
Combine eligible nodes into FusedSchedulerNodes.
|
||||
"""
|
||||
for i in range(10):
|
||||
old_len = len(nodes)
|
||||
fusion_log.debug(
|
||||
"===== attempting fusion (%d/10): %d nodes =====",
|
||||
i + 1,
|
||||
old_len,
|
||||
)
|
||||
nodes = self.fuse_nodes_once(nodes)
|
||||
new_len = len(nodes)
|
||||
fusion_log.debug(
|
||||
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
|
||||
i + 1,
|
||||
old_len,
|
||||
new_len,
|
||||
)
|
||||
if new_len == old_len or new_len == 1:
|
||||
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
|
||||
break
|
||||
return nodes
|
||||
with dynamo_timed("Scheduler.fused_nodes"):
|
||||
for i in range(10):
|
||||
old_len = len(nodes)
|
||||
fusion_log.debug(
|
||||
"===== attempting fusion (%d/10): %d nodes =====",
|
||||
i + 1,
|
||||
old_len,
|
||||
)
|
||||
nodes = self.fuse_nodes_once(nodes)
|
||||
new_len = len(nodes)
|
||||
fusion_log.debug(
|
||||
"completed fusion round (%d/10): fused %d nodes into %d nodes\n",
|
||||
i + 1,
|
||||
old_len,
|
||||
new_len,
|
||||
)
|
||||
if new_len == old_len or new_len == 1:
|
||||
fusion_log.debug(
|
||||
"===== fusion complete (%d iterations) =====", i + 1
|
||||
)
|
||||
break
|
||||
return nodes
|
||||
|
||||
def process_grouped_nodes(self) -> None:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user