[compile-time traces] Profile large missing gaps in compile time (#151256)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151256
Approved by: https://github.com/bdhirsh, https://github.com/masnesral, https://github.com/zou3519, https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-05-12 22:35:03 -07:00
committed by PyTorch MergeBot
parent ee096b89f6
commit 7fdd754136
8 changed files with 81 additions and 29 deletions

View File

@ -230,14 +230,21 @@ class TestDynamoTimed(TestCase):
'_recursive_joint_graph_passes': [0.0],
'_recursive_post_grad_passes': [0.0, 0.0],
'_recursive_pre_grad_passes': [0.0],
'additional_fake_tensor_prop': [0.0, 0.0],
'aot_collect_metadata': [0.0],
'aot_trace_joint_graph': [0.0],
'async_compile.wait': [0.0, 0.0],
'backward._backward_impl': [0.0],
'build_guards': [0.0],
'bytecode_tracing': [0.0],
'compile_attempt_0': [0.0],
'compile_file': [0.0, 0.0],
'compile_fx.<locals>.bw_compiler': [0.0],
'compile_fx.<locals>.fw_compiler_base': [0.0],
'compile_fx_inner': [0.0, 0.0],
'create_aot_dispatcher_function': [0.0],
'gc': [0.0]}""", # noqa: B950
'gc': [0.0],
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
)
# Now validate utils.calculate_time_spent(). Formatting the return

View File

@ -13,11 +13,13 @@ The inductor backend can be used with torch.compile():
"""
from torch._dynamo import register_backend
from torch._dynamo.utils import dynamo_timed
@register_backend
def inductor(*args, **kwargs):
# do import here to avoid loading inductor into memory when it is not used
from torch._inductor.compile_fx import compile_fx
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
# do import here to avoid loading inductor into memory when it is not used
from torch._inductor.compile_fx import compile_fx
return compile_fx(*args, **kwargs)

View File

@ -737,6 +737,7 @@ def _compile(
)
try:
tracer.output.mark_bytecode_tracing_start()
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
tracer.run()
except exc.UnspecializeRestartAnalysis:
@ -810,7 +811,10 @@ def _compile(
for attempt in itertools.count():
CompileContext.get().attempt = attempt
try:
out_code = transform_code_object(code, transform)
with dynamo_timed(
f"compile_attempt_{attempt}", log_pt2_compile_event=True
):
out_code = transform_code_object(code, transform)
break
except exc.RestartAnalysis as e:
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
@ -919,13 +923,14 @@ def _compile(
assert output.guards is not None
CleanupManager.instance[out_code] = output.cleanups
nonlocal cache_entry
check_fn = CheckFunctionManager(
code,
output,
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
)
with dynamo_timed("build_guards", log_pt2_compile_event=True):
check_fn = CheckFunctionManager(
code,
output,
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str

View File

@ -517,6 +517,19 @@ class OutputGraph(OutputGraphGuardsState):
self.install_builtins_dict_in_fglobals()
)
self.compiler_trace_stack = contextlib.ExitStack()
def mark_bytecode_tracing_start(self):
self.compiler_trace_stack.enter_context(
dynamo_timed(
"bytecode_tracing",
log_pt2_compile_event=True,
)
)
def mark_bytecode_tracing_stop(self):
self.compiler_trace_stack.close()
def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an
# implemenation detail -
@ -1068,6 +1081,8 @@ class OutputGraph(OutputGraphGuardsState):
Generate a subgraph to continue execution on user code.
Automatically restore live variables.
"""
# bytecode tracing has finished. Pop the context manager for dynamo_timed
self.mark_bytecode_tracing_stop()
assert reason is not None
from .decorators import disable

View File

@ -1374,6 +1374,11 @@ class InstructionTranslatorBase(
if isinstance(self, InstructionTranslator):
self.output.cleanup()
# Note that this call maybe redundant if compile_subgraph is
# called. This is ok, because calling exit stack close()
# twice is not an issue (second stop is a no op).
self.output.mark_bytecode_tracing_stop()
def push(self, val: Optional[VariableTracker]):
assert val is None or isinstance(val, VariableTracker), (
f"push expects VariableTracker, got {typestr(val)}"

View File

@ -23,7 +23,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code
from torch._guards import CompileContext, TracingContext
from torch._logging import getArtifactLogger, trace_structured
from torch._subclasses import FakeTensor
@ -792,9 +792,10 @@ def aot_dispatch_autograd(
)
fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
)
# Copied from aot_dispatch_autograd_graph.
disable_amp = torch._C._is_any_autocast_enabled()

View File

@ -673,7 +673,17 @@ def _create_aot_dispatcher_function(
ctx = _detect_attribute_assignment(mod)
else:
ctx = nullcontext()
with ctx:
if torch._functorch.config.fake_tensor_propagate_real_tensors:
# Running dynamo_timed causes fake tensor issues when
# propagate real tensor is switched on.
dynamo_timed_ctx = nullcontext()
else:
dynamo_timed_ctx = dynamo_timed(
"aot_collect_metadata", log_pt2_compile_event=True
)
with dynamo_timed_ctx, ctx:
fw_metadata = run_functionalized_fw_and_collect_metadata(
flat_fn,
static_input_indices=aot_config.static_input_indices,

View File

@ -1140,12 +1140,15 @@ class _InProcessFxCompile(FxCompile):
# .view() call.
view_to_reshape(gm)
# It is safe to run FakeTensorProp under no_grad because by the time
# we're in inductor, we assume that AOTAutograd has already "taken care"
# of autograd, so there should be no more autograd-related API's in the
# graph.
with torch.no_grad():
fake_mode = fake_tensor_prop(gm, example_inputs)
with dynamo_timed(
"additional_fake_tensor_prop", log_pt2_compile_event=True
):
# It is safe to run FakeTensorProp under no_grad because by the time
# we're in inductor, we assume that AOTAutograd has already "taken care"
# of autograd, so there should be no more autograd-related API's in the
# graph.
with torch.no_grad():
fake_mode = fake_tensor_prop(gm, example_inputs)
record_original_output_strides(gm)
@ -2196,13 +2199,17 @@ def compile_fx(
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
"static_lifetime_input_indices", None
)
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
with dynamo_utils.dynamo_timed(
"min_cut_rematerialization_partition", log_pt2_compile_event=True
):
return min_cut_rematerialization_partition(
gm,
joint_inputs,
compiler="inductor",
static_lifetime_input_indices=static_lifetime_input_indices,
**kwargs,
)
@compile_time_strobelight_meta(phase_name="backward")
def bw_compiler(