mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[compile-time traces] Profile large missing gaps in compile time (#151256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151256 Approved by: https://github.com/bdhirsh, https://github.com/masnesral, https://github.com/zou3519, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee096b89f6
commit
7fdd754136
@ -230,14 +230,21 @@ class TestDynamoTimed(TestCase):
|
||||
'_recursive_joint_graph_passes': [0.0],
|
||||
'_recursive_post_grad_passes': [0.0, 0.0],
|
||||
'_recursive_pre_grad_passes': [0.0],
|
||||
'additional_fake_tensor_prop': [0.0, 0.0],
|
||||
'aot_collect_metadata': [0.0],
|
||||
'aot_trace_joint_graph': [0.0],
|
||||
'async_compile.wait': [0.0, 0.0],
|
||||
'backward._backward_impl': [0.0],
|
||||
'build_guards': [0.0],
|
||||
'bytecode_tracing': [0.0],
|
||||
'compile_attempt_0': [0.0],
|
||||
'compile_file': [0.0, 0.0],
|
||||
'compile_fx.<locals>.bw_compiler': [0.0],
|
||||
'compile_fx.<locals>.fw_compiler_base': [0.0],
|
||||
'compile_fx_inner': [0.0, 0.0],
|
||||
'create_aot_dispatcher_function': [0.0],
|
||||
'gc': [0.0]}""", # noqa: B950
|
||||
'gc': [0.0],
|
||||
'min_cut_rematerialization_partition': [0.0]}""", # noqa: B950
|
||||
)
|
||||
|
||||
# Now validate utils.calculate_time_spent(). Formatting the return
|
||||
|
@ -13,11 +13,13 @@ The inductor backend can be used with torch.compile():
|
||||
"""
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
|
||||
@register_backend
|
||||
def inductor(*args, **kwargs):
|
||||
# do import here to avoid loading inductor into memory when it is not used
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
with dynamo_timed("inductor_import", log_pt2_compile_event=True):
|
||||
# do import here to avoid loading inductor into memory when it is not used
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(*args, **kwargs)
|
||||
|
@ -737,6 +737,7 @@ def _compile(
|
||||
)
|
||||
|
||||
try:
|
||||
tracer.output.mark_bytecode_tracing_start()
|
||||
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
|
||||
tracer.run()
|
||||
except exc.UnspecializeRestartAnalysis:
|
||||
@ -810,7 +811,10 @@ def _compile(
|
||||
for attempt in itertools.count():
|
||||
CompileContext.get().attempt = attempt
|
||||
try:
|
||||
out_code = transform_code_object(code, transform)
|
||||
with dynamo_timed(
|
||||
f"compile_attempt_{attempt}", log_pt2_compile_event=True
|
||||
):
|
||||
out_code = transform_code_object(code, transform)
|
||||
break
|
||||
except exc.RestartAnalysis as e:
|
||||
if not isinstance(e, exc.TensorifyScalarRestartAnalysis):
|
||||
@ -919,13 +923,14 @@ def _compile(
|
||||
assert output.guards is not None
|
||||
CleanupManager.instance[out_code] = output.cleanups
|
||||
nonlocal cache_entry
|
||||
check_fn = CheckFunctionManager(
|
||||
code,
|
||||
output,
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
)
|
||||
with dynamo_timed("build_guards", log_pt2_compile_event=True):
|
||||
check_fn = CheckFunctionManager(
|
||||
code,
|
||||
output,
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
)
|
||||
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||
|
@ -517,6 +517,19 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
self.install_builtins_dict_in_fglobals()
|
||||
)
|
||||
|
||||
self.compiler_trace_stack = contextlib.ExitStack()
|
||||
|
||||
def mark_bytecode_tracing_start(self):
|
||||
self.compiler_trace_stack.enter_context(
|
||||
dynamo_timed(
|
||||
"bytecode_tracing",
|
||||
log_pt2_compile_event=True,
|
||||
)
|
||||
)
|
||||
|
||||
def mark_bytecode_tracing_stop(self):
|
||||
self.compiler_trace_stack.close()
|
||||
|
||||
def install_builtins_dict_in_fglobals(self):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
# implemenation detail -
|
||||
@ -1068,6 +1081,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
Generate a subgraph to continue execution on user code.
|
||||
Automatically restore live variables.
|
||||
"""
|
||||
# bytecode tracing has finished. Pop the context manager for dynamo_timed
|
||||
self.mark_bytecode_tracing_stop()
|
||||
assert reason is not None
|
||||
|
||||
from .decorators import disable
|
||||
|
@ -1374,6 +1374,11 @@ class InstructionTranslatorBase(
|
||||
if isinstance(self, InstructionTranslator):
|
||||
self.output.cleanup()
|
||||
|
||||
# Note that this call maybe redundant if compile_subgraph is
|
||||
# called. This is ok, because calling exit stack close()
|
||||
# twice is not an issue (second stop is a no op).
|
||||
self.output.mark_bytecode_tracing_stop()
|
||||
|
||||
def push(self, val: Optional[VariableTracker]):
|
||||
assert val is None or isinstance(val, VariableTracker), (
|
||||
f"push expects VariableTracker, got {typestr(val)}"
|
||||
|
@ -23,7 +23,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
|
||||
from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code
|
||||
from torch._guards import CompileContext, TracingContext
|
||||
from torch._logging import getArtifactLogger, trace_structured
|
||||
from torch._subclasses import FakeTensor
|
||||
@ -792,9 +792,10 @@ def aot_dispatch_autograd(
|
||||
)
|
||||
|
||||
fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled()
|
||||
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
|
||||
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True):
|
||||
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(
|
||||
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata
|
||||
)
|
||||
|
||||
# Copied from aot_dispatch_autograd_graph.
|
||||
disable_amp = torch._C._is_any_autocast_enabled()
|
||||
|
@ -673,7 +673,17 @@ def _create_aot_dispatcher_function(
|
||||
ctx = _detect_attribute_assignment(mod)
|
||||
else:
|
||||
ctx = nullcontext()
|
||||
with ctx:
|
||||
|
||||
if torch._functorch.config.fake_tensor_propagate_real_tensors:
|
||||
# Running dynamo_timed causes fake tensor issues when
|
||||
# propagate real tensor is switched on.
|
||||
dynamo_timed_ctx = nullcontext()
|
||||
else:
|
||||
dynamo_timed_ctx = dynamo_timed(
|
||||
"aot_collect_metadata", log_pt2_compile_event=True
|
||||
)
|
||||
|
||||
with dynamo_timed_ctx, ctx:
|
||||
fw_metadata = run_functionalized_fw_and_collect_metadata(
|
||||
flat_fn,
|
||||
static_input_indices=aot_config.static_input_indices,
|
||||
|
@ -1140,12 +1140,15 @@ class _InProcessFxCompile(FxCompile):
|
||||
# .view() call.
|
||||
view_to_reshape(gm)
|
||||
|
||||
# It is safe to run FakeTensorProp under no_grad because by the time
|
||||
# we're in inductor, we assume that AOTAutograd has already "taken care"
|
||||
# of autograd, so there should be no more autograd-related API's in the
|
||||
# graph.
|
||||
with torch.no_grad():
|
||||
fake_mode = fake_tensor_prop(gm, example_inputs)
|
||||
with dynamo_timed(
|
||||
"additional_fake_tensor_prop", log_pt2_compile_event=True
|
||||
):
|
||||
# It is safe to run FakeTensorProp under no_grad because by the time
|
||||
# we're in inductor, we assume that AOTAutograd has already "taken care"
|
||||
# of autograd, so there should be no more autograd-related API's in the
|
||||
# graph.
|
||||
with torch.no_grad():
|
||||
fake_mode = fake_tensor_prop(gm, example_inputs)
|
||||
|
||||
record_original_output_strides(gm)
|
||||
|
||||
@ -2196,13 +2199,17 @@ def compile_fx(
|
||||
static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment]
|
||||
"static_lifetime_input_indices", None
|
||||
)
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"min_cut_rematerialization_partition", log_pt2_compile_event=True
|
||||
):
|
||||
return min_cut_rematerialization_partition(
|
||||
gm,
|
||||
joint_inputs,
|
||||
compiler="inductor",
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@compile_time_strobelight_meta(phase_name="backward")
|
||||
def bw_compiler(
|
||||
|
Reference in New Issue
Block a user