diff --git a/test/dynamo/test_metrics_context.py b/test/dynamo/test_metrics_context.py deleted file mode 100644 index 27cf63252467..000000000000 --- a/test/dynamo/test_metrics_context.py +++ /dev/null @@ -1,78 +0,0 @@ -# Owner(s): ["module: dynamo"] - -from torch._dynamo.metrics_context import MetricsContext -from torch._dynamo.test_case import run_tests, TestCase - - -class TestMetricsContext(TestCase): - def setUp(self): - super().setUp() - self.metrics = {} - - def _on_exit(self, metrics): - # Save away the metrics to be validated in the test. - self.metrics = metrics.copy() - - def test_context_exists(self): - """ - Setting a value without entering the context should raise. - """ - context = MetricsContext(self._on_exit) - with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): - context.increment("m", 1) - - with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): - context.set("m", 1) - - with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"): - context.update({"m", 1}) - - def test_nested_context(self): - """ - Only the outermost context should get an on_exit call, and it should - include everything. - """ - context = MetricsContext(self._on_exit) - with context: - with context: - context.set("m1", 1) - self.assertEqual(self.metrics, {}) - context.set("m2", 2) - self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) - - def test_set(self): - """ - Validate various ways to set metrics. - """ - with MetricsContext(self._on_exit) as context: - context.set("m1", 1) - context.set("m2", 2) - context.update({"m3": 3, "m4": 4}) - - self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4}) - - def test_set_disallow_overwrite(self): - """ - Validate set won't overwrite. - """ - with MetricsContext(self._on_exit) as context: - context.set("m1", 1) - with self.assertRaisesRegex(RuntimeError, "already been set"): - context.set("m1", 2) - - self.assertEqual(self.metrics, {"m1": 1}) - - def test_update_disallow_overwrite(self): - """ - Validate update won't overwite. - """ - with MetricsContext(self._on_exit) as context: - context.update({"m1": 1, "m2": 2}) - with self.assertRaisesRegex(RuntimeError, "already been set"): - context.update({"m1": 7, "m3": 3}) - - self.assertEqual(self.metrics, {"m1": 1, "m2": 2}) - - -if __name__ == "__main__": - run_tests() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 8a8abea173e1..c69cd02cf09c 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -159,7 +159,6 @@ class TestDynamoTimed(TestCase): '_recursive_post_grad_passes': [0.0, 0.0], '_recursive_pre_grad_passes': [0.0], 'async_compile.wait': [0.0, 0.0], - 'backward._backward_impl': [0.0], 'compile_file': [0.0, 0.0], 'compile_fx..bw_compiler': [0.0], 'compile_fx..fw_compiler_base': [0.0], @@ -175,7 +174,6 @@ class TestDynamoTimed(TestCase): """\ {'backend_compile': 0.0, 'code_gen': 0.0, - 'entire_backward_compile': 0.0, 'entire_frame_compile': 0.0, 'inductor_compile': 0.0, 'total_wall_time': 0.0}""", # noqa: B950 @@ -202,7 +200,6 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': 0, 'aot_autograd_cumulative_compile_time_us': 0, 'backend_compile_time_s': 0.0, - 'backward_cumulative_compile_time_us': None, 'cache_size': 0, 'co_filename': None, 'co_firstlineno': None, @@ -213,13 +210,12 @@ class TestDynamoTimed(TestCase): 'config_inline_inbuilt_nn_modules': False, 'config_suppress_errors': False, 'cuda_synchronize_time_us': None, - 'distributed_ephemeral_timeout_us': None, + 'distributed_ephemeral_timeout_us': 0, 'duration_us': 0, 'dynamo_compile_time_before_restart_us': 0, 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': 0, 'dynamo_time_before_restart_s': 0.0, - 'end_time_us': 100, 'entire_frame_compile_time_s': 0.0, 'fail_reason': None, 'fail_type': None, @@ -235,10 +231,9 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': True, - 'log_format_version': 2, 'non_compliant_ops': set(), 'num_triton_bundles': None, - 'remote_cache_time_saved_s': None, + 'remote_cache_time_saved_s': 0, 'remote_fx_graph_cache_get_time_ms': None, 'remote_fx_graph_cache_get_time_us': None, 'remote_fx_graph_cache_put_time_ms': None, @@ -262,7 +257,6 @@ class TestDynamoTimed(TestCase): {'accumulated_cache_size': None, 'aot_autograd_cumulative_compile_time_us': None, 'backend_compile_time_s': None, - 'backward_cumulative_compile_time_us': 0, 'cache_size': None, 'co_filename': None, 'co_firstlineno': None, @@ -279,7 +273,6 @@ class TestDynamoTimed(TestCase): 'dynamo_config': None, 'dynamo_cumulative_compile_time_us': None, 'dynamo_time_before_restart_s': None, - 'end_time_us': 100, 'entire_frame_compile_time_s': None, 'fail_reason': None, 'fail_type': None, @@ -295,7 +288,6 @@ class TestDynamoTimed(TestCase): 'inductor_compile_time_s': 0.0, 'inductor_cumulative_compile_time_us': 0, 'is_forward': False, - 'log_format_version': 2, 'non_compliant_ops': None, 'num_triton_bundles': None, 'remote_cache_time_saved_s': None, @@ -310,7 +302,7 @@ class TestDynamoTimed(TestCase): 'specialize_float': None, 'start_time': None, 'start_time_us': 100, - 'structured_logging_overhead_s': None, + 'structured_logging_overhead_s': 0.0, 'structured_logging_overhead_us': 0, 'triton_compile_time_us': None}""", # noqa: B950 ) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 3bb18a85a260..948ea618e9b2 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -30,7 +30,7 @@ import torch import torch._logging from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg -from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context +from torch._dynamo.utils import CompileTimeInstructionCounter from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._logging import structured from torch._utils_internal import ( @@ -105,9 +105,12 @@ from .symbolic_convert import ( from .trace_rules import is_numpy from .utils import ( CleanupManager, + codecache_metrics, + CompilationMetrics, counters, dynamo_timed, format_bytecode, + frame_phase_timing, gen_record_file_name, get_chromium_event_logger, increment_frame, @@ -115,8 +118,10 @@ from .utils import ( istype, LazyString, orig_code_map, + record_compilation_metrics, reset_graph_break_dup_checker, setup_compile_debug, + to_int_ms, to_int_us, troubleshooting_url, write_record_to_file, @@ -693,9 +698,7 @@ def _compile( with contextlib.ExitStack() as stack: stack.enter_context( dynamo_timed( - "_compile.compile_inner", - phase_name="entire_frame_compile", - dynamo_compile_column_us="dynamo_cumulative_compile_time_us", + "_compile.compile_inner", phase_name="entire_frame_compile" ) ) stack.enter_context( @@ -860,11 +863,9 @@ def _compile( chromium_event_log.reset() chromium_start_time = time.time_ns() chromium_event_log.log_event_start("dynamo", chromium_start_time, {}) - - metrics_context = get_metrics_context() with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( CompileContext(compile_id) - ), metrics_context: + ): restart_reasons: set[str] = set() # This is shared across restarts speculation_log = SpeculationLog() @@ -971,6 +972,7 @@ def _compile( fail_user_frame_lineno: Optional[int] = None torch._dynamo.utils.ReinplaceCounters.clear() guarded_code = None + codecache_metrics.clear() try: guarded_code = compile_inner(code, one_graph, hooks, transform) @@ -987,7 +989,6 @@ def _compile( return guarded_code except Exception as e: - # TODO(masnesral): Populating the exception info should be automatic fail_type = type(e).__qualname__ fail_reason = str(e) # NB: e's msg is mutated here to add user stack, but we DON'T want @@ -1037,34 +1038,66 @@ def _compile( if tracer: tracer.output.local_scope = {} - end_time_ns = time.time_ns() - duration_ns = end_time_ns - start_time_ns + duration_ns = time.time_ns() - start_time_ns from .utils import curr_frame frame_key = str(curr_frame) - if fail_reason is None and output is not None: + if ( + fail_reason is None + and output is not None + and frame_key in frame_phase_timing + ): guard_count = len(output.guards) shape_env_guard_count = len(output.shape_env.guards) graph_op_count = output.count_calls() graph_node_count = len(output.graph.nodes) graph_input_count = len(output.placeholders) + entire_frame_compile_time = frame_phase_timing[frame_key].get( + "entire_frame_compile", None + ) + backend_compile_time = frame_phase_timing[frame_key].get( + "backend_compile", None + ) + inductor_compile_time = frame_phase_timing[frame_key].get( + "inductor_compile", None + ) + code_gen_time = frame_phase_timing[frame_key].get("code_gen", None) non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops} compliant_custom_ops = { op.__qualname__ for op in output.compliant_custom_ops } + remote_cache_time_saved = frame_phase_timing[frame_key].get( + "remote_cache_time_saved", 0 + ) + remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get( + "remote_fx_graph_cache_get", None + ) + remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get( + "remote_fx_graph_cache_put", None + ) + num_triton_bundles = codecache_metrics.get("num_triton_bundles", None) torch._dynamo.utils.ReinplaceCounters.log() + else: guard_count = None shape_env_guard_count = None graph_op_count = None graph_node_count = None graph_input_count = None + entire_frame_compile_time = None + backend_compile_time = None + inductor_compile_time = None + code_gen_time = None non_compliant_ops = set({}) compliant_custom_ops = set({}) restart_reasons = set() # If compilation failed, the entire time is wasted dynamo_time_before_restart = duration_ns / 1e9 + remote_cache_time_saved = None + remote_fx_graph_cache_get_time = None + remote_fx_graph_cache_put_time = None + num_triton_bundles = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() @@ -1099,55 +1132,74 @@ def _compile( } config_dict = clean_for_json(config.get_config_copy()) - metrics = { - "compile_id": str(compile_id), - "frame_key": frame_key, - "co_name": code.co_name, - "co_filename": code.co_filename, - "co_firstlineno": code.co_firstlineno, - "cache_size": cache_size.num_cache_entries_with_same_id_matched_objs, - "accumulated_cache_size": cache_size.num_cache_entries, - "guard_count": guard_count, - "shape_env_guard_count": shape_env_guard_count, - "graph_op_count": graph_op_count, - "graph_node_count": graph_node_count, - "graph_input_count": graph_input_count, - # TODO(masnesral): start_time and end_time shouldn't need to be - # populated manually. - "start_time": start_time_ns / 1e9, - "fail_type": fail_type, - "fail_reason": fail_reason, - "fail_user_frame_filename": fail_user_frame_filename, - "fail_user_frame_lineno": fail_user_frame_lineno, - "non_compliant_ops": non_compliant_ops, - "compliant_custom_ops": compliant_custom_ops, - "restart_reasons": restart_reasons, - "dynamo_time_before_restart_s": dynamo_time_before_restart, - "has_guarded_code": guarded_code is not None, - "structured_logging_overhead_s": structured_logging_overhead_s, - "config_suppress_errors": config.suppress_errors, - "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, - "specialize_float": config.specialize_float, - "dynamo_config": json.dumps(config_dict), - "is_forward": True, - "start_time_us": start_time_ns // 1000, - "end_time_us": end_time_ns // 1000, - "duration_us": duration_ns // 1000, - "dynamo_compile_time_before_restart_us": to_int_us( + metrics = CompilationMetrics( + str(compile_id), + frame_key, + code.co_name, + code.co_filename, + code.co_firstlineno, + cache_size.num_cache_entries_with_same_id_matched_objs, + cache_size.num_cache_entries, + guard_count, + shape_env_guard_count, + graph_op_count, + graph_node_count, + graph_input_count, + start_time_ns / 1e9, + entire_frame_compile_time, + backend_compile_time, + inductor_compile_time, + code_gen_time, + fail_type, + fail_reason, + fail_user_frame_filename, + fail_user_frame_lineno, + non_compliant_ops, + compliant_custom_ops, + restart_reasons, + dynamo_time_before_restart, + guarded_code is not None, + remote_cache_time_saved, + structured_logging_overhead_s, + config.suppress_errors, + config.inline_inbuilt_nn_modules, + config.specialize_float, + json.dumps(config_dict), + True, # is_forward + num_triton_bundles, + to_int_ms(remote_fx_graph_cache_get_time), + to_int_ms(remote_fx_graph_cache_put_time), + start_time_us=start_time_ns // 1000, + duration_us=duration_ns // 1000, + dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time), + aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time), + inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time), + inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time), + triton_compile_time_us=None, # TODO: instrument + runtime_cudagraphify_time_us=None, # TODO: instrument in separate event + runtime_triton_autotune_time_us=None, # TODO: instrument in separate event + dynamo_compile_time_before_restart_us=to_int_us( dynamo_time_before_restart ), - "structured_logging_overhead_us": to_int_us( - structured_logging_overhead_s + cuda_synchronize_time_us=None, # TODO: instrument + distributed_ephemeral_timeout_us=to_int_us( + remote_cache_time_saved + ), # TODO: instrument more accurately + structured_logging_overhead_us=to_int_us(structured_logging_overhead_s), + remote_fx_graph_cache_get_time_us=to_int_us( + remote_fx_graph_cache_get_time ), - } - metrics_context.update_outer(metrics) + remote_fx_graph_cache_put_time_us=to_int_us( + remote_fx_graph_cache_put_time + ), + ) + record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() + chromium_event_log.log_event_end( + "dynamo", time.time_ns(), {}, chromium_start_time, True + ) # === END WARNING WARNING WARNING === - chromium_event_log.log_event_end( - "dynamo", time.time_ns(), {}, chromium_start_time, True - ) - class ConvertFrame: def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None: diff --git a/torch/_dynamo/metrics_context.py b/torch/_dynamo/metrics_context.py deleted file mode 100644 index a51ad52fe9b0..000000000000 --- a/torch/_dynamo/metrics_context.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Any, Callable, Dict, Optional, Type -from typing_extensions import TypeAlias - - -OnExitType: TypeAlias = Callable[[Dict[str, Any]], None] - - -class MetricsContext: - def __init__(self, on_exit: OnExitType): - """ - Use this class as a contextmanager to create a context under which to accumulate - a set of metrics, e.g., metrics gathered during a compilation. On exit of the - contextmanager, call the provided 'on_exit' function and pass a dictionary of - all metrics set during the lifetime of the contextmanager. - """ - self._on_exit = on_exit - self._metrics: Dict[str, Any] = {} - self._level = 0 - - def __enter__(self) -> "MetricsContext": - """ - Initialize metrics recording. - """ - if self._level == 0: - # In case of recursion, track at the outermost context. - self._metrics = {} - - self._level += 1 - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - _traceback: Any, - ) -> None: - """ - At exit, call the provided on_exit function. - """ - self._level -= 1 - assert self._level >= 0 - if self._level == 0: - self._on_exit(self._metrics) - - def in_progress(self) -> bool: - """ - True if we've entered the context. - """ - return self._level > 0 - - def increment(self, metric: str, value: int) -> None: - """ - Increment a metric by a given amount. - """ - if self._level == 0: - raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext") - if metric not in self._metrics: - self._metrics[metric] = 0 - self._metrics[metric] += value - - def set(self, metric: str, value: Any) -> None: - """ - Set a metric to a given value. Raises if the metric has been assigned previously - in the current context. - """ - if self._level == 0: - raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext") - if metric in self._metrics: - raise RuntimeError( - f"Metric '{metric}' has already been set in the current context" - ) - self._metrics[metric] = value - - def update(self, values: Dict[str, Any]) -> None: - """ - Set multiple metrics directly. This method does NOT increment. Raises if any - metric has been assigned previously in the current context. - """ - if self._level == 0: - raise RuntimeError("Cannot update metrics outside of a MetricsContext") - existing = self._metrics.keys() & values.keys() - if existing: - raise RuntimeError( - f"Metric(s) {existing} have already been set in the current context" - ) - self._metrics.update(values) - - def update_outer(self, values: Dict[str, Any]) -> None: - """ - Update, but only when at the outermost context. - """ - if self._level == 0: - raise RuntimeError("Cannot update metrics outside of a MetricsContext") - if self._level == 1: - self.update(values) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 48e226d6b289..4a35bb51af02 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1394,7 +1394,6 @@ class OutputGraph: "OutputGraph.call_user_compiler", phase_name="backend_compile", log_pt2_compile_event=True, - dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", ): return self._call_user_compiler(gm) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9319ddfebe75..65095e7daa9d 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -43,7 +43,6 @@ from typing import ( DefaultDict, Deque, Dict, - Generator, Iterable, Iterator, KeysView, @@ -72,7 +71,6 @@ from torch._C import ( _push_on_torch_function_stack, ) from torch._dispatch.python import enable_python_dispatcher -from torch._dynamo.metrics_context import MetricsContext from torch._guards import Source, TracingContext from torch._subclasses.meta_utils import is_sparse_compressed from torch._utils_internal import ( @@ -141,10 +139,12 @@ log = logging.getLogger(__name__) # profiling compilation time by function compilation_time_metrics: Dict[str, List[float]] = {} -# This supports calculate_time_spent(), which reports cumulative times -# across the process for any "phase" populated by dynamo_timed. Reset if -# reset_frame_count() is called. -cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float) +# profiling compilation time by frame phase +frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict( + lambda: collections.defaultdict(float) +) + +codecache_metrics: Counter[str] = collections.Counter() timer_counter = itertools.count() @@ -220,7 +220,7 @@ def increment_frame() -> None: # Note: Called for you by dynamo - you almost never ever want to invoke this yourself. def reset_frame_count() -> None: global curr_frame - cumulative_time_spent_ns.clear() + frame_phase_timing.clear() compilation_time_metrics.clear() curr_frame = 0 @@ -233,16 +233,25 @@ def increment_op_count(cnt: int) -> None: op_count += cnt -# Get the total time in seconds for each "phase" +# Calculate total time spent so far for each phase # For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806} def calculate_time_spent() -> Dict[str, float]: + total_wall_time = 0.0 total_by_key = {} - for phase, timing in cumulative_time_spent_ns.items(): - total_by_key[phase] = timing / 1e9 + for timings in frame_phase_timing.values(): + total_wall_time += timings.get( + "entire_frame_compile", timings.get("inductor_compile", 0) + ) + + for key, timing in timings.items(): + if key not in total_by_key: + total_by_key[key] = timing + else: + total_by_key[key] += timing + + if total_by_key: + total_by_key["total_wall_time"] = total_wall_time - total_by_key["total_wall_time"] = total_by_key.get( - "entire_frame_compile", 0 - ) + total_by_key.get("entire_backward_compile", 0) return total_by_key @@ -261,124 +270,188 @@ def print_time_report() -> None: print(out) -# Use the following singleton to capture and log CompilationMetrics. Entering the context -# manager allocates a new record to be logged when it exits. (You should not need to use -# this directly unless you introduce a new code path where compilation metrics would be -# gathered). While compiling, use the setters or timer in MetricsContext to update fields -# in the current context. For example: -# -# To set a single field once (use overwrite=True to overwrite): -# get_metrics_context().set("metric_name", value) -# -# To set multiple fields at once (use overwrite=True to overwrite): -# get_metrics_context().update({"name1": val1, "name2": val2}) -# -# To increment an integer field: -# get_metrics_context().increment("metric_name", value) -# -# To record execution time, MetricsContext works with dynamo_timed: -# def foo(...): -# # Updates the "metric_us" field. -# with dynamo_timed("metric", dynamo_compile_column_us="metric_us") -# ... -# -_METRICS_CONTEXT: MetricsContext +def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None: + frame_phase_timing[key][phase_name] += time_spent -def get_metrics_context() -> MetricsContext: - return _METRICS_CONTEXT +# Use frame_phase_timing to record remote_cache_time_saved +# This follows the same principles of key as the other frame phase timings, +# but is incremented by FxGraphCache (and later AOTAutogradCache) directly +def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None: + key = None + if is_backward: + # Use compile id as the frame key for backwards compilation + key = str(torch._guards.CompileContext.current_compile_id()) + else: + key = str(curr_frame) + # Convert to seconds (as a float) + time_saved = time_saved_ns / 1e9 + _add_time_spent(key, "remote_cache_time_saved", time_saved) + + +# dynamo_timed is a context manager +# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics +# where the key is the functions name. +# For example: +# +# def _foo(...): +# with dynamo_timed("_foo"): +# ... +# +# Would show up as an entry in our timing dict: +# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) +# This is extremely useful for granular debugging. +# +# Although it is tempting to use dynamo_timed as a decorator, please do not. +# In its decorator form it makes cProfile traces less useful as dynamo_timed +# suddenly becomes a bottleneck for lots of function calls (as only one parent +# pointer is recorded). +# +# For a higher-level mode, pass a phase_name into dynamo_timed +# phase_names record an extra record into a separate compilation timing structure, +# one keyed on frame+name rather than function. +# The frame is incremented outside of this function, in def increment_frame() above. +# `fwd_only` is used to identify if this phase or function is only called +# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`. +# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs. @contextmanager def dynamo_timed( key: str, - # TODO(masneral): Deprecate this param. phase_name: Optional[str] = None, - log_pt2_compile_event: bool = False, - # TODO(masnesral): fwd_only is ignored. Remove it. + log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event fwd_only: bool = True, - metadata: Optional[Dict[str, object]] = None, - dynamo_compile_column_us: Optional[str] = None, -) -> Generator[Any, None, None]: - """ - dynamo_timed is a context manager - By wrapping a function in dynamo_timed, we can get a few things: - - 1) Log timings to pt2_compile_events. - 2) Log timings to CompilationMetrics (dynamo_compile). - 3) Chromium events. - 4) Storing a record in compilation_time_metrics - For example: - - def _foo(...): - with dynamo_timed("_foo"): - ... - - Would show up as an entry in our timing dict: - OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])]) - This is extremely useful for granular debugging. - - Although it is tempting to use dynamo_timed as a decorator, please do not. - In its decorator form it makes cProfile traces less useful as dynamo_timed - suddenly becomes a bottleneck for lots of function calls (as only one parent - pointer is recorded). - - Params: - - key: key into compile_time_metrics. If phase_name is not provided, this is - also the event name used for pt2_compile_events logs and chromium events. - - phase_name: Optional override for the event name. - - log_pt2_compile_event: Whether to log a pt2 compile event internally. - - metadata: Extra metadata to put in pt2_compile_events. - - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics - field to be logged to dyname_compile column. We expect all columns to be _us; - therefore, the field name must end with "_us". - """ - # We're standardizing on microseconds for dynamo_compile timings. - if dynamo_compile_column_us is not None: - assert dynamo_compile_column_us.endswith("_us") - - if phase_name: - event_name = phase_name - fn_name = key - else: - event_name = key - fn_name = None - +): + chromium_log: ChromiumEventLogger = get_chromium_event_logger() if key not in compilation_time_metrics: compilation_time_metrics[key] = [] - event_metadata = {} - if metadata: - event_metadata.update(metadata) - if fn_name: - event_metadata.update({"fn_name": fn_name}) - - chromium_log: ChromiumEventLogger = get_chromium_event_logger() + fail_type: Optional[str] = None + fail_reason: Optional[str] = None + time_spent = float("-inf") start_ns = time.time_ns() - chromium_log.log_event_start(event_name, start_ns, event_metadata) - try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): + t0 = time.time() + if phase_name: + chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key}) + else: + chromium_log.log_event_start(key, start_ns, {}) yield + time_spent = time.time() - t0 + compilation_time_metrics[key].append(time_spent) + except Exception as e: + fail_type = str(type(e)) + fail_reason = str(e) + raise finally: end_ns = time.time_ns() - time_spent_ns = end_ns - start_ns - compilation_time_metrics[key].append(time_spent_ns / 1e9) - chromium_log.log_event_end( - event_name, end_ns, {}, start_ns, log_pt2_compile_event - ) - if dynamo_compile_column_us: - metrics_context = get_metrics_context() - if metrics_context.in_progress(): - metrics_context.increment( - dynamo_compile_column_us, time_spent_ns // 1000 - ) - # TODO: the events that we capture in calculate_time_spent() seem a little - # arbitrary. Currently, it's only those fields that are present in - # CompilationMetrics (but note that we accumulate by the associated event - # name, not the field name in CompilationMetrics). Do we want to keep it - # this way? - cumulative_time_spent_ns[event_name] += time_spent_ns + # Always log the end event even on exception + if phase_name: + chromium_log.log_event_end( + phase_name, + end_ns, + {}, + start_ns, + log_pt2_compile_event, + ) + else: + chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event) + # Only record backward compilation metrics if phase_name is not None! + if phase_name: + frame_key = str(curr_frame) + # fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch. + # use frame_key as time aggregation key. + if fwd_only and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + else: + # fwd + bwd compilation stages: inductor_compile, code_gen. + # use frame_key as time aggregation key for fwd graphs; + # use compile_id as time aggregation key for bwd graphs. + if torch._guards.TracingContext.try_get() is not None: + aot_graph_name = str( + torch._guards.TracingContext.get().aot_graph_name + ) + if ( + "forward" in aot_graph_name or "inference" in aot_graph_name + ) and fail_type is None: + _add_time_spent(frame_key, phase_name, time_spent) + elif "backward" in aot_graph_name: + compile_id = str( + torch._guards.CompileContext.current_compile_id() + ) + if fail_type is None: + _add_time_spent(compile_id, phase_name, time_spent) + + # log backward compilation metrics at the end of `inductor_compile` of bwd graph, + # one record for one bwd graph. + if phase_name == "inductor_compile": + if fail_type is None: + inductor_compile_time = frame_phase_timing[ + compile_id + ].get("inductor_compile", None) + code_gen_time = frame_phase_timing[compile_id].get( + "code_gen", None + ) + remote_cache_time_saved = frame_phase_timing[ + compile_id + ].get("remote_cache_time_saved", None) + remote_fx_graph_cache_get_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_get", None) + remote_fx_graph_cache_put_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_put", None) + else: + inductor_compile_time = None + code_gen_time = None + remote_cache_time_saved = None + remote_fx_graph_cache_get_time = None + remote_fx_graph_cache_put_time = None + structured_logging_overhead_s = ( + torch._logging.get_structured_logging_overhead() + ) + metrics = CompilationMetrics( + compile_id=compile_id, + inductor_compile_time_s=inductor_compile_time, + code_gen_time_s=code_gen_time, + fail_type=fail_type, + fail_reason=fail_reason, + remote_cache_time_saved_s=remote_cache_time_saved, + structured_logging_overhead_s=structured_logging_overhead_s, + is_forward=False, # is_forward + num_triton_bundles=codecache_metrics.get( + "num_triton_bundles", None + ), + remote_fx_graph_cache_get_time_ms=to_int_ms( + remote_fx_graph_cache_get_time + ), + remote_fx_graph_cache_put_time_ms=to_int_ms( + remote_fx_graph_cache_put_time + ), + start_time_us=start_ns // 1000, + duration_us=(end_ns - start_ns) // 1000, + inductor_cumulative_compile_time_us=to_int_us( + inductor_compile_time + ), + inductor_code_gen_cumulative_compile_time_us=to_int_us( + code_gen_time + ), + distributed_ephemeral_timeout_us=to_int_us( + remote_cache_time_saved + ), # TODO: instrument more accurately + structured_logging_overhead_us=to_int_us( + structured_logging_overhead_s + ), + remote_fx_graph_cache_get_time_us=to_int_us( + remote_fx_graph_cache_get_time + ), + remote_fx_graph_cache_put_time_us=to_int_us( + remote_fx_graph_cache_put_time + ), + ) + record_compilation_metrics(metrics) @overload @@ -793,11 +866,6 @@ def to_int_us(v: Optional[float]) -> Optional[int]: return None if v is None else int(v * 1_000_000) -# Version field added to every log. Increment to make it easier to distinguish new -# vs. old entries when you make a substantive change to how the logs are populated. -LOG_FORMAT_VERSION = 2 - - @dataclasses.dataclass class CompilationMetrics: compile_id: Optional[str] = None @@ -845,18 +913,15 @@ class CompilationMetrics: aot_autograd_cumulative_compile_time_us: Optional[int] = None inductor_cumulative_compile_time_us: Optional[int] = None inductor_code_gen_cumulative_compile_time_us: Optional[int] = None - triton_compile_time_us: Optional[int] = None # TODO: instrument - runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument - runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument + triton_compile_time_us: Optional[int] = None + runtime_cudagraphify_time_us: Optional[int] = None + runtime_triton_autotune_time_us: Optional[int] = None dynamo_compile_time_before_restart_us: Optional[int] = None - cuda_synchronize_time_us: Optional[int] = None # TODO: instrument + cuda_synchronize_time_us: Optional[int] = None distributed_ephemeral_timeout_us: Optional[int] = None structured_logging_overhead_us: Optional[int] = None remote_fx_graph_cache_get_time_us: Optional[int] = None remote_fx_graph_cache_put_time_us: Optional[int] = None - backward_cumulative_compile_time_us: Optional[int] = None - end_time_us: Optional[int] = None - log_format_version: int = LOG_FORMAT_VERSION DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -904,32 +969,8 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics): ) -def record_compilation_metrics(metrics: Dict[str, Any]): - # TODO: Temporary; populate legacy fields from their replacements. - # Remove when we decide we can really deprecate them. - def us_to_s(field): - metric = metrics.get(field, None) - return metric / 1e6 if metric is not None else None - - def us_to_ms(field): - metric = metrics.get(field, None) - return metric // 1000 if metric is not None else None - - legacy_metrics = { - "entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"), - "backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"), - "inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"), - "code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"), - "remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"), - "remote_fx_graph_cache_get_time_ms": us_to_ms( - "remote_fx_graph_cache_get_time_us" - ), - "remote_fx_graph_cache_put_time_ms": us_to_ms( - "remote_fx_graph_cache_put_time_us" - ), - } - - compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics}) +def record_compilation_metrics(compilation_metrics: CompilationMetrics): + global _compilation_metrics _compilation_metrics.append(compilation_metrics) if compilation_metrics.is_forward: name = "compilation_metrics" @@ -938,7 +979,10 @@ def record_compilation_metrics(metrics: Dict[str, Any]): name = "bwd_compilation_metrics" torch._logging.trace_structured( name, - lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()}, + lambda: { + k: list(v) if isinstance(v, set) else v + for k, v in dataclasses.asdict(compilation_metrics).items() + }, # NB: Because compilation metrics *includes* the logging overhead time, # we can't both *measure* the logging overhead of compilation metrics # without making it inconsistent with compilation metrics itself, so @@ -949,10 +993,6 @@ def record_compilation_metrics(metrics: Dict[str, Any]): log_compilation_event(compilation_metrics) -# record_compilation_metrics is called by the singleton MetricsContext exit handler. -_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics) - - def set_compilation_metrics_limit(new_size: int) -> None: global _compilation_metrics while len(_compilation_metrics) > new_size: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 13d7a1460d56..bb6740882937 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -632,7 +632,7 @@ class AOTAutogradCache: ) # TODO: should we use the same field for remote cache time saved for both # FXGraphCache and AOTAutogradCache? - # get_metrics_context().increment(...) + # add_remote_cache_time_saved(time_saved_ns, is_backward=False) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index f421f30cb44a..206bab569240 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -10,7 +10,6 @@ import builtins import collections import itertools import pprint -import time from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps @@ -19,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us from torch._guards import ( compile_context, CompileContext, @@ -2004,49 +2002,17 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa context = torch._C._DisableAutocast if disable_amp else nullcontext with tracing(saved_context), compile_context( saved_compile_context - ), context(), track_graph_compiling( - aot_config, "backward" - ), get_metrics_context(), dynamo_timed( - "backward._backward_impl", - phase_name="entire_backward_compile", - dynamo_compile_column_us="backward_cumulative_compile_time_us", - ): - fail_type: Optional[str] = None - fail_reason: Optional[str] = None - start_ns = time.time_ns() - try: - CompiledFunction.compiled_bw = aot_config.bw_compiler( - bw_module, placeholder_list + ), context(), track_graph_compiling(aot_config, "backward"): + CompiledFunction.compiled_bw = aot_config.bw_compiler( + bw_module, placeholder_list + ) + # Maybe save cache entry + if try_save_cache_entry is not None: + try_save_cache_entry( + CompiledFunction.compiled_bw, + fw_metadata, + aot_config, ) - # Maybe save cache entry - if try_save_cache_entry is not None: - try_save_cache_entry( - CompiledFunction.compiled_bw, - fw_metadata, - aot_config, - ) - except Exception as e: - # TODO(masnesral): Populating the exception info should be automatic. - fail_type = type(e).__qualname__ - fail_reason = str(e) - finally: - # TODO(masnesral): Populating time fields should be automatic. - end_ns = time.time_ns() - metrics = { - "compile_id": str( - torch._guards.CompileContext.current_compile_id() - ), - "fail_type": fail_type, - "fail_reason": fail_reason, - "is_forward": False, - "start_time_us": start_ns // 1000, - "end_time_us": end_ns // 1000, - "duration_us": (end_ns - start_ns) // 1000, - "structured_logging_overhead_us": to_int_us( - torch._logging.get_structured_logging_overhead(), - ), - } - get_metrics_context().update_outer(metrics) if ( torch._functorch.config.donated_buffer diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 821d67514a2b..2371ae38bfbf 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -54,10 +54,11 @@ import torch import torch.distributed as dist from torch import SymInt, Tensor from torch._dynamo.utils import ( + add_remote_cache_time_saved, + codecache_metrics, counters, dynamo_timed, get_chromium_event_logger, - get_metrics_context, ) from torch._inductor import config, exc, metrics from torch._inductor.codegen.cuda import cuda_env @@ -1151,7 +1152,7 @@ class FxGraphCache: "inductor_compile", cached_kernel_names=meta.cached_kernel_names ) if len(meta.cached_kernel_names) > 0: - get_metrics_context().increment("num_triton_bundles", 1) + codecache_metrics["num_triton_bundles"] += 1 inductor_meta = autotune_cache.inductor_meta_from_config() AutotuneCacheBundler.begin_compile(inductor_meta, code=code) @@ -1448,9 +1449,7 @@ class FxGraphCache: if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns - get_metrics_context().increment( - "distributed_ephemeral_timeout_us", time_saved_ns // 1000 - ) + add_remote_cache_time_saved(time_saved_ns, is_backward) if ( ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( time_saved_ns diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index cd997c30e914..1cd8d09043a8 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -567,7 +567,7 @@ def compile_fx_inner( "compile_fx_inner", phase_name="inductor_compile", log_pt2_compile_event=True, - dynamo_compile_column_us="inductor_cumulative_compile_time_us", + fwd_only=False, ) ) # NB: Why is this the dynamo_compile counter? The rule here is that diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 390be0c2c0c3..ad4a03ec8afc 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1950,7 +1950,7 @@ class GraphLowering(torch.fx.Interpreter): "GraphLowering.compile_to_module", phase_name="code_gen", log_pt2_compile_event=True, - dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", + fwd_only=False, ): return self._compile_to_module() diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 4e53b920a423..d03599500647 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial( "FbRemoteFxGraphCache.get", phase_name="remote_fx_graph_cache_get", log_pt2_compile_event=False, - dynamo_compile_column_us="remote_fx_graph_cache_get_time_us", + fwd_only=False, ) remote_fx_cache_put_timed = functools.partial( dynamo_timed, "FbRemoteFxGraphCache.put", phase_name="remote_fx_graph_cache_put", log_pt2_compile_event=False, - dynamo_compile_column_us="remote_fx_graph_cache_put_time_us", + fwd_only=False, ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index ff151e6323df..4eb7af60047c 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -135,13 +135,7 @@ try: except AttributeError: # Compile workers only have a mock version of torch @contextlib.contextmanager - def dynamo_timed( - key, - phase_name=None, - fwd_only=True, - metadata=None, - dynamo_compile_column_us=None, - ): + def dynamo_timed(key, phase_name=None, fwd_only=True): yield diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index a31ea0c198c2..70bbb27bfa26 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1099,6 +1099,7 @@ class LazyString: structured_logging_overhead: Dict[str, float] = defaultdict(float) +# Same principle as add_remote_cache_time_saved, but do it for structured logging def add_structured_logging_overhead(time_spent: float) -> None: global structured_logging_overhead key = None