[logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849)

Here's the overview:

There's a new contextmanager singleton called MetricsContext. Entering the MetricsContext is how we demarcate the boundary on which we'll create a single CompilationMetrics object, and therefore, a single dynamo_compile log entry. While we're inside the MetricsContext, we can update/set many different metrics. Most importantly: `dynamo_timed` can also update the in-progress MetricsContext. In the proposal here, we tell `dynamo_timed` that we want it to do so by providing the name of the MetricsContext field to increment. There can be many `dynamo_timed` calls in different parts of the code updating different fields. Then when the MetricsContext exits, that's when the logging of everything gathered finally happens. One potential footgun is trying to use `dynamo_timed` when we haven't entered the MetricsContext, but we assert on that problem. Another problem is that we re-enter the context recursively, but we watch for that and do the logging only when the outermost exits.

Some specifics:
* Introduce MetricsContext - a context manager that on exit, records the CompilationMetrics (which also logs to dynamo_compile).
* Completely remove the concept of frame_phase_timing. Instead, update the MetricsContext during compilation, either directly or via dynamo_timed.
* Remove some globals we previously used to accumulate counters to later populate a CompilationMetrics. We use CompilationMetrics set/update/increment APIs instead.
* `record_compilation_metrics` is now called on exit from MetricsContext.
* Populate legacy CompilationMetrics fields right before logging, inside `record_compilation_metrics`.
* Remove the one-off `add_remote_cache_time_saved` helper; capture that timing directly into the MetricsContext.

And specifically, several changes to dynamo_timed:
* "Modernize" the parameters and update all callsites accordingly.
* Move the backwards logging of the CompilationMetrics to the backwards compile location.
* Add a parameter for which CompilationMetrics field to update

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139849
Approved by: https://github.com/ezyang
ghstack dependencies: #140094
This commit is contained in:
Sam Larsen
2024-11-09 10:30:47 -08:00
committed by PyTorch MergeBot
parent 565a7942ee
commit cb15c15157
14 changed files with 456 additions and 326 deletions

View File

@ -0,0 +1,78 @@
# Owner(s): ["module: dynamo"]
from torch._dynamo.metrics_context import MetricsContext
from torch._dynamo.test_case import run_tests, TestCase
class TestMetricsContext(TestCase):
def setUp(self):
super().setUp()
self.metrics = {}
def _on_exit(self, metrics):
# Save away the metrics to be validated in the test.
self.metrics = metrics.copy()
def test_context_exists(self):
"""
Setting a value without entering the context should raise.
"""
context = MetricsContext(self._on_exit)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.increment("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.set("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.update({"m", 1})
def test_nested_context(self):
"""
Only the outermost context should get an on_exit call, and it should
include everything.
"""
context = MetricsContext(self._on_exit)
with context:
with context:
context.set("m1", 1)
self.assertEqual(self.metrics, {})
context.set("m2", 2)
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
def test_set(self):
"""
Validate various ways to set metrics.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
context.set("m2", 2)
context.update({"m3": 3, "m4": 4})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
def test_set_disallow_overwrite(self):
"""
Validate set won't overwrite.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.set("m1", 2)
self.assertEqual(self.metrics, {"m1": 1})
def test_update_disallow_overwrite(self):
"""
Validate update won't overwite.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.update({"m1": 7, "m3": 3})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
if __name__ == "__main__":
run_tests()

View File

@ -159,6 +159,7 @@ class TestDynamoTimed(TestCase):
'_recursive_post_grad_passes': [0.0, 0.0],
'_recursive_pre_grad_passes': [0.0],
'async_compile.wait': [0.0, 0.0],
'backward._backward_impl': [0.0],
'compile_file': [0.0, 0.0],
'compile_fx.<locals>.bw_compiler': [0.0],
'compile_fx.<locals>.fw_compiler_base': [0.0],
@ -174,6 +175,7 @@ class TestDynamoTimed(TestCase):
"""\
{'backend_compile': 0.0,
'code_gen': 0.0,
'entire_backward_compile': 0.0,
'entire_frame_compile': 0.0,
'inductor_compile': 0.0,
'total_wall_time': 0.0}""", # noqa: B950
@ -200,6 +202,7 @@ class TestDynamoTimed(TestCase):
{'accumulated_cache_size': 0,
'aot_autograd_cumulative_compile_time_us': 0,
'backend_compile_time_s': 0.0,
'backward_cumulative_compile_time_us': None,
'cache_size': 0,
'co_filename': None,
'co_firstlineno': None,
@ -210,12 +213,13 @@ class TestDynamoTimed(TestCase):
'config_inline_inbuilt_nn_modules': False,
'config_suppress_errors': False,
'cuda_synchronize_time_us': None,
'distributed_ephemeral_timeout_us': 0,
'distributed_ephemeral_timeout_us': None,
'duration_us': 0,
'dynamo_compile_time_before_restart_us': 0,
'dynamo_config': None,
'dynamo_cumulative_compile_time_us': 0,
'dynamo_time_before_restart_s': 0.0,
'end_time_us': 100,
'entire_frame_compile_time_s': 0.0,
'fail_reason': None,
'fail_type': None,
@ -231,9 +235,10 @@ class TestDynamoTimed(TestCase):
'inductor_compile_time_s': 0.0,
'inductor_cumulative_compile_time_us': 0,
'is_forward': True,
'log_format_version': 2,
'non_compliant_ops': set(),
'num_triton_bundles': None,
'remote_cache_time_saved_s': 0,
'remote_cache_time_saved_s': None,
'remote_fx_graph_cache_get_time_ms': None,
'remote_fx_graph_cache_get_time_us': None,
'remote_fx_graph_cache_put_time_ms': None,
@ -257,6 +262,7 @@ class TestDynamoTimed(TestCase):
{'accumulated_cache_size': None,
'aot_autograd_cumulative_compile_time_us': None,
'backend_compile_time_s': None,
'backward_cumulative_compile_time_us': 0,
'cache_size': None,
'co_filename': None,
'co_firstlineno': None,
@ -273,6 +279,7 @@ class TestDynamoTimed(TestCase):
'dynamo_config': None,
'dynamo_cumulative_compile_time_us': None,
'dynamo_time_before_restart_s': None,
'end_time_us': 100,
'entire_frame_compile_time_s': None,
'fail_reason': None,
'fail_type': None,
@ -288,6 +295,7 @@ class TestDynamoTimed(TestCase):
'inductor_compile_time_s': 0.0,
'inductor_cumulative_compile_time_us': 0,
'is_forward': False,
'log_format_version': 2,
'non_compliant_ops': None,
'num_triton_bundles': None,
'remote_cache_time_saved_s': None,
@ -302,7 +310,7 @@ class TestDynamoTimed(TestCase):
'specialize_float': None,
'start_time': None,
'start_time_us': 100,
'structured_logging_overhead_s': 0.0,
'structured_logging_overhead_s': None,
'structured_logging_overhead_us': 0,
'triton_compile_time_us': None}""", # noqa: B950
)

View File

@ -30,7 +30,7 @@ import torch
import torch._logging
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._dynamo.utils import CompileTimeInstructionCounter
from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import (
@ -105,12 +105,9 @@ from .symbolic_convert import (
from .trace_rules import is_numpy
from .utils import (
CleanupManager,
codecache_metrics,
CompilationMetrics,
counters,
dynamo_timed,
format_bytecode,
frame_phase_timing,
gen_record_file_name,
get_chromium_event_logger,
increment_frame,
@ -118,10 +115,8 @@ from .utils import (
istype,
LazyString,
orig_code_map,
record_compilation_metrics,
reset_graph_break_dup_checker,
setup_compile_debug,
to_int_ms,
to_int_us,
troubleshooting_url,
write_record_to_file,
@ -699,7 +694,9 @@ def _compile(
with contextlib.ExitStack() as stack:
stack.enter_context(
dynamo_timed(
"_compile.compile_inner", phase_name="entire_frame_compile"
"_compile.compile_inner",
phase_name="entire_frame_compile",
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
)
)
stack.enter_context(
@ -864,9 +861,11 @@ def _compile(
chromium_event_log.reset()
chromium_start_time = time.time_ns()
chromium_event_log.log_event_start("dynamo", chromium_start_time, {})
metrics_context = get_metrics_context()
with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(
CompileContext(compile_id)
):
), metrics_context:
restart_reasons: set[str] = set()
# This is shared across restarts
mutated_closure_cell_ids: Set[int] = set()
@ -974,7 +973,6 @@ def _compile(
fail_user_frame_lineno: Optional[int] = None
torch._dynamo.utils.ReinplaceCounters.clear()
guarded_code = None
codecache_metrics.clear()
try:
guarded_code = compile_inner(code, one_graph, hooks, transform)
@ -991,6 +989,7 @@ def _compile(
return guarded_code
except Exception as e:
# TODO(masnesral): Populating the exception info should be automatic
fail_type = type(e).__qualname__
fail_reason = str(e)
# NB: e's msg is mutated here to add user stack, but we DON'T want
@ -1040,66 +1039,34 @@ def _compile(
if tracer:
tracer.output.local_scope = {}
duration_ns = time.time_ns() - start_time_ns
end_time_ns = time.time_ns()
duration_ns = end_time_ns - start_time_ns
from .utils import curr_frame
frame_key = str(curr_frame)
if (
fail_reason is None
and output is not None
and frame_key in frame_phase_timing
):
if fail_reason is None and output is not None:
guard_count = len(output.guards)
shape_env_guard_count = len(output.shape_env.guards)
graph_op_count = output.count_calls()
graph_node_count = len(output.graph.nodes)
graph_input_count = len(output.placeholders)
entire_frame_compile_time = frame_phase_timing[frame_key].get(
"entire_frame_compile", None
)
backend_compile_time = frame_phase_timing[frame_key].get(
"backend_compile", None
)
inductor_compile_time = frame_phase_timing[frame_key].get(
"inductor_compile", None
)
code_gen_time = frame_phase_timing[frame_key].get("code_gen", None)
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
compliant_custom_ops = {
op.__qualname__ for op in output.compliant_custom_ops
}
remote_cache_time_saved = frame_phase_timing[frame_key].get(
"remote_cache_time_saved", 0
)
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_get", None
)
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_put", None
)
num_triton_bundles = codecache_metrics.get("num_triton_bundles", None)
torch._dynamo.utils.ReinplaceCounters.log()
else:
guard_count = None
shape_env_guard_count = None
graph_op_count = None
graph_node_count = None
graph_input_count = None
entire_frame_compile_time = None
backend_compile_time = None
inductor_compile_time = None
code_gen_time = None
non_compliant_ops = set({})
compliant_custom_ops = set({})
restart_reasons = set()
# If compilation failed, the entire time is wasted
dynamo_time_before_restart = duration_ns / 1e9
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None
num_triton_bundles = None
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
@ -1134,74 +1101,55 @@ def _compile(
}
config_dict = clean_for_json(config.get_config_copy())
metrics = CompilationMetrics(
str(compile_id),
frame_key,
code.co_name,
code.co_filename,
code.co_firstlineno,
cache_size.num_cache_entries_with_same_id_matched_objs,
cache_size.num_cache_entries,
guard_count,
shape_env_guard_count,
graph_op_count,
graph_node_count,
graph_input_count,
start_time_ns / 1e9,
entire_frame_compile_time,
backend_compile_time,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
fail_user_frame_filename,
fail_user_frame_lineno,
non_compliant_ops,
compliant_custom_ops,
restart_reasons,
dynamo_time_before_restart,
guarded_code is not None,
remote_cache_time_saved,
structured_logging_overhead_s,
config.suppress_errors,
config.inline_inbuilt_nn_modules,
config.specialize_float,
json.dumps(config_dict),
True, # is_forward
num_triton_bundles,
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
start_time_us=start_time_ns // 1000,
duration_us=duration_ns // 1000,
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
triton_compile_time_us=None, # TODO: instrument
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
dynamo_compile_time_before_restart_us=to_int_us(
metrics = {
"compile_id": str(compile_id),
"frame_key": frame_key,
"co_name": code.co_name,
"co_filename": code.co_filename,
"co_firstlineno": code.co_firstlineno,
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
"accumulated_cache_size": cache_size.num_cache_entries,
"guard_count": guard_count,
"shape_env_guard_count": shape_env_guard_count,
"graph_op_count": graph_op_count,
"graph_node_count": graph_node_count,
"graph_input_count": graph_input_count,
# TODO(masnesral): start_time and end_time shouldn't need to be
# populated manually.
"start_time": start_time_ns / 1e9,
"fail_type": fail_type,
"fail_reason": fail_reason,
"fail_user_frame_filename": fail_user_frame_filename,
"fail_user_frame_lineno": fail_user_frame_lineno,
"non_compliant_ops": non_compliant_ops,
"compliant_custom_ops": compliant_custom_ops,
"restart_reasons": restart_reasons,
"dynamo_time_before_restart_s": dynamo_time_before_restart,
"has_guarded_code": guarded_code is not None,
"structured_logging_overhead_s": structured_logging_overhead_s,
"config_suppress_errors": config.suppress_errors,
"config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
"specialize_float": config.specialize_float,
"dynamo_config": json.dumps(config_dict),
"is_forward": True,
"start_time_us": start_time_ns // 1000,
"end_time_us": end_time_ns // 1000,
"duration_us": duration_ns // 1000,
"dynamo_compile_time_before_restart_us": to_int_us(
dynamo_time_before_restart
),
cuda_synchronize_time_us=None, # TODO: instrument
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
"structured_logging_overhead_us": to_int_us(
structured_logging_overhead_s
),
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
}
metrics_context.update_outer(metrics)
torch._dynamo.callback_handler.run_end_callbacks()
chromium_event_log.log_event_end(
"dynamo", time.time_ns(), {}, chromium_start_time, True
)
# === END WARNING WARNING WARNING ===
chromium_event_log.log_event_end(
"dynamo", time.time_ns(), {}, chromium_start_time, True
)
class ConvertFrame:
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None:

View File

@ -0,0 +1,95 @@
from typing import Any, Callable, Dict, Optional, Type
from typing_extensions import TypeAlias
OnExitType: TypeAlias = Callable[[Dict[str, Any]], None]
class MetricsContext:
def __init__(self, on_exit: OnExitType):
"""
Use this class as a contextmanager to create a context under which to accumulate
a set of metrics, e.g., metrics gathered during a compilation. On exit of the
contextmanager, call the provided 'on_exit' function and pass a dictionary of
all metrics set during the lifetime of the contextmanager.
"""
self._on_exit = on_exit
self._metrics: Dict[str, Any] = {}
self._level = 0
def __enter__(self) -> "MetricsContext":
"""
Initialize metrics recording.
"""
if self._level == 0:
# In case of recursion, track at the outermost context.
self._metrics = {}
self._level += 1
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
_traceback: Any,
) -> None:
"""
At exit, call the provided on_exit function.
"""
self._level -= 1
assert self._level >= 0
if self._level == 0:
self._on_exit(self._metrics)
def in_progress(self) -> bool:
"""
True if we've entered the context.
"""
return self._level > 0
def increment(self, metric: str, value: int) -> None:
"""
Increment a metric by a given amount.
"""
if self._level == 0:
raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
if metric not in self._metrics:
self._metrics[metric] = 0
self._metrics[metric] += value
def set(self, metric: str, value: Any) -> None:
"""
Set a metric to a given value. Raises if the metric has been assigned previously
in the current context.
"""
if self._level == 0:
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
if metric in self._metrics:
raise RuntimeError(
f"Metric '{metric}' has already been set in the current context"
)
self._metrics[metric] = value
def update(self, values: Dict[str, Any]) -> None:
"""
Set multiple metrics directly. This method does NOT increment. Raises if any
metric has been assigned previously in the current context.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
existing = self._metrics.keys() & values.keys()
if existing:
raise RuntimeError(
f"Metric(s) {existing} have already been set in the current context"
)
self._metrics.update(values)
def update_outer(self, values: Dict[str, Any]) -> None:
"""
Update, but only when at the outermost context.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
if self._level == 1:
self.update(values)

View File

@ -1391,6 +1391,7 @@ class OutputGraph:
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)

View File

@ -43,6 +43,7 @@ from typing import (
DefaultDict,
Deque,
Dict,
Generator,
Iterable,
Iterator,
KeysView,
@ -71,6 +72,7 @@ from torch._C import (
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.metrics_context import MetricsContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import (
@ -139,12 +141,10 @@ log = logging.getLogger(__name__)
# profiling compilation time by function
compilation_time_metrics: Dict[str, List[float]] = {}
# profiling compilation time by frame phase
frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
lambda: collections.defaultdict(float)
)
codecache_metrics: Counter[str] = collections.Counter()
# This supports calculate_time_spent(), which reports cumulative times
# across the process for any "phase" populated by dynamo_timed. Reset if
# reset_frame_count() is called.
cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float)
timer_counter = itertools.count()
@ -220,7 +220,7 @@ def increment_frame() -> None:
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def reset_frame_count() -> None:
global curr_frame
frame_phase_timing.clear()
cumulative_time_spent_ns.clear()
compilation_time_metrics.clear()
curr_frame = 0
@ -233,25 +233,16 @@ def increment_op_count(cnt: int) -> None:
op_count += cnt
# Calculate total time spent so far for each phase
# Get the total time in seconds for each "phase"
# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
def calculate_time_spent() -> Dict[str, float]:
total_wall_time = 0.0
total_by_key = {}
for timings in frame_phase_timing.values():
total_wall_time += timings.get(
"entire_frame_compile", timings.get("inductor_compile", 0)
)
for key, timing in timings.items():
if key not in total_by_key:
total_by_key[key] = timing
else:
total_by_key[key] += timing
if total_by_key:
total_by_key["total_wall_time"] = total_wall_time
for phase, timing in cumulative_time_spent_ns.items():
total_by_key[phase] = timing / 1e9
total_by_key["total_wall_time"] = total_by_key.get(
"entire_frame_compile", 0
) + total_by_key.get("entire_backward_compile", 0)
return total_by_key
@ -270,188 +261,124 @@ def print_time_report() -> None:
print(out)
def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent
# Use the following singleton to capture and log CompilationMetrics. Entering the context
# manager allocates a new record to be logged when it exits. (You should not need to use
# this directly unless you introduce a new code path where compilation metrics would be
# gathered). While compiling, use the setters or timer in MetricsContext to update fields
# in the current context. For example:
#
# To set a single field once (use overwrite=True to overwrite):
# get_metrics_context().set("metric_name", value)
#
# To set multiple fields at once (use overwrite=True to overwrite):
# get_metrics_context().update({"name1": val1, "name2": val2})
#
# To increment an integer field:
# get_metrics_context().increment("metric_name", value)
#
# To record execution time, MetricsContext works with dynamo_timed:
# def foo(...):
# # Updates the "metric_us" field.
# with dynamo_timed("metric", dynamo_compile_column_us="metric_us")
# ...
#
_METRICS_CONTEXT: MetricsContext
# Use frame_phase_timing to record remote_cache_time_saved
# This follows the same principles of key as the other frame phase timings,
# but is incremented by FxGraphCache (and later AOTAutogradCache) directly
def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None:
key = None
if is_backward:
# Use compile id as the frame key for backwards compilation
key = str(torch._guards.CompileContext.current_compile_id())
else:
key = str(curr_frame)
# Convert to seconds (as a float)
time_saved = time_saved_ns / 1e9
_add_time_spent(key, "remote_cache_time_saved", time_saved)
# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
# For example:
#
# def _foo(...):
# with dynamo_timed("_foo"):
# ...
#
# Would show up as an entry in our timing dict:
# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
# This is extremely useful for granular debugging.
#
# Although it is tempting to use dynamo_timed as a decorator, please do not.
# In its decorator form it makes cProfile traces less useful as dynamo_timed
# suddenly becomes a bottleneck for lots of function calls (as only one parent
# pointer is recorded).
#
# For a higher-level mode, pass a phase_name into dynamo_timed
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
# The frame is incremented outside of this function, in def increment_frame() above.
# `fwd_only` is used to identify if this phase or function is only called
# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
def get_metrics_context() -> MetricsContext:
return _METRICS_CONTEXT
@contextmanager
def dynamo_timed(
key: str,
# TODO(masneral): Deprecate this param.
phase_name: Optional[str] = None,
log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event
log_pt2_compile_event: bool = False,
# TODO(masnesral): fwd_only is ignored. Remove it.
fwd_only: bool = True,
):
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
metadata: Optional[Dict[str, object]] = None,
dynamo_compile_column_us: Optional[str] = None,
) -> Generator[Any, None, None]:
"""
dynamo_timed is a context manager
By wrapping a function in dynamo_timed, we can get a few things:
1) Log timings to pt2_compile_events.
2) Log timings to CompilationMetrics (dynamo_compile).
3) Chromium events.
4) Storing a record in compilation_time_metrics
For example:
def _foo(...):
with dynamo_timed("_foo"):
...
Would show up as an entry in our timing dict:
OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
This is extremely useful for granular debugging.
Although it is tempting to use dynamo_timed as a decorator, please do not.
In its decorator form it makes cProfile traces less useful as dynamo_timed
suddenly becomes a bottleneck for lots of function calls (as only one parent
pointer is recorded).
Params:
- key: key into compile_time_metrics. If phase_name is not provided, this is
also the event name used for pt2_compile_events logs and chromium events.
- phase_name: Optional override for the event name.
- log_pt2_compile_event: Whether to log a pt2 compile event internally.
- metadata: Extra metadata to put in pt2_compile_events.
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
field to be logged to dyname_compile column. We expect all columns to be _us;
therefore, the field name must end with "_us".
"""
# We're standardizing on microseconds for dynamo_compile timings.
if dynamo_compile_column_us is not None:
assert dynamo_compile_column_us.endswith("_us")
if phase_name:
event_name = phase_name
fn_name = key
else:
event_name = key
fn_name = None
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
event_metadata = {}
if metadata:
event_metadata.update(metadata)
if fn_name:
event_metadata.update({"fn_name": fn_name})
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
start_ns = time.time_ns()
chromium_log.log_event_start(event_name, start_ns, event_metadata)
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
if phase_name:
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
else:
chromium_log.log_event_start(key, start_ns, {})
yield
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
fail_type = str(type(e))
fail_reason = str(e)
raise
finally:
end_ns = time.time_ns()
# Always log the end event even on exception
if phase_name:
chromium_log.log_event_end(
phase_name,
end_ns,
{},
start_ns,
log_pt2_compile_event,
)
else:
chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event)
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
# use frame_key as time aggregation key.
if fwd_only and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
else:
# fwd + bwd compilation stages: inductor_compile, code_gen.
# use frame_key as time aggregation key for fwd graphs;
# use compile_id as time aggregation key for bwd graphs.
if torch._guards.TracingContext.try_get() is not None:
aot_graph_name = str(
torch._guards.TracingContext.get().aot_graph_name
)
if (
"forward" in aot_graph_name or "inference" in aot_graph_name
) and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
elif "backward" in aot_graph_name:
compile_id = str(
torch._guards.CompileContext.current_compile_id()
)
if fail_type is None:
_add_time_spent(compile_id, phase_name, time_spent)
# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
# one record for one bwd graph.
if phase_name == "inductor_compile":
if fail_type is None:
inductor_compile_time = frame_phase_timing[
compile_id
].get("inductor_compile", None)
code_gen_time = frame_phase_timing[compile_id].get(
"code_gen", None
)
remote_cache_time_saved = frame_phase_timing[
compile_id
].get("remote_cache_time_saved", None)
remote_fx_graph_cache_get_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_get", None)
remote_fx_graph_cache_put_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_put", None)
else:
inductor_compile_time = None
code_gen_time = None
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
)
metrics = CompilationMetrics(
compile_id=compile_id,
inductor_compile_time_s=inductor_compile_time,
code_gen_time_s=code_gen_time,
fail_type=fail_type,
fail_reason=fail_reason,
remote_cache_time_saved_s=remote_cache_time_saved,
structured_logging_overhead_s=structured_logging_overhead_s,
is_forward=False, # is_forward
num_triton_bundles=codecache_metrics.get(
"num_triton_bundles", None
),
remote_fx_graph_cache_get_time_ms=to_int_ms(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_ms=to_int_ms(
remote_fx_graph_cache_put_time
),
start_time_us=start_ns // 1000,
duration_us=(end_ns - start_ns) // 1000,
inductor_cumulative_compile_time_us=to_int_us(
inductor_compile_time
),
inductor_code_gen_cumulative_compile_time_us=to_int_us(
code_gen_time
),
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(
structured_logging_overhead_s
),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
time_spent_ns = end_ns - start_ns
compilation_time_metrics[key].append(time_spent_ns / 1e9)
chromium_log.log_event_end(
event_name, end_ns, {}, start_ns, log_pt2_compile_event
)
if dynamo_compile_column_us:
metrics_context = get_metrics_context()
if metrics_context.in_progress():
metrics_context.increment(
dynamo_compile_column_us, time_spent_ns // 1000
)
# TODO: the events that we capture in calculate_time_spent() seem a little
# arbitrary. Currently, it's only those fields that are present in
# CompilationMetrics (but note that we accumulate by the associated event
# name, not the field name in CompilationMetrics). Do we want to keep it
# this way?
cumulative_time_spent_ns[event_name] += time_spent_ns
@overload
@ -866,6 +793,11 @@ def to_int_us(v: Optional[float]) -> Optional[int]:
return None if v is None else int(v * 1_000_000)
# Version field added to every log. Increment to make it easier to distinguish new
# vs. old entries when you make a substantive change to how the logs are populated.
LOG_FORMAT_VERSION = 2
@dataclasses.dataclass
class CompilationMetrics:
compile_id: Optional[str] = None
@ -913,15 +845,18 @@ class CompilationMetrics:
aot_autograd_cumulative_compile_time_us: Optional[int] = None
inductor_cumulative_compile_time_us: Optional[int] = None
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None
runtime_cudagraphify_time_us: Optional[int] = None
runtime_triton_autotune_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None # TODO: instrument
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
dynamo_compile_time_before_restart_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
distributed_ephemeral_timeout_us: Optional[int] = None
structured_logging_overhead_us: Optional[int] = None
remote_fx_graph_cache_get_time_us: Optional[int] = None
remote_fx_graph_cache_put_time_us: Optional[int] = None
backward_cumulative_compile_time_us: Optional[int] = None
end_time_us: Optional[int] = None
log_format_version: int = LOG_FORMAT_VERSION
DEFAULT_COMPILATION_METRICS_LIMIT = 64
@ -969,8 +904,32 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
)
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
global _compilation_metrics
def record_compilation_metrics(metrics: Dict[str, Any]):
# TODO: Temporary; populate legacy fields from their replacements.
# Remove when we decide we can really deprecate them.
def us_to_s(field):
metric = metrics.get(field, None)
return metric / 1e6 if metric is not None else None
def us_to_ms(field):
metric = metrics.get(field, None)
return metric // 1000 if metric is not None else None
legacy_metrics = {
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
"code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"),
"remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"),
"remote_fx_graph_cache_get_time_ms": us_to_ms(
"remote_fx_graph_cache_get_time_us"
),
"remote_fx_graph_cache_put_time_ms": us_to_ms(
"remote_fx_graph_cache_put_time_us"
),
}
compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics})
_compilation_metrics.append(compilation_metrics)
if compilation_metrics.is_forward:
name = "compilation_metrics"
@ -979,10 +938,7 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics):
name = "bwd_compilation_metrics"
torch._logging.trace_structured(
name,
lambda: {
k: list(v) if isinstance(v, set) else v
for k, v in dataclasses.asdict(compilation_metrics).items()
},
lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()},
# NB: Because compilation metrics *includes* the logging overhead time,
# we can't both *measure* the logging overhead of compilation metrics
# without making it inconsistent with compilation metrics itself, so
@ -993,6 +949,10 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics):
log_compilation_event(compilation_metrics)
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
def set_compilation_metrics_limit(new_size: int) -> None:
global _compilation_metrics
while len(_compilation_metrics) > new_size:

View File

@ -632,7 +632,7 @@ class AOTAutogradCache:
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
# get_metrics_context().increment(...)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns

View File

@ -10,6 +10,7 @@ import builtins
import collections
import itertools
import pprint
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us
from torch._guards import (
compile_context,
CompileContext,
@ -2002,17 +2004,49 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
context = torch._C._DisableAutocast if disable_amp else nullcontext
with tracing(saved_context), compile_context(
saved_compile_context
), context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,
aot_config,
), context(), track_graph_compiling(
aot_config, "backward"
), get_metrics_context(), dynamo_timed(
"backward._backward_impl",
phase_name="entire_backward_compile",
dynamo_compile_column_us="backward_cumulative_compile_time_us",
):
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
start_ns = time.time_ns()
try:
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,
aot_config,
)
except Exception as e:
# TODO(masnesral): Populating the exception info should be automatic.
fail_type = type(e).__qualname__
fail_reason = str(e)
finally:
# TODO(masnesral): Populating time fields should be automatic.
end_ns = time.time_ns()
metrics = {
"compile_id": str(
torch._guards.CompileContext.current_compile_id()
),
"fail_type": fail_type,
"fail_reason": fail_reason,
"is_forward": False,
"start_time_us": start_ns // 1000,
"end_time_us": end_ns // 1000,
"duration_us": (end_ns - start_ns) // 1000,
"structured_logging_overhead_us": to_int_us(
torch._logging.get_structured_logging_overhead(),
),
}
get_metrics_context().update_outer(metrics)
if (
torch._functorch.config.donated_buffer

View File

@ -54,11 +54,10 @@ import torch
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import (
add_remote_cache_time_saved,
codecache_metrics,
counters,
dynamo_timed,
get_chromium_event_logger,
get_metrics_context,
)
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
@ -1152,7 +1151,7 @@ class FxGraphCache:
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
if len(meta.cached_kernel_names) > 0:
codecache_metrics["num_triton_bundles"] += 1
get_metrics_context().increment("num_triton_bundles", 1)
inductor_meta = autotune_cache.inductor_meta_from_config()
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
@ -1449,7 +1448,9 @@ class FxGraphCache:
if (time_saved_ns := compiled_graph._time_taken_ns) is not None:
cache_info["time_saved_ns"] = time_saved_ns
add_remote_cache_time_saved(time_saved_ns, is_backward)
get_metrics_context().increment(
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns

View File

@ -567,7 +567,7 @@ def compile_fx_inner(
"compile_fx_inner",
phase_name="inductor_compile",
log_pt2_compile_event=True,
fwd_only=False,
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
)
)
# NB: Why is this the dynamo_compile counter? The rule here is that

View File

@ -1949,7 +1949,7 @@ class GraphLowering(torch.fx.Interpreter):
"GraphLowering.compile_to_module",
phase_name="code_gen",
log_pt2_compile_event=True,
fwd_only=False,
dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
):
return self._compile_to_module()

View File

@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial(
"FbRemoteFxGraphCache.get",
phase_name="remote_fx_graph_cache_get",
log_pt2_compile_event=False,
fwd_only=False,
dynamo_compile_column_us="remote_fx_graph_cache_get_time_us",
)
remote_fx_cache_put_timed = functools.partial(
dynamo_timed,
"FbRemoteFxGraphCache.put",
phase_name="remote_fx_graph_cache_put",
log_pt2_compile_event=False,
fwd_only=False,
dynamo_compile_column_us="remote_fx_graph_cache_put_time_us",
)

View File

@ -135,5 +135,11 @@ try:
except AttributeError: # Compile workers only have a mock version of torch
@contextlib.contextmanager
def dynamo_timed(key, phase_name=None, fwd_only=True):
def dynamo_timed(
key,
phase_name=None,
fwd_only=True,
metadata=None,
dynamo_compile_column_us=None,
):
yield

View File

@ -1099,7 +1099,6 @@ class LazyString:
structured_logging_overhead: Dict[str, float] = defaultdict(float)
# Same principle as add_remote_cache_time_saved, but do it for structured logging
def add_structured_logging_overhead(time_spent: float) -> None:
global structured_logging_overhead
key = None