mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 17:24:59 +08:00
[logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849)
Here's the overview: There's a new contextmanager singleton called MetricsContext. Entering the MetricsContext is how we demarcate the boundary on which we'll create a single CompilationMetrics object, and therefore, a single dynamo_compile log entry. While we're inside the MetricsContext, we can update/set many different metrics. Most importantly: `dynamo_timed` can also update the in-progress MetricsContext. In the proposal here, we tell `dynamo_timed` that we want it to do so by providing the name of the MetricsContext field to increment. There can be many `dynamo_timed` calls in different parts of the code updating different fields. Then when the MetricsContext exits, that's when the logging of everything gathered finally happens. One potential footgun is trying to use `dynamo_timed` when we haven't entered the MetricsContext, but we assert on that problem. Another problem is that we re-enter the context recursively, but we watch for that and do the logging only when the outermost exits. Some specifics: * Introduce MetricsContext - a context manager that on exit, records the CompilationMetrics (which also logs to dynamo_compile). * Completely remove the concept of frame_phase_timing. Instead, update the MetricsContext during compilation, either directly or via dynamo_timed. * Remove some globals we previously used to accumulate counters to later populate a CompilationMetrics. We use CompilationMetrics set/update/increment APIs instead. * `record_compilation_metrics` is now called on exit from MetricsContext. * Populate legacy CompilationMetrics fields right before logging, inside `record_compilation_metrics`. * Remove the one-off `add_remote_cache_time_saved` helper; capture that timing directly into the MetricsContext. And specifically, several changes to dynamo_timed: * "Modernize" the parameters and update all callsites accordingly. * Move the backwards logging of the CompilationMetrics to the backwards compile location. * Add a parameter for which CompilationMetrics field to update Pull Request resolved: https://github.com/pytorch/pytorch/pull/139849 Approved by: https://github.com/ezyang ghstack dependencies: #140094
This commit is contained in:
committed by
PyTorch MergeBot
parent
565a7942ee
commit
cb15c15157
78
test/dynamo/test_metrics_context.py
Normal file
78
test/dynamo/test_metrics_context.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
from torch._dynamo.metrics_context import MetricsContext
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
class TestMetricsContext(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.metrics = {}
|
||||
|
||||
def _on_exit(self, metrics):
|
||||
# Save away the metrics to be validated in the test.
|
||||
self.metrics = metrics.copy()
|
||||
|
||||
def test_context_exists(self):
|
||||
"""
|
||||
Setting a value without entering the context should raise.
|
||||
"""
|
||||
context = MetricsContext(self._on_exit)
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.increment("m", 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.set("m", 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.update({"m", 1})
|
||||
|
||||
def test_nested_context(self):
|
||||
"""
|
||||
Only the outermost context should get an on_exit call, and it should
|
||||
include everything.
|
||||
"""
|
||||
context = MetricsContext(self._on_exit)
|
||||
with context:
|
||||
with context:
|
||||
context.set("m1", 1)
|
||||
self.assertEqual(self.metrics, {})
|
||||
context.set("m2", 2)
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
Validate various ways to set metrics.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.set("m1", 1)
|
||||
context.set("m2", 2)
|
||||
context.update({"m3": 3, "m4": 4})
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
|
||||
|
||||
def test_set_disallow_overwrite(self):
|
||||
"""
|
||||
Validate set won't overwrite.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.set("m1", 1)
|
||||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
context.set("m1", 2)
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1})
|
||||
|
||||
def test_update_disallow_overwrite(self):
|
||||
"""
|
||||
Validate update won't overwite.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.update({"m1": 1, "m2": 2})
|
||||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
context.update({"m1": 7, "m3": 3})
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -159,6 +159,7 @@ class TestDynamoTimed(TestCase):
|
||||
'_recursive_post_grad_passes': [0.0, 0.0],
|
||||
'_recursive_pre_grad_passes': [0.0],
|
||||
'async_compile.wait': [0.0, 0.0],
|
||||
'backward._backward_impl': [0.0],
|
||||
'compile_file': [0.0, 0.0],
|
||||
'compile_fx.<locals>.bw_compiler': [0.0],
|
||||
'compile_fx.<locals>.fw_compiler_base': [0.0],
|
||||
@ -174,6 +175,7 @@ class TestDynamoTimed(TestCase):
|
||||
"""\
|
||||
{'backend_compile': 0.0,
|
||||
'code_gen': 0.0,
|
||||
'entire_backward_compile': 0.0,
|
||||
'entire_frame_compile': 0.0,
|
||||
'inductor_compile': 0.0,
|
||||
'total_wall_time': 0.0}""", # noqa: B950
|
||||
@ -200,6 +202,7 @@ class TestDynamoTimed(TestCase):
|
||||
{'accumulated_cache_size': 0,
|
||||
'aot_autograd_cumulative_compile_time_us': 0,
|
||||
'backend_compile_time_s': 0.0,
|
||||
'backward_cumulative_compile_time_us': None,
|
||||
'cache_size': 0,
|
||||
'co_filename': None,
|
||||
'co_firstlineno': None,
|
||||
@ -210,12 +213,13 @@ class TestDynamoTimed(TestCase):
|
||||
'config_inline_inbuilt_nn_modules': False,
|
||||
'config_suppress_errors': False,
|
||||
'cuda_synchronize_time_us': None,
|
||||
'distributed_ephemeral_timeout_us': 0,
|
||||
'distributed_ephemeral_timeout_us': None,
|
||||
'duration_us': 0,
|
||||
'dynamo_compile_time_before_restart_us': 0,
|
||||
'dynamo_config': None,
|
||||
'dynamo_cumulative_compile_time_us': 0,
|
||||
'dynamo_time_before_restart_s': 0.0,
|
||||
'end_time_us': 100,
|
||||
'entire_frame_compile_time_s': 0.0,
|
||||
'fail_reason': None,
|
||||
'fail_type': None,
|
||||
@ -231,9 +235,10 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_compile_time_s': 0.0,
|
||||
'inductor_cumulative_compile_time_us': 0,
|
||||
'is_forward': True,
|
||||
'log_format_version': 2,
|
||||
'non_compliant_ops': set(),
|
||||
'num_triton_bundles': None,
|
||||
'remote_cache_time_saved_s': 0,
|
||||
'remote_cache_time_saved_s': None,
|
||||
'remote_fx_graph_cache_get_time_ms': None,
|
||||
'remote_fx_graph_cache_get_time_us': None,
|
||||
'remote_fx_graph_cache_put_time_ms': None,
|
||||
@ -257,6 +262,7 @@ class TestDynamoTimed(TestCase):
|
||||
{'accumulated_cache_size': None,
|
||||
'aot_autograd_cumulative_compile_time_us': None,
|
||||
'backend_compile_time_s': None,
|
||||
'backward_cumulative_compile_time_us': 0,
|
||||
'cache_size': None,
|
||||
'co_filename': None,
|
||||
'co_firstlineno': None,
|
||||
@ -273,6 +279,7 @@ class TestDynamoTimed(TestCase):
|
||||
'dynamo_config': None,
|
||||
'dynamo_cumulative_compile_time_us': None,
|
||||
'dynamo_time_before_restart_s': None,
|
||||
'end_time_us': 100,
|
||||
'entire_frame_compile_time_s': None,
|
||||
'fail_reason': None,
|
||||
'fail_type': None,
|
||||
@ -288,6 +295,7 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_compile_time_s': 0.0,
|
||||
'inductor_cumulative_compile_time_us': 0,
|
||||
'is_forward': False,
|
||||
'log_format_version': 2,
|
||||
'non_compliant_ops': None,
|
||||
'num_triton_bundles': None,
|
||||
'remote_cache_time_saved_s': None,
|
||||
@ -302,7 +310,7 @@ class TestDynamoTimed(TestCase):
|
||||
'specialize_float': None,
|
||||
'start_time': None,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
'structured_logging_overhead_s': None,
|
||||
'structured_logging_overhead_us': 0,
|
||||
'triton_compile_time_us': None}""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ import torch
|
||||
import torch._logging
|
||||
from torch._C._dynamo.guards import GlobalStateGuard
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
from torch._logging import structured
|
||||
from torch._utils_internal import (
|
||||
@ -105,12 +105,9 @@ from .symbolic_convert import (
|
||||
from .trace_rules import is_numpy
|
||||
from .utils import (
|
||||
CleanupManager,
|
||||
codecache_metrics,
|
||||
CompilationMetrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
format_bytecode,
|
||||
frame_phase_timing,
|
||||
gen_record_file_name,
|
||||
get_chromium_event_logger,
|
||||
increment_frame,
|
||||
@ -118,10 +115,8 @@ from .utils import (
|
||||
istype,
|
||||
LazyString,
|
||||
orig_code_map,
|
||||
record_compilation_metrics,
|
||||
reset_graph_break_dup_checker,
|
||||
setup_compile_debug,
|
||||
to_int_ms,
|
||||
to_int_us,
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
@ -699,7 +694,9 @@ def _compile(
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
dynamo_timed(
|
||||
"_compile.compile_inner", phase_name="entire_frame_compile"
|
||||
"_compile.compile_inner",
|
||||
phase_name="entire_frame_compile",
|
||||
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
@ -864,9 +861,11 @@ def _compile(
|
||||
chromium_event_log.reset()
|
||||
chromium_start_time = time.time_ns()
|
||||
chromium_event_log.log_event_start("dynamo", chromium_start_time, {})
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(
|
||||
CompileContext(compile_id)
|
||||
):
|
||||
), metrics_context:
|
||||
restart_reasons: set[str] = set()
|
||||
# This is shared across restarts
|
||||
mutated_closure_cell_ids: Set[int] = set()
|
||||
@ -974,7 +973,6 @@ def _compile(
|
||||
fail_user_frame_lineno: Optional[int] = None
|
||||
torch._dynamo.utils.ReinplaceCounters.clear()
|
||||
guarded_code = None
|
||||
codecache_metrics.clear()
|
||||
try:
|
||||
guarded_code = compile_inner(code, one_graph, hooks, transform)
|
||||
|
||||
@ -991,6 +989,7 @@ def _compile(
|
||||
|
||||
return guarded_code
|
||||
except Exception as e:
|
||||
# TODO(masnesral): Populating the exception info should be automatic
|
||||
fail_type = type(e).__qualname__
|
||||
fail_reason = str(e)
|
||||
# NB: e's msg is mutated here to add user stack, but we DON'T want
|
||||
@ -1040,66 +1039,34 @@ def _compile(
|
||||
if tracer:
|
||||
tracer.output.local_scope = {}
|
||||
|
||||
duration_ns = time.time_ns() - start_time_ns
|
||||
end_time_ns = time.time_ns()
|
||||
duration_ns = end_time_ns - start_time_ns
|
||||
|
||||
from .utils import curr_frame
|
||||
|
||||
frame_key = str(curr_frame)
|
||||
if (
|
||||
fail_reason is None
|
||||
and output is not None
|
||||
and frame_key in frame_phase_timing
|
||||
):
|
||||
if fail_reason is None and output is not None:
|
||||
guard_count = len(output.guards)
|
||||
shape_env_guard_count = len(output.shape_env.guards)
|
||||
graph_op_count = output.count_calls()
|
||||
graph_node_count = len(output.graph.nodes)
|
||||
graph_input_count = len(output.placeholders)
|
||||
entire_frame_compile_time = frame_phase_timing[frame_key].get(
|
||||
"entire_frame_compile", None
|
||||
)
|
||||
backend_compile_time = frame_phase_timing[frame_key].get(
|
||||
"backend_compile", None
|
||||
)
|
||||
inductor_compile_time = frame_phase_timing[frame_key].get(
|
||||
"inductor_compile", None
|
||||
)
|
||||
code_gen_time = frame_phase_timing[frame_key].get("code_gen", None)
|
||||
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
|
||||
compliant_custom_ops = {
|
||||
op.__qualname__ for op in output.compliant_custom_ops
|
||||
}
|
||||
remote_cache_time_saved = frame_phase_timing[frame_key].get(
|
||||
"remote_cache_time_saved", 0
|
||||
)
|
||||
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
|
||||
"remote_fx_graph_cache_get", None
|
||||
)
|
||||
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
|
||||
"remote_fx_graph_cache_put", None
|
||||
)
|
||||
num_triton_bundles = codecache_metrics.get("num_triton_bundles", None)
|
||||
torch._dynamo.utils.ReinplaceCounters.log()
|
||||
|
||||
else:
|
||||
guard_count = None
|
||||
shape_env_guard_count = None
|
||||
graph_op_count = None
|
||||
graph_node_count = None
|
||||
graph_input_count = None
|
||||
entire_frame_compile_time = None
|
||||
backend_compile_time = None
|
||||
inductor_compile_time = None
|
||||
code_gen_time = None
|
||||
non_compliant_ops = set({})
|
||||
compliant_custom_ops = set({})
|
||||
restart_reasons = set()
|
||||
# If compilation failed, the entire time is wasted
|
||||
dynamo_time_before_restart = duration_ns / 1e9
|
||||
remote_cache_time_saved = None
|
||||
remote_fx_graph_cache_get_time = None
|
||||
remote_fx_graph_cache_put_time = None
|
||||
num_triton_bundles = None
|
||||
|
||||
structured_logging_overhead_s = (
|
||||
torch._logging.get_structured_logging_overhead()
|
||||
@ -1134,74 +1101,55 @@ def _compile(
|
||||
}
|
||||
|
||||
config_dict = clean_for_json(config.get_config_copy())
|
||||
metrics = CompilationMetrics(
|
||||
str(compile_id),
|
||||
frame_key,
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
cache_size.num_cache_entries,
|
||||
guard_count,
|
||||
shape_env_guard_count,
|
||||
graph_op_count,
|
||||
graph_node_count,
|
||||
graph_input_count,
|
||||
start_time_ns / 1e9,
|
||||
entire_frame_compile_time,
|
||||
backend_compile_time,
|
||||
inductor_compile_time,
|
||||
code_gen_time,
|
||||
fail_type,
|
||||
fail_reason,
|
||||
fail_user_frame_filename,
|
||||
fail_user_frame_lineno,
|
||||
non_compliant_ops,
|
||||
compliant_custom_ops,
|
||||
restart_reasons,
|
||||
dynamo_time_before_restart,
|
||||
guarded_code is not None,
|
||||
remote_cache_time_saved,
|
||||
structured_logging_overhead_s,
|
||||
config.suppress_errors,
|
||||
config.inline_inbuilt_nn_modules,
|
||||
config.specialize_float,
|
||||
json.dumps(config_dict),
|
||||
True, # is_forward
|
||||
num_triton_bundles,
|
||||
to_int_ms(remote_fx_graph_cache_get_time),
|
||||
to_int_ms(remote_fx_graph_cache_put_time),
|
||||
start_time_us=start_time_ns // 1000,
|
||||
duration_us=duration_ns // 1000,
|
||||
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
|
||||
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
|
||||
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
|
||||
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
|
||||
triton_compile_time_us=None, # TODO: instrument
|
||||
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
|
||||
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
|
||||
dynamo_compile_time_before_restart_us=to_int_us(
|
||||
metrics = {
|
||||
"compile_id": str(compile_id),
|
||||
"frame_key": frame_key,
|
||||
"co_name": code.co_name,
|
||||
"co_filename": code.co_filename,
|
||||
"co_firstlineno": code.co_firstlineno,
|
||||
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
"accumulated_cache_size": cache_size.num_cache_entries,
|
||||
"guard_count": guard_count,
|
||||
"shape_env_guard_count": shape_env_guard_count,
|
||||
"graph_op_count": graph_op_count,
|
||||
"graph_node_count": graph_node_count,
|
||||
"graph_input_count": graph_input_count,
|
||||
# TODO(masnesral): start_time and end_time shouldn't need to be
|
||||
# populated manually.
|
||||
"start_time": start_time_ns / 1e9,
|
||||
"fail_type": fail_type,
|
||||
"fail_reason": fail_reason,
|
||||
"fail_user_frame_filename": fail_user_frame_filename,
|
||||
"fail_user_frame_lineno": fail_user_frame_lineno,
|
||||
"non_compliant_ops": non_compliant_ops,
|
||||
"compliant_custom_ops": compliant_custom_ops,
|
||||
"restart_reasons": restart_reasons,
|
||||
"dynamo_time_before_restart_s": dynamo_time_before_restart,
|
||||
"has_guarded_code": guarded_code is not None,
|
||||
"structured_logging_overhead_s": structured_logging_overhead_s,
|
||||
"config_suppress_errors": config.suppress_errors,
|
||||
"config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
|
||||
"specialize_float": config.specialize_float,
|
||||
"dynamo_config": json.dumps(config_dict),
|
||||
"is_forward": True,
|
||||
"start_time_us": start_time_ns // 1000,
|
||||
"end_time_us": end_time_ns // 1000,
|
||||
"duration_us": duration_ns // 1000,
|
||||
"dynamo_compile_time_before_restart_us": to_int_us(
|
||||
dynamo_time_before_restart
|
||||
),
|
||||
cuda_synchronize_time_us=None, # TODO: instrument
|
||||
distributed_ephemeral_timeout_us=to_int_us(
|
||||
remote_cache_time_saved
|
||||
), # TODO: instrument more accurately
|
||||
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
|
||||
remote_fx_graph_cache_get_time_us=to_int_us(
|
||||
remote_fx_graph_cache_get_time
|
||||
"structured_logging_overhead_us": to_int_us(
|
||||
structured_logging_overhead_s
|
||||
),
|
||||
remote_fx_graph_cache_put_time_us=to_int_us(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
)
|
||||
record_compilation_metrics(metrics)
|
||||
}
|
||||
metrics_context.update_outer(metrics)
|
||||
torch._dynamo.callback_handler.run_end_callbacks()
|
||||
chromium_event_log.log_event_end(
|
||||
"dynamo", time.time_ns(), {}, chromium_start_time, True
|
||||
)
|
||||
# === END WARNING WARNING WARNING ===
|
||||
|
||||
chromium_event_log.log_event_end(
|
||||
"dynamo", time.time_ns(), {}, chromium_start_time, True
|
||||
)
|
||||
|
||||
|
||||
class ConvertFrame:
|
||||
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None:
|
||||
|
||||
95
torch/_dynamo/metrics_context.py
Normal file
95
torch/_dynamo/metrics_context.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import Any, Callable, Dict, Optional, Type
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
OnExitType: TypeAlias = Callable[[Dict[str, Any]], None]
|
||||
|
||||
|
||||
class MetricsContext:
|
||||
def __init__(self, on_exit: OnExitType):
|
||||
"""
|
||||
Use this class as a contextmanager to create a context under which to accumulate
|
||||
a set of metrics, e.g., metrics gathered during a compilation. On exit of the
|
||||
contextmanager, call the provided 'on_exit' function and pass a dictionary of
|
||||
all metrics set during the lifetime of the contextmanager.
|
||||
"""
|
||||
self._on_exit = on_exit
|
||||
self._metrics: Dict[str, Any] = {}
|
||||
self._level = 0
|
||||
|
||||
def __enter__(self) -> "MetricsContext":
|
||||
"""
|
||||
Initialize metrics recording.
|
||||
"""
|
||||
if self._level == 0:
|
||||
# In case of recursion, track at the outermost context.
|
||||
self._metrics = {}
|
||||
|
||||
self._level += 1
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
_traceback: Any,
|
||||
) -> None:
|
||||
"""
|
||||
At exit, call the provided on_exit function.
|
||||
"""
|
||||
self._level -= 1
|
||||
assert self._level >= 0
|
||||
if self._level == 0:
|
||||
self._on_exit(self._metrics)
|
||||
|
||||
def in_progress(self) -> bool:
|
||||
"""
|
||||
True if we've entered the context.
|
||||
"""
|
||||
return self._level > 0
|
||||
|
||||
def increment(self, metric: str, value: int) -> None:
|
||||
"""
|
||||
Increment a metric by a given amount.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
|
||||
if metric not in self._metrics:
|
||||
self._metrics[metric] = 0
|
||||
self._metrics[metric] += value
|
||||
|
||||
def set(self, metric: str, value: Any) -> None:
|
||||
"""
|
||||
Set a metric to a given value. Raises if the metric has been assigned previously
|
||||
in the current context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
|
||||
if metric in self._metrics:
|
||||
raise RuntimeError(
|
||||
f"Metric '{metric}' has already been set in the current context"
|
||||
)
|
||||
self._metrics[metric] = value
|
||||
|
||||
def update(self, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Set multiple metrics directly. This method does NOT increment. Raises if any
|
||||
metric has been assigned previously in the current context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
|
||||
existing = self._metrics.keys() & values.keys()
|
||||
if existing:
|
||||
raise RuntimeError(
|
||||
f"Metric(s) {existing} have already been set in the current context"
|
||||
)
|
||||
self._metrics.update(values)
|
||||
|
||||
def update_outer(self, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update, but only when at the outermost context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
|
||||
if self._level == 1:
|
||||
self.update(values)
|
||||
@ -1391,6 +1391,7 @@ class OutputGraph:
|
||||
"OutputGraph.call_user_compiler",
|
||||
phase_name="backend_compile",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
|
||||
):
|
||||
return self._call_user_compiler(gm)
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ from typing import (
|
||||
DefaultDict,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
KeysView,
|
||||
@ -71,6 +72,7 @@ from torch._C import (
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.metrics_context import MetricsContext
|
||||
from torch._guards import Source, TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import (
|
||||
@ -139,12 +141,10 @@ log = logging.getLogger(__name__)
|
||||
# profiling compilation time by function
|
||||
compilation_time_metrics: Dict[str, List[float]] = {}
|
||||
|
||||
# profiling compilation time by frame phase
|
||||
frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
|
||||
lambda: collections.defaultdict(float)
|
||||
)
|
||||
|
||||
codecache_metrics: Counter[str] = collections.Counter()
|
||||
# This supports calculate_time_spent(), which reports cumulative times
|
||||
# across the process for any "phase" populated by dynamo_timed. Reset if
|
||||
# reset_frame_count() is called.
|
||||
cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float)
|
||||
|
||||
timer_counter = itertools.count()
|
||||
|
||||
@ -220,7 +220,7 @@ def increment_frame() -> None:
|
||||
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
|
||||
def reset_frame_count() -> None:
|
||||
global curr_frame
|
||||
frame_phase_timing.clear()
|
||||
cumulative_time_spent_ns.clear()
|
||||
compilation_time_metrics.clear()
|
||||
curr_frame = 0
|
||||
|
||||
@ -233,25 +233,16 @@ def increment_op_count(cnt: int) -> None:
|
||||
op_count += cnt
|
||||
|
||||
|
||||
# Calculate total time spent so far for each phase
|
||||
# Get the total time in seconds for each "phase"
|
||||
# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
|
||||
def calculate_time_spent() -> Dict[str, float]:
|
||||
total_wall_time = 0.0
|
||||
total_by_key = {}
|
||||
for timings in frame_phase_timing.values():
|
||||
total_wall_time += timings.get(
|
||||
"entire_frame_compile", timings.get("inductor_compile", 0)
|
||||
)
|
||||
|
||||
for key, timing in timings.items():
|
||||
if key not in total_by_key:
|
||||
total_by_key[key] = timing
|
||||
else:
|
||||
total_by_key[key] += timing
|
||||
|
||||
if total_by_key:
|
||||
total_by_key["total_wall_time"] = total_wall_time
|
||||
for phase, timing in cumulative_time_spent_ns.items():
|
||||
total_by_key[phase] = timing / 1e9
|
||||
|
||||
total_by_key["total_wall_time"] = total_by_key.get(
|
||||
"entire_frame_compile", 0
|
||||
) + total_by_key.get("entire_backward_compile", 0)
|
||||
return total_by_key
|
||||
|
||||
|
||||
@ -270,188 +261,124 @@ def print_time_report() -> None:
|
||||
print(out)
|
||||
|
||||
|
||||
def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
|
||||
frame_phase_timing[key][phase_name] += time_spent
|
||||
# Use the following singleton to capture and log CompilationMetrics. Entering the context
|
||||
# manager allocates a new record to be logged when it exits. (You should not need to use
|
||||
# this directly unless you introduce a new code path where compilation metrics would be
|
||||
# gathered). While compiling, use the setters or timer in MetricsContext to update fields
|
||||
# in the current context. For example:
|
||||
#
|
||||
# To set a single field once (use overwrite=True to overwrite):
|
||||
# get_metrics_context().set("metric_name", value)
|
||||
#
|
||||
# To set multiple fields at once (use overwrite=True to overwrite):
|
||||
# get_metrics_context().update({"name1": val1, "name2": val2})
|
||||
#
|
||||
# To increment an integer field:
|
||||
# get_metrics_context().increment("metric_name", value)
|
||||
#
|
||||
# To record execution time, MetricsContext works with dynamo_timed:
|
||||
# def foo(...):
|
||||
# # Updates the "metric_us" field.
|
||||
# with dynamo_timed("metric", dynamo_compile_column_us="metric_us")
|
||||
# ...
|
||||
#
|
||||
_METRICS_CONTEXT: MetricsContext
|
||||
|
||||
|
||||
# Use frame_phase_timing to record remote_cache_time_saved
|
||||
# This follows the same principles of key as the other frame phase timings,
|
||||
# but is incremented by FxGraphCache (and later AOTAutogradCache) directly
|
||||
def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None:
|
||||
key = None
|
||||
if is_backward:
|
||||
# Use compile id as the frame key for backwards compilation
|
||||
key = str(torch._guards.CompileContext.current_compile_id())
|
||||
else:
|
||||
key = str(curr_frame)
|
||||
# Convert to seconds (as a float)
|
||||
time_saved = time_saved_ns / 1e9
|
||||
_add_time_spent(key, "remote_cache_time_saved", time_saved)
|
||||
|
||||
|
||||
# dynamo_timed is a context manager
|
||||
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
|
||||
# where the key is the functions name.
|
||||
# For example:
|
||||
#
|
||||
# def _foo(...):
|
||||
# with dynamo_timed("_foo"):
|
||||
# ...
|
||||
#
|
||||
# Would show up as an entry in our timing dict:
|
||||
# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
|
||||
# This is extremely useful for granular debugging.
|
||||
#
|
||||
# Although it is tempting to use dynamo_timed as a decorator, please do not.
|
||||
# In its decorator form it makes cProfile traces less useful as dynamo_timed
|
||||
# suddenly becomes a bottleneck for lots of function calls (as only one parent
|
||||
# pointer is recorded).
|
||||
#
|
||||
# For a higher-level mode, pass a phase_name into dynamo_timed
|
||||
# phase_names record an extra record into a separate compilation timing structure,
|
||||
# one keyed on frame+name rather than function.
|
||||
# The frame is incremented outside of this function, in def increment_frame() above.
|
||||
# `fwd_only` is used to identify if this phase or function is only called
|
||||
# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
|
||||
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
|
||||
def get_metrics_context() -> MetricsContext:
|
||||
return _METRICS_CONTEXT
|
||||
|
||||
|
||||
@contextmanager
|
||||
def dynamo_timed(
|
||||
key: str,
|
||||
# TODO(masneral): Deprecate this param.
|
||||
phase_name: Optional[str] = None,
|
||||
log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event
|
||||
log_pt2_compile_event: bool = False,
|
||||
# TODO(masnesral): fwd_only is ignored. Remove it.
|
||||
fwd_only: bool = True,
|
||||
):
|
||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||
metadata: Optional[Dict[str, object]] = None,
|
||||
dynamo_compile_column_us: Optional[str] = None,
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
dynamo_timed is a context manager
|
||||
By wrapping a function in dynamo_timed, we can get a few things:
|
||||
|
||||
1) Log timings to pt2_compile_events.
|
||||
2) Log timings to CompilationMetrics (dynamo_compile).
|
||||
3) Chromium events.
|
||||
4) Storing a record in compilation_time_metrics
|
||||
For example:
|
||||
|
||||
def _foo(...):
|
||||
with dynamo_timed("_foo"):
|
||||
...
|
||||
|
||||
Would show up as an entry in our timing dict:
|
||||
OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
|
||||
This is extremely useful for granular debugging.
|
||||
|
||||
Although it is tempting to use dynamo_timed as a decorator, please do not.
|
||||
In its decorator form it makes cProfile traces less useful as dynamo_timed
|
||||
suddenly becomes a bottleneck for lots of function calls (as only one parent
|
||||
pointer is recorded).
|
||||
|
||||
Params:
|
||||
- key: key into compile_time_metrics. If phase_name is not provided, this is
|
||||
also the event name used for pt2_compile_events logs and chromium events.
|
||||
- phase_name: Optional override for the event name.
|
||||
- log_pt2_compile_event: Whether to log a pt2 compile event internally.
|
||||
- metadata: Extra metadata to put in pt2_compile_events.
|
||||
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
|
||||
field to be logged to dyname_compile column. We expect all columns to be _us;
|
||||
therefore, the field name must end with "_us".
|
||||
"""
|
||||
# We're standardizing on microseconds for dynamo_compile timings.
|
||||
if dynamo_compile_column_us is not None:
|
||||
assert dynamo_compile_column_us.endswith("_us")
|
||||
|
||||
if phase_name:
|
||||
event_name = phase_name
|
||||
fn_name = key
|
||||
else:
|
||||
event_name = key
|
||||
fn_name = None
|
||||
|
||||
if key not in compilation_time_metrics:
|
||||
compilation_time_metrics[key] = []
|
||||
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
time_spent = float("-inf")
|
||||
event_metadata = {}
|
||||
if metadata:
|
||||
event_metadata.update(metadata)
|
||||
if fn_name:
|
||||
event_metadata.update({"fn_name": fn_name})
|
||||
|
||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||
start_ns = time.time_ns()
|
||||
chromium_log.log_event_start(event_name, start_ns, event_metadata)
|
||||
|
||||
try:
|
||||
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
|
||||
t0 = time.time()
|
||||
if phase_name:
|
||||
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
|
||||
else:
|
||||
chromium_log.log_event_start(key, start_ns, {})
|
||||
yield
|
||||
time_spent = time.time() - t0
|
||||
compilation_time_metrics[key].append(time_spent)
|
||||
except Exception as e:
|
||||
fail_type = str(type(e))
|
||||
fail_reason = str(e)
|
||||
raise
|
||||
finally:
|
||||
end_ns = time.time_ns()
|
||||
# Always log the end event even on exception
|
||||
if phase_name:
|
||||
chromium_log.log_event_end(
|
||||
phase_name,
|
||||
end_ns,
|
||||
{},
|
||||
start_ns,
|
||||
log_pt2_compile_event,
|
||||
)
|
||||
else:
|
||||
chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event)
|
||||
# Only record backward compilation metrics if phase_name is not None!
|
||||
if phase_name:
|
||||
frame_key = str(curr_frame)
|
||||
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
|
||||
# use frame_key as time aggregation key.
|
||||
if fwd_only and fail_type is None:
|
||||
_add_time_spent(frame_key, phase_name, time_spent)
|
||||
else:
|
||||
# fwd + bwd compilation stages: inductor_compile, code_gen.
|
||||
# use frame_key as time aggregation key for fwd graphs;
|
||||
# use compile_id as time aggregation key for bwd graphs.
|
||||
if torch._guards.TracingContext.try_get() is not None:
|
||||
aot_graph_name = str(
|
||||
torch._guards.TracingContext.get().aot_graph_name
|
||||
)
|
||||
if (
|
||||
"forward" in aot_graph_name or "inference" in aot_graph_name
|
||||
) and fail_type is None:
|
||||
_add_time_spent(frame_key, phase_name, time_spent)
|
||||
elif "backward" in aot_graph_name:
|
||||
compile_id = str(
|
||||
torch._guards.CompileContext.current_compile_id()
|
||||
)
|
||||
if fail_type is None:
|
||||
_add_time_spent(compile_id, phase_name, time_spent)
|
||||
|
||||
# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
|
||||
# one record for one bwd graph.
|
||||
if phase_name == "inductor_compile":
|
||||
if fail_type is None:
|
||||
inductor_compile_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("inductor_compile", None)
|
||||
code_gen_time = frame_phase_timing[compile_id].get(
|
||||
"code_gen", None
|
||||
)
|
||||
remote_cache_time_saved = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_cache_time_saved", None)
|
||||
remote_fx_graph_cache_get_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_fx_graph_cache_get", None)
|
||||
remote_fx_graph_cache_put_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_fx_graph_cache_put", None)
|
||||
else:
|
||||
inductor_compile_time = None
|
||||
code_gen_time = None
|
||||
remote_cache_time_saved = None
|
||||
remote_fx_graph_cache_get_time = None
|
||||
remote_fx_graph_cache_put_time = None
|
||||
structured_logging_overhead_s = (
|
||||
torch._logging.get_structured_logging_overhead()
|
||||
)
|
||||
metrics = CompilationMetrics(
|
||||
compile_id=compile_id,
|
||||
inductor_compile_time_s=inductor_compile_time,
|
||||
code_gen_time_s=code_gen_time,
|
||||
fail_type=fail_type,
|
||||
fail_reason=fail_reason,
|
||||
remote_cache_time_saved_s=remote_cache_time_saved,
|
||||
structured_logging_overhead_s=structured_logging_overhead_s,
|
||||
is_forward=False, # is_forward
|
||||
num_triton_bundles=codecache_metrics.get(
|
||||
"num_triton_bundles", None
|
||||
),
|
||||
remote_fx_graph_cache_get_time_ms=to_int_ms(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
remote_fx_graph_cache_put_time_ms=to_int_ms(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
start_time_us=start_ns // 1000,
|
||||
duration_us=(end_ns - start_ns) // 1000,
|
||||
inductor_cumulative_compile_time_us=to_int_us(
|
||||
inductor_compile_time
|
||||
),
|
||||
inductor_code_gen_cumulative_compile_time_us=to_int_us(
|
||||
code_gen_time
|
||||
),
|
||||
distributed_ephemeral_timeout_us=to_int_us(
|
||||
remote_cache_time_saved
|
||||
), # TODO: instrument more accurately
|
||||
structured_logging_overhead_us=to_int_us(
|
||||
structured_logging_overhead_s
|
||||
),
|
||||
remote_fx_graph_cache_get_time_us=to_int_us(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
remote_fx_graph_cache_put_time_us=to_int_us(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
)
|
||||
record_compilation_metrics(metrics)
|
||||
time_spent_ns = end_ns - start_ns
|
||||
compilation_time_metrics[key].append(time_spent_ns / 1e9)
|
||||
chromium_log.log_event_end(
|
||||
event_name, end_ns, {}, start_ns, log_pt2_compile_event
|
||||
)
|
||||
if dynamo_compile_column_us:
|
||||
metrics_context = get_metrics_context()
|
||||
if metrics_context.in_progress():
|
||||
metrics_context.increment(
|
||||
dynamo_compile_column_us, time_spent_ns // 1000
|
||||
)
|
||||
# TODO: the events that we capture in calculate_time_spent() seem a little
|
||||
# arbitrary. Currently, it's only those fields that are present in
|
||||
# CompilationMetrics (but note that we accumulate by the associated event
|
||||
# name, not the field name in CompilationMetrics). Do we want to keep it
|
||||
# this way?
|
||||
cumulative_time_spent_ns[event_name] += time_spent_ns
|
||||
|
||||
|
||||
@overload
|
||||
@ -866,6 +793,11 @@ def to_int_us(v: Optional[float]) -> Optional[int]:
|
||||
return None if v is None else int(v * 1_000_000)
|
||||
|
||||
|
||||
# Version field added to every log. Increment to make it easier to distinguish new
|
||||
# vs. old entries when you make a substantive change to how the logs are populated.
|
||||
LOG_FORMAT_VERSION = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationMetrics:
|
||||
compile_id: Optional[str] = None
|
||||
@ -913,15 +845,18 @@ class CompilationMetrics:
|
||||
aot_autograd_cumulative_compile_time_us: Optional[int] = None
|
||||
inductor_cumulative_compile_time_us: Optional[int] = None
|
||||
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
|
||||
triton_compile_time_us: Optional[int] = None
|
||||
runtime_cudagraphify_time_us: Optional[int] = None
|
||||
runtime_triton_autotune_time_us: Optional[int] = None
|
||||
triton_compile_time_us: Optional[int] = None # TODO: instrument
|
||||
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
|
||||
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
|
||||
dynamo_compile_time_before_restart_us: Optional[int] = None
|
||||
cuda_synchronize_time_us: Optional[int] = None
|
||||
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
|
||||
distributed_ephemeral_timeout_us: Optional[int] = None
|
||||
structured_logging_overhead_us: Optional[int] = None
|
||||
remote_fx_graph_cache_get_time_us: Optional[int] = None
|
||||
remote_fx_graph_cache_put_time_us: Optional[int] = None
|
||||
backward_cumulative_compile_time_us: Optional[int] = None
|
||||
end_time_us: Optional[int] = None
|
||||
log_format_version: int = LOG_FORMAT_VERSION
|
||||
|
||||
|
||||
DEFAULT_COMPILATION_METRICS_LIMIT = 64
|
||||
@ -969,8 +904,32 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
|
||||
)
|
||||
|
||||
|
||||
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
|
||||
global _compilation_metrics
|
||||
def record_compilation_metrics(metrics: Dict[str, Any]):
|
||||
# TODO: Temporary; populate legacy fields from their replacements.
|
||||
# Remove when we decide we can really deprecate them.
|
||||
def us_to_s(field):
|
||||
metric = metrics.get(field, None)
|
||||
return metric / 1e6 if metric is not None else None
|
||||
|
||||
def us_to_ms(field):
|
||||
metric = metrics.get(field, None)
|
||||
return metric // 1000 if metric is not None else None
|
||||
|
||||
legacy_metrics = {
|
||||
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
|
||||
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
|
||||
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
|
||||
"code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"),
|
||||
"remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"),
|
||||
"remote_fx_graph_cache_get_time_ms": us_to_ms(
|
||||
"remote_fx_graph_cache_get_time_us"
|
||||
),
|
||||
"remote_fx_graph_cache_put_time_ms": us_to_ms(
|
||||
"remote_fx_graph_cache_put_time_us"
|
||||
),
|
||||
}
|
||||
|
||||
compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics})
|
||||
_compilation_metrics.append(compilation_metrics)
|
||||
if compilation_metrics.is_forward:
|
||||
name = "compilation_metrics"
|
||||
@ -979,10 +938,7 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics):
|
||||
name = "bwd_compilation_metrics"
|
||||
torch._logging.trace_structured(
|
||||
name,
|
||||
lambda: {
|
||||
k: list(v) if isinstance(v, set) else v
|
||||
for k, v in dataclasses.asdict(compilation_metrics).items()
|
||||
},
|
||||
lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()},
|
||||
# NB: Because compilation metrics *includes* the logging overhead time,
|
||||
# we can't both *measure* the logging overhead of compilation metrics
|
||||
# without making it inconsistent with compilation metrics itself, so
|
||||
@ -993,6 +949,10 @@ def record_compilation_metrics(compilation_metrics: CompilationMetrics):
|
||||
log_compilation_event(compilation_metrics)
|
||||
|
||||
|
||||
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
|
||||
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
|
||||
|
||||
|
||||
def set_compilation_metrics_limit(new_size: int) -> None:
|
||||
global _compilation_metrics
|
||||
while len(_compilation_metrics) > new_size:
|
||||
|
||||
@ -632,7 +632,7 @@ class AOTAutogradCache:
|
||||
)
|
||||
# TODO: should we use the same field for remote cache time saved for both
|
||||
# FXGraphCache and AOTAutogradCache?
|
||||
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
|
||||
# get_metrics_context().increment(...)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
|
||||
@ -10,6 +10,7 @@ import builtins
|
||||
import collections
|
||||
import itertools
|
||||
import pprint
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
@ -18,6 +19,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us
|
||||
from torch._guards import (
|
||||
compile_context,
|
||||
CompileContext,
|
||||
@ -2002,17 +2004,49 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with tracing(saved_context), compile_context(
|
||||
saved_compile_context
|
||||
), context(), track_graph_compiling(aot_config, "backward"):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
)
|
||||
# Maybe save cache entry
|
||||
if try_save_cache_entry is not None:
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
), context(), track_graph_compiling(
|
||||
aot_config, "backward"
|
||||
), get_metrics_context(), dynamo_timed(
|
||||
"backward._backward_impl",
|
||||
phase_name="entire_backward_compile",
|
||||
dynamo_compile_column_us="backward_cumulative_compile_time_us",
|
||||
):
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
start_ns = time.time_ns()
|
||||
try:
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
)
|
||||
# Maybe save cache entry
|
||||
if try_save_cache_entry is not None:
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(masnesral): Populating the exception info should be automatic.
|
||||
fail_type = type(e).__qualname__
|
||||
fail_reason = str(e)
|
||||
finally:
|
||||
# TODO(masnesral): Populating time fields should be automatic.
|
||||
end_ns = time.time_ns()
|
||||
metrics = {
|
||||
"compile_id": str(
|
||||
torch._guards.CompileContext.current_compile_id()
|
||||
),
|
||||
"fail_type": fail_type,
|
||||
"fail_reason": fail_reason,
|
||||
"is_forward": False,
|
||||
"start_time_us": start_ns // 1000,
|
||||
"end_time_us": end_ns // 1000,
|
||||
"duration_us": (end_ns - start_ns) // 1000,
|
||||
"structured_logging_overhead_us": to_int_us(
|
||||
torch._logging.get_structured_logging_overhead(),
|
||||
),
|
||||
}
|
||||
get_metrics_context().update_outer(metrics)
|
||||
|
||||
if (
|
||||
torch._functorch.config.donated_buffer
|
||||
|
||||
@ -54,11 +54,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.utils import (
|
||||
add_remote_cache_time_saved,
|
||||
codecache_metrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_chromium_event_logger,
|
||||
get_metrics_context,
|
||||
)
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
@ -1152,7 +1151,7 @@ class FxGraphCache:
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
codecache_metrics["num_triton_bundles"] += 1
|
||||
get_metrics_context().increment("num_triton_bundles", 1)
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
@ -1449,7 +1448,9 @@ class FxGraphCache:
|
||||
|
||||
if (time_saved_ns := compiled_graph._time_taken_ns) is not None:
|
||||
cache_info["time_saved_ns"] = time_saved_ns
|
||||
add_remote_cache_time_saved(time_saved_ns, is_backward)
|
||||
get_metrics_context().increment(
|
||||
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
|
||||
)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
|
||||
@ -567,7 +567,7 @@ def compile_fx_inner(
|
||||
"compile_fx_inner",
|
||||
phase_name="inductor_compile",
|
||||
log_pt2_compile_event=True,
|
||||
fwd_only=False,
|
||||
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
|
||||
)
|
||||
)
|
||||
# NB: Why is this the dynamo_compile counter? The rule here is that
|
||||
|
||||
@ -1949,7 +1949,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
"GraphLowering.compile_to_module",
|
||||
phase_name="code_gen",
|
||||
log_pt2_compile_event=True,
|
||||
fwd_only=False,
|
||||
dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
|
||||
):
|
||||
return self._compile_to_module()
|
||||
|
||||
|
||||
@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial(
|
||||
"FbRemoteFxGraphCache.get",
|
||||
phase_name="remote_fx_graph_cache_get",
|
||||
log_pt2_compile_event=False,
|
||||
fwd_only=False,
|
||||
dynamo_compile_column_us="remote_fx_graph_cache_get_time_us",
|
||||
)
|
||||
remote_fx_cache_put_timed = functools.partial(
|
||||
dynamo_timed,
|
||||
"FbRemoteFxGraphCache.put",
|
||||
phase_name="remote_fx_graph_cache_put",
|
||||
log_pt2_compile_event=False,
|
||||
fwd_only=False,
|
||||
dynamo_compile_column_us="remote_fx_graph_cache_put_time_us",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -135,5 +135,11 @@ try:
|
||||
except AttributeError: # Compile workers only have a mock version of torch
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamo_timed(key, phase_name=None, fwd_only=True):
|
||||
def dynamo_timed(
|
||||
key,
|
||||
phase_name=None,
|
||||
fwd_only=True,
|
||||
metadata=None,
|
||||
dynamo_compile_column_us=None,
|
||||
):
|
||||
yield
|
||||
|
||||
@ -1099,7 +1099,6 @@ class LazyString:
|
||||
structured_logging_overhead: Dict[str, float] = defaultdict(float)
|
||||
|
||||
|
||||
# Same principle as add_remote_cache_time_saved, but do it for structured logging
|
||||
def add_structured_logging_overhead(time_spent: float) -> None:
|
||||
global structured_logging_overhead
|
||||
key = None
|
||||
|
||||
Reference in New Issue
Block a user