Revert "[logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849)"

This reverts commit cb15c1515778499ae801dcf67d55c8bdab4724ef.

Reverted https://github.com/pytorch/pytorch/pull/139849 on behalf of https://github.com/kit1980 due to Breaking an internal tests + there is a bug according to the author ([comment](https://github.com/pytorch/pytorch/pull/139849#issuecomment-2474459094))
This commit is contained in:
PyTorch MergeBot
2024-11-13 18:47:51 +00:00
parent 42622cf7d5
commit d63eb3c46c
14 changed files with 326 additions and 456 deletions

View File

@ -1,78 +0,0 @@
# Owner(s): ["module: dynamo"]
from torch._dynamo.metrics_context import MetricsContext
from torch._dynamo.test_case import run_tests, TestCase
class TestMetricsContext(TestCase):
def setUp(self):
super().setUp()
self.metrics = {}
def _on_exit(self, metrics):
# Save away the metrics to be validated in the test.
self.metrics = metrics.copy()
def test_context_exists(self):
"""
Setting a value without entering the context should raise.
"""
context = MetricsContext(self._on_exit)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.increment("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.set("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.update({"m", 1})
def test_nested_context(self):
"""
Only the outermost context should get an on_exit call, and it should
include everything.
"""
context = MetricsContext(self._on_exit)
with context:
with context:
context.set("m1", 1)
self.assertEqual(self.metrics, {})
context.set("m2", 2)
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
def test_set(self):
"""
Validate various ways to set metrics.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
context.set("m2", 2)
context.update({"m3": 3, "m4": 4})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
def test_set_disallow_overwrite(self):
"""
Validate set won't overwrite.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.set("m1", 2)
self.assertEqual(self.metrics, {"m1": 1})
def test_update_disallow_overwrite(self):
"""
Validate update won't overwite.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.update({"m1": 7, "m3": 3})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
if __name__ == "__main__":
run_tests()

View File

@ -159,7 +159,6 @@ class TestDynamoTimed(TestCase):
'_recursive_post_grad_passes': [0.0, 0.0],
'_recursive_pre_grad_passes': [0.0],
'async_compile.wait': [0.0, 0.0],
'backward._backward_impl': [0.0],
'compile_file': [0.0, 0.0],
'compile_fx.<locals>.bw_compiler': [0.0],
'compile_fx.<locals>.fw_compiler_base': [0.0],
@ -175,7 +174,6 @@ class TestDynamoTimed(TestCase):
"""\
{'backend_compile': 0.0,
'code_gen': 0.0,
'entire_backward_compile': 0.0,
'entire_frame_compile': 0.0,
'inductor_compile': 0.0,
'total_wall_time': 0.0}""", # noqa: B950
@ -202,7 +200,6 @@ class TestDynamoTimed(TestCase):
{'accumulated_cache_size': 0,
'aot_autograd_cumulative_compile_time_us': 0,
'backend_compile_time_s': 0.0,
'backward_cumulative_compile_time_us': None,
'cache_size': 0,
'co_filename': None,
'co_firstlineno': None,
@ -213,13 +210,12 @@ class TestDynamoTimed(TestCase):
'config_inline_inbuilt_nn_modules': False,
'config_suppress_errors': False,
'cuda_synchronize_time_us': None,
'distributed_ephemeral_timeout_us': None,
'distributed_ephemeral_timeout_us': 0,
'duration_us': 0,
'dynamo_compile_time_before_restart_us': 0,
'dynamo_config': None,
'dynamo_cumulative_compile_time_us': 0,
'dynamo_time_before_restart_s': 0.0,
'end_time_us': 100,
'entire_frame_compile_time_s': 0.0,
'fail_reason': None,
'fail_type': None,
@ -235,10 +231,9 @@ class TestDynamoTimed(TestCase):
'inductor_compile_time_s': 0.0,
'inductor_cumulative_compile_time_us': 0,
'is_forward': True,
'log_format_version': 2,
'non_compliant_ops': set(),
'num_triton_bundles': None,
'remote_cache_time_saved_s': None,
'remote_cache_time_saved_s': 0,
'remote_fx_graph_cache_get_time_ms': None,
'remote_fx_graph_cache_get_time_us': None,
'remote_fx_graph_cache_put_time_ms': None,
@ -262,7 +257,6 @@ class TestDynamoTimed(TestCase):
{'accumulated_cache_size': None,
'aot_autograd_cumulative_compile_time_us': None,
'backend_compile_time_s': None,
'backward_cumulative_compile_time_us': 0,
'cache_size': None,
'co_filename': None,
'co_firstlineno': None,
@ -279,7 +273,6 @@ class TestDynamoTimed(TestCase):
'dynamo_config': None,
'dynamo_cumulative_compile_time_us': None,
'dynamo_time_before_restart_s': None,
'end_time_us': 100,
'entire_frame_compile_time_s': None,
'fail_reason': None,
'fail_type': None,
@ -295,7 +288,6 @@ class TestDynamoTimed(TestCase):
'inductor_compile_time_s': 0.0,
'inductor_cumulative_compile_time_us': 0,
'is_forward': False,
'log_format_version': 2,
'non_compliant_ops': None,
'num_triton_bundles': None,
'remote_cache_time_saved_s': None,
@ -310,7 +302,7 @@ class TestDynamoTimed(TestCase):
'specialize_float': None,
'start_time': None,
'start_time_us': 100,
'structured_logging_overhead_s': None,
'structured_logging_overhead_s': 0.0,
'structured_logging_overhead_us': 0,
'triton_compile_time_us': None}""", # noqa: B950
)

View File

@ -30,7 +30,7 @@ import torch
import torch._logging
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context
from torch._dynamo.utils import CompileTimeInstructionCounter
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import (
@ -105,9 +105,12 @@ from .symbolic_convert import (
from .trace_rules import is_numpy
from .utils import (
CleanupManager,
codecache_metrics,
CompilationMetrics,
counters,
dynamo_timed,
format_bytecode,
frame_phase_timing,
gen_record_file_name,
get_chromium_event_logger,
increment_frame,
@ -115,8 +118,10 @@ from .utils import (
istype,
LazyString,
orig_code_map,
record_compilation_metrics,
reset_graph_break_dup_checker,
setup_compile_debug,
to_int_ms,
to_int_us,
troubleshooting_url,
write_record_to_file,
@ -693,9 +698,7 @@ def _compile(
with contextlib.ExitStack() as stack:
stack.enter_context(
dynamo_timed(
"_compile.compile_inner",
phase_name="entire_frame_compile",
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
"_compile.compile_inner", phase_name="entire_frame_compile"
)
)
stack.enter_context(
@ -860,11 +863,9 @@ def _compile(
chromium_event_log.reset()
chromium_start_time = time.time_ns()
chromium_event_log.log_event_start("dynamo", chromium_start_time, {})
metrics_context = get_metrics_context()
with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(
CompileContext(compile_id)
), metrics_context:
):
restart_reasons: set[str] = set()
# This is shared across restarts
speculation_log = SpeculationLog()
@ -971,6 +972,7 @@ def _compile(
fail_user_frame_lineno: Optional[int] = None
torch._dynamo.utils.ReinplaceCounters.clear()
guarded_code = None
codecache_metrics.clear()
try:
guarded_code = compile_inner(code, one_graph, hooks, transform)
@ -987,7 +989,6 @@ def _compile(
return guarded_code
except Exception as e:
# TODO(masnesral): Populating the exception info should be automatic
fail_type = type(e).__qualname__
fail_reason = str(e)
# NB: e's msg is mutated here to add user stack, but we DON'T want
@ -1037,34 +1038,66 @@ def _compile(
if tracer:
tracer.output.local_scope = {}
end_time_ns = time.time_ns()
duration_ns = end_time_ns - start_time_ns
duration_ns = time.time_ns() - start_time_ns
from .utils import curr_frame
frame_key = str(curr_frame)
if fail_reason is None and output is not None:
if (
fail_reason is None
and output is not None
and frame_key in frame_phase_timing
):
guard_count = len(output.guards)
shape_env_guard_count = len(output.shape_env.guards)
graph_op_count = output.count_calls()
graph_node_count = len(output.graph.nodes)
graph_input_count = len(output.placeholders)
entire_frame_compile_time = frame_phase_timing[frame_key].get(
"entire_frame_compile", None
)
backend_compile_time = frame_phase_timing[frame_key].get(
"backend_compile", None
)
inductor_compile_time = frame_phase_timing[frame_key].get(
"inductor_compile", None
)
code_gen_time = frame_phase_timing[frame_key].get("code_gen", None)
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
compliant_custom_ops = {
op.__qualname__ for op in output.compliant_custom_ops
}
remote_cache_time_saved = frame_phase_timing[frame_key].get(
"remote_cache_time_saved", 0
)
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_get", None
)
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_put", None
)
num_triton_bundles = codecache_metrics.get("num_triton_bundles", None)
torch._dynamo.utils.ReinplaceCounters.log()
else:
guard_count = None
shape_env_guard_count = None
graph_op_count = None
graph_node_count = None
graph_input_count = None
entire_frame_compile_time = None
backend_compile_time = None
inductor_compile_time = None
code_gen_time = None
non_compliant_ops = set({})
compliant_custom_ops = set({})
restart_reasons = set()
# If compilation failed, the entire time is wasted
dynamo_time_before_restart = duration_ns / 1e9
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None
num_triton_bundles = None
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
@ -1099,55 +1132,74 @@ def _compile(
}
config_dict = clean_for_json(config.get_config_copy())
metrics = {
"compile_id": str(compile_id),
"frame_key": frame_key,
"co_name": code.co_name,
"co_filename": code.co_filename,
"co_firstlineno": code.co_firstlineno,
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
"accumulated_cache_size": cache_size.num_cache_entries,
"guard_count": guard_count,
"shape_env_guard_count": shape_env_guard_count,
"graph_op_count": graph_op_count,
"graph_node_count": graph_node_count,
"graph_input_count": graph_input_count,
# TODO(masnesral): start_time and end_time shouldn't need to be
# populated manually.
"start_time": start_time_ns / 1e9,
"fail_type": fail_type,
"fail_reason": fail_reason,
"fail_user_frame_filename": fail_user_frame_filename,
"fail_user_frame_lineno": fail_user_frame_lineno,
"non_compliant_ops": non_compliant_ops,
"compliant_custom_ops": compliant_custom_ops,
"restart_reasons": restart_reasons,
"dynamo_time_before_restart_s": dynamo_time_before_restart,
"has_guarded_code": guarded_code is not None,
"structured_logging_overhead_s": structured_logging_overhead_s,
"config_suppress_errors": config.suppress_errors,
"config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
"specialize_float": config.specialize_float,
"dynamo_config": json.dumps(config_dict),
"is_forward": True,
"start_time_us": start_time_ns // 1000,
"end_time_us": end_time_ns // 1000,
"duration_us": duration_ns // 1000,
"dynamo_compile_time_before_restart_us": to_int_us(
metrics = CompilationMetrics(
str(compile_id),
frame_key,
code.co_name,
code.co_filename,
code.co_firstlineno,
cache_size.num_cache_entries_with_same_id_matched_objs,
cache_size.num_cache_entries,
guard_count,
shape_env_guard_count,
graph_op_count,
graph_node_count,
graph_input_count,
start_time_ns / 1e9,
entire_frame_compile_time,
backend_compile_time,
inductor_compile_time,
code_gen_time,
fail_type,
fail_reason,
fail_user_frame_filename,
fail_user_frame_lineno,
non_compliant_ops,
compliant_custom_ops,
restart_reasons,
dynamo_time_before_restart,
guarded_code is not None,
remote_cache_time_saved,
structured_logging_overhead_s,
config.suppress_errors,
config.inline_inbuilt_nn_modules,
config.specialize_float,
json.dumps(config_dict),
True, # is_forward
num_triton_bundles,
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
start_time_us=start_time_ns // 1000,
duration_us=duration_ns // 1000,
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
triton_compile_time_us=None, # TODO: instrument
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
dynamo_compile_time_before_restart_us=to_int_us(
dynamo_time_before_restart
),
"structured_logging_overhead_us": to_int_us(
structured_logging_overhead_s
cuda_synchronize_time_us=None, # TODO: instrument
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
),
}
metrics_context.update_outer(metrics)
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
torch._dynamo.callback_handler.run_end_callbacks()
chromium_event_log.log_event_end(
"dynamo", time.time_ns(), {}, chromium_start_time, True
)
# === END WARNING WARNING WARNING ===
chromium_event_log.log_event_end(
"dynamo", time.time_ns(), {}, chromium_start_time, True
)
class ConvertFrame:
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None:

View File

@ -1,95 +0,0 @@
from typing import Any, Callable, Dict, Optional, Type
from typing_extensions import TypeAlias
OnExitType: TypeAlias = Callable[[Dict[str, Any]], None]
class MetricsContext:
def __init__(self, on_exit: OnExitType):
"""
Use this class as a contextmanager to create a context under which to accumulate
a set of metrics, e.g., metrics gathered during a compilation. On exit of the
contextmanager, call the provided 'on_exit' function and pass a dictionary of
all metrics set during the lifetime of the contextmanager.
"""
self._on_exit = on_exit
self._metrics: Dict[str, Any] = {}
self._level = 0
def __enter__(self) -> "MetricsContext":
"""
Initialize metrics recording.
"""
if self._level == 0:
# In case of recursion, track at the outermost context.
self._metrics = {}
self._level += 1
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
_traceback: Any,
) -> None:
"""
At exit, call the provided on_exit function.
"""
self._level -= 1
assert self._level >= 0
if self._level == 0:
self._on_exit(self._metrics)
def in_progress(self) -> bool:
"""
True if we've entered the context.
"""
return self._level > 0
def increment(self, metric: str, value: int) -> None:
"""
Increment a metric by a given amount.
"""
if self._level == 0:
raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
if metric not in self._metrics:
self._metrics[metric] = 0
self._metrics[metric] += value
def set(self, metric: str, value: Any) -> None:
"""
Set a metric to a given value. Raises if the metric has been assigned previously
in the current context.
"""
if self._level == 0:
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
if metric in self._metrics:
raise RuntimeError(
f"Metric '{metric}' has already been set in the current context"
)
self._metrics[metric] = value
def update(self, values: Dict[str, Any]) -> None:
"""
Set multiple metrics directly. This method does NOT increment. Raises if any
metric has been assigned previously in the current context.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
existing = self._metrics.keys() & values.keys()
if existing:
raise RuntimeError(
f"Metric(s) {existing} have already been set in the current context"
)
self._metrics.update(values)
def update_outer(self, values: Dict[str, Any]) -> None:
"""
Update, but only when at the outermost context.
"""
if self._level == 0:
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
if self._level == 1:
self.update(values)

View File

@ -1394,7 +1394,6 @@ class OutputGraph:
"OutputGraph.call_user_compiler",
phase_name="backend_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
):
return self._call_user_compiler(gm)

View File

@ -43,7 +43,6 @@ from typing import (
DefaultDict,
Deque,
Dict,
Generator,
Iterable,
Iterator,
KeysView,
@ -72,7 +71,6 @@ from torch._C import (
_push_on_torch_function_stack,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.metrics_context import MetricsContext
from torch._guards import Source, TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import (
@ -141,10 +139,12 @@ log = logging.getLogger(__name__)
# profiling compilation time by function
compilation_time_metrics: Dict[str, List[float]] = {}
# This supports calculate_time_spent(), which reports cumulative times
# across the process for any "phase" populated by dynamo_timed. Reset if
# reset_frame_count() is called.
cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float)
# profiling compilation time by frame phase
frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
lambda: collections.defaultdict(float)
)
codecache_metrics: Counter[str] = collections.Counter()
timer_counter = itertools.count()
@ -220,7 +220,7 @@ def increment_frame() -> None:
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def reset_frame_count() -> None:
global curr_frame
cumulative_time_spent_ns.clear()
frame_phase_timing.clear()
compilation_time_metrics.clear()
curr_frame = 0
@ -233,16 +233,25 @@ def increment_op_count(cnt: int) -> None:
op_count += cnt
# Get the total time in seconds for each "phase"
# Calculate total time spent so far for each phase
# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
def calculate_time_spent() -> Dict[str, float]:
total_wall_time = 0.0
total_by_key = {}
for phase, timing in cumulative_time_spent_ns.items():
total_by_key[phase] = timing / 1e9
for timings in frame_phase_timing.values():
total_wall_time += timings.get(
"entire_frame_compile", timings.get("inductor_compile", 0)
)
for key, timing in timings.items():
if key not in total_by_key:
total_by_key[key] = timing
else:
total_by_key[key] += timing
if total_by_key:
total_by_key["total_wall_time"] = total_wall_time
total_by_key["total_wall_time"] = total_by_key.get(
"entire_frame_compile", 0
) + total_by_key.get("entire_backward_compile", 0)
return total_by_key
@ -261,124 +270,188 @@ def print_time_report() -> None:
print(out)
# Use the following singleton to capture and log CompilationMetrics. Entering the context
# manager allocates a new record to be logged when it exits. (You should not need to use
# this directly unless you introduce a new code path where compilation metrics would be
# gathered). While compiling, use the setters or timer in MetricsContext to update fields
# in the current context. For example:
#
# To set a single field once (use overwrite=True to overwrite):
# get_metrics_context().set("metric_name", value)
#
# To set multiple fields at once (use overwrite=True to overwrite):
# get_metrics_context().update({"name1": val1, "name2": val2})
#
# To increment an integer field:
# get_metrics_context().increment("metric_name", value)
#
# To record execution time, MetricsContext works with dynamo_timed:
# def foo(...):
# # Updates the "metric_us" field.
# with dynamo_timed("metric", dynamo_compile_column_us="metric_us")
# ...
#
_METRICS_CONTEXT: MetricsContext
def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent
def get_metrics_context() -> MetricsContext:
return _METRICS_CONTEXT
# Use frame_phase_timing to record remote_cache_time_saved
# This follows the same principles of key as the other frame phase timings,
# but is incremented by FxGraphCache (and later AOTAutogradCache) directly
def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None:
key = None
if is_backward:
# Use compile id as the frame key for backwards compilation
key = str(torch._guards.CompileContext.current_compile_id())
else:
key = str(curr_frame)
# Convert to seconds (as a float)
time_saved = time_saved_ns / 1e9
_add_time_spent(key, "remote_cache_time_saved", time_saved)
# dynamo_timed is a context manager
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
# For example:
#
# def _foo(...):
# with dynamo_timed("_foo"):
# ...
#
# Would show up as an entry in our timing dict:
# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
# This is extremely useful for granular debugging.
#
# Although it is tempting to use dynamo_timed as a decorator, please do not.
# In its decorator form it makes cProfile traces less useful as dynamo_timed
# suddenly becomes a bottleneck for lots of function calls (as only one parent
# pointer is recorded).
#
# For a higher-level mode, pass a phase_name into dynamo_timed
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
# The frame is incremented outside of this function, in def increment_frame() above.
# `fwd_only` is used to identify if this phase or function is only called
# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
@contextmanager
def dynamo_timed(
key: str,
# TODO(masneral): Deprecate this param.
phase_name: Optional[str] = None,
log_pt2_compile_event: bool = False,
# TODO(masnesral): fwd_only is ignored. Remove it.
log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event
fwd_only: bool = True,
metadata: Optional[Dict[str, object]] = None,
dynamo_compile_column_us: Optional[str] = None,
) -> Generator[Any, None, None]:
"""
dynamo_timed is a context manager
By wrapping a function in dynamo_timed, we can get a few things:
1) Log timings to pt2_compile_events.
2) Log timings to CompilationMetrics (dynamo_compile).
3) Chromium events.
4) Storing a record in compilation_time_metrics
For example:
def _foo(...):
with dynamo_timed("_foo"):
...
Would show up as an entry in our timing dict:
OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
This is extremely useful for granular debugging.
Although it is tempting to use dynamo_timed as a decorator, please do not.
In its decorator form it makes cProfile traces less useful as dynamo_timed
suddenly becomes a bottleneck for lots of function calls (as only one parent
pointer is recorded).
Params:
- key: key into compile_time_metrics. If phase_name is not provided, this is
also the event name used for pt2_compile_events logs and chromium events.
- phase_name: Optional override for the event name.
- log_pt2_compile_event: Whether to log a pt2 compile event internally.
- metadata: Extra metadata to put in pt2_compile_events.
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
field to be logged to dyname_compile column. We expect all columns to be _us;
therefore, the field name must end with "_us".
"""
# We're standardizing on microseconds for dynamo_compile timings.
if dynamo_compile_column_us is not None:
assert dynamo_compile_column_us.endswith("_us")
if phase_name:
event_name = phase_name
fn_name = key
else:
event_name = key
fn_name = None
):
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
event_metadata = {}
if metadata:
event_metadata.update(metadata)
if fn_name:
event_metadata.update({"fn_name": fn_name})
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
time_spent = float("-inf")
start_ns = time.time_ns()
chromium_log.log_event_start(event_name, start_ns, event_metadata)
try:
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
if phase_name:
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
else:
chromium_log.log_event_start(key, start_ns, {})
yield
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
except Exception as e:
fail_type = str(type(e))
fail_reason = str(e)
raise
finally:
end_ns = time.time_ns()
time_spent_ns = end_ns - start_ns
compilation_time_metrics[key].append(time_spent_ns / 1e9)
chromium_log.log_event_end(
event_name, end_ns, {}, start_ns, log_pt2_compile_event
)
if dynamo_compile_column_us:
metrics_context = get_metrics_context()
if metrics_context.in_progress():
metrics_context.increment(
dynamo_compile_column_us, time_spent_ns // 1000
)
# TODO: the events that we capture in calculate_time_spent() seem a little
# arbitrary. Currently, it's only those fields that are present in
# CompilationMetrics (but note that we accumulate by the associated event
# name, not the field name in CompilationMetrics). Do we want to keep it
# this way?
cumulative_time_spent_ns[event_name] += time_spent_ns
# Always log the end event even on exception
if phase_name:
chromium_log.log_event_end(
phase_name,
end_ns,
{},
start_ns,
log_pt2_compile_event,
)
else:
chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event)
# Only record backward compilation metrics if phase_name is not None!
if phase_name:
frame_key = str(curr_frame)
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
# use frame_key as time aggregation key.
if fwd_only and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
else:
# fwd + bwd compilation stages: inductor_compile, code_gen.
# use frame_key as time aggregation key for fwd graphs;
# use compile_id as time aggregation key for bwd graphs.
if torch._guards.TracingContext.try_get() is not None:
aot_graph_name = str(
torch._guards.TracingContext.get().aot_graph_name
)
if (
"forward" in aot_graph_name or "inference" in aot_graph_name
) and fail_type is None:
_add_time_spent(frame_key, phase_name, time_spent)
elif "backward" in aot_graph_name:
compile_id = str(
torch._guards.CompileContext.current_compile_id()
)
if fail_type is None:
_add_time_spent(compile_id, phase_name, time_spent)
# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
# one record for one bwd graph.
if phase_name == "inductor_compile":
if fail_type is None:
inductor_compile_time = frame_phase_timing[
compile_id
].get("inductor_compile", None)
code_gen_time = frame_phase_timing[compile_id].get(
"code_gen", None
)
remote_cache_time_saved = frame_phase_timing[
compile_id
].get("remote_cache_time_saved", None)
remote_fx_graph_cache_get_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_get", None)
remote_fx_graph_cache_put_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_put", None)
else:
inductor_compile_time = None
code_gen_time = None
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
)
metrics = CompilationMetrics(
compile_id=compile_id,
inductor_compile_time_s=inductor_compile_time,
code_gen_time_s=code_gen_time,
fail_type=fail_type,
fail_reason=fail_reason,
remote_cache_time_saved_s=remote_cache_time_saved,
structured_logging_overhead_s=structured_logging_overhead_s,
is_forward=False, # is_forward
num_triton_bundles=codecache_metrics.get(
"num_triton_bundles", None
),
remote_fx_graph_cache_get_time_ms=to_int_ms(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_ms=to_int_ms(
remote_fx_graph_cache_put_time
),
start_time_us=start_ns // 1000,
duration_us=(end_ns - start_ns) // 1000,
inductor_cumulative_compile_time_us=to_int_us(
inductor_compile_time
),
inductor_code_gen_cumulative_compile_time_us=to_int_us(
code_gen_time
),
distributed_ephemeral_timeout_us=to_int_us(
remote_cache_time_saved
), # TODO: instrument more accurately
structured_logging_overhead_us=to_int_us(
structured_logging_overhead_s
),
remote_fx_graph_cache_get_time_us=to_int_us(
remote_fx_graph_cache_get_time
),
remote_fx_graph_cache_put_time_us=to_int_us(
remote_fx_graph_cache_put_time
),
)
record_compilation_metrics(metrics)
@overload
@ -793,11 +866,6 @@ def to_int_us(v: Optional[float]) -> Optional[int]:
return None if v is None else int(v * 1_000_000)
# Version field added to every log. Increment to make it easier to distinguish new
# vs. old entries when you make a substantive change to how the logs are populated.
LOG_FORMAT_VERSION = 2
@dataclasses.dataclass
class CompilationMetrics:
compile_id: Optional[str] = None
@ -845,18 +913,15 @@ class CompilationMetrics:
aot_autograd_cumulative_compile_time_us: Optional[int] = None
inductor_cumulative_compile_time_us: Optional[int] = None
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
triton_compile_time_us: Optional[int] = None # TODO: instrument
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
triton_compile_time_us: Optional[int] = None
runtime_cudagraphify_time_us: Optional[int] = None
runtime_triton_autotune_time_us: Optional[int] = None
dynamo_compile_time_before_restart_us: Optional[int] = None
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
cuda_synchronize_time_us: Optional[int] = None
distributed_ephemeral_timeout_us: Optional[int] = None
structured_logging_overhead_us: Optional[int] = None
remote_fx_graph_cache_get_time_us: Optional[int] = None
remote_fx_graph_cache_put_time_us: Optional[int] = None
backward_cumulative_compile_time_us: Optional[int] = None
end_time_us: Optional[int] = None
log_format_version: int = LOG_FORMAT_VERSION
DEFAULT_COMPILATION_METRICS_LIMIT = 64
@ -904,32 +969,8 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
)
def record_compilation_metrics(metrics: Dict[str, Any]):
# TODO: Temporary; populate legacy fields from their replacements.
# Remove when we decide we can really deprecate them.
def us_to_s(field):
metric = metrics.get(field, None)
return metric / 1e6 if metric is not None else None
def us_to_ms(field):
metric = metrics.get(field, None)
return metric // 1000 if metric is not None else None
legacy_metrics = {
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
"code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"),
"remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"),
"remote_fx_graph_cache_get_time_ms": us_to_ms(
"remote_fx_graph_cache_get_time_us"
),
"remote_fx_graph_cache_put_time_ms": us_to_ms(
"remote_fx_graph_cache_put_time_us"
),
}
compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics})
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
if compilation_metrics.is_forward:
name = "compilation_metrics"
@ -938,7 +979,10 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
name = "bwd_compilation_metrics"
torch._logging.trace_structured(
name,
lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()},
lambda: {
k: list(v) if isinstance(v, set) else v
for k, v in dataclasses.asdict(compilation_metrics).items()
},
# NB: Because compilation metrics *includes* the logging overhead time,
# we can't both *measure* the logging overhead of compilation metrics
# without making it inconsistent with compilation metrics itself, so
@ -949,10 +993,6 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
log_compilation_event(compilation_metrics)
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
def set_compilation_metrics_limit(new_size: int) -> None:
global _compilation_metrics
while len(_compilation_metrics) > new_size:

View File

@ -632,7 +632,7 @@ class AOTAutogradCache:
)
# TODO: should we use the same field for remote cache time saved for both
# FXGraphCache and AOTAutogradCache?
# get_metrics_context().increment(...)
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns

View File

@ -10,7 +10,6 @@ import builtins
import collections
import itertools
import pprint
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
@ -19,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.utils.dlpack
from torch import Tensor
from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us
from torch._guards import (
compile_context,
CompileContext,
@ -2004,49 +2002,17 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
context = torch._C._DisableAutocast if disable_amp else nullcontext
with tracing(saved_context), compile_context(
saved_compile_context
), context(), track_graph_compiling(
aot_config, "backward"
), get_metrics_context(), dynamo_timed(
"backward._backward_impl",
phase_name="entire_backward_compile",
dynamo_compile_column_us="backward_cumulative_compile_time_us",
):
fail_type: Optional[str] = None
fail_reason: Optional[str] = None
start_ns = time.time_ns()
try:
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
), context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)
# Maybe save cache entry
if try_save_cache_entry is not None:
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,
aot_config,
)
# Maybe save cache entry
if try_save_cache_entry is not None:
try_save_cache_entry(
CompiledFunction.compiled_bw,
fw_metadata,
aot_config,
)
except Exception as e:
# TODO(masnesral): Populating the exception info should be automatic.
fail_type = type(e).__qualname__
fail_reason = str(e)
finally:
# TODO(masnesral): Populating time fields should be automatic.
end_ns = time.time_ns()
metrics = {
"compile_id": str(
torch._guards.CompileContext.current_compile_id()
),
"fail_type": fail_type,
"fail_reason": fail_reason,
"is_forward": False,
"start_time_us": start_ns // 1000,
"end_time_us": end_ns // 1000,
"duration_us": (end_ns - start_ns) // 1000,
"structured_logging_overhead_us": to_int_us(
torch._logging.get_structured_logging_overhead(),
),
}
get_metrics_context().update_outer(metrics)
if (
torch._functorch.config.donated_buffer

View File

@ -54,10 +54,11 @@ import torch
import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.utils import (
add_remote_cache_time_saved,
codecache_metrics,
counters,
dynamo_timed,
get_chromium_event_logger,
get_metrics_context,
)
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
@ -1151,7 +1152,7 @@ class FxGraphCache:
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
)
if len(meta.cached_kernel_names) > 0:
get_metrics_context().increment("num_triton_bundles", 1)
codecache_metrics["num_triton_bundles"] += 1
inductor_meta = autotune_cache.inductor_meta_from_config()
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
@ -1448,9 +1449,7 @@ class FxGraphCache:
if (time_saved_ns := compiled_graph._time_taken_ns) is not None:
cache_info["time_saved_ns"] = time_saved_ns
get_metrics_context().increment(
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
)
add_remote_cache_time_saved(time_saved_ns, is_backward)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns

View File

@ -567,7 +567,7 @@ def compile_fx_inner(
"compile_fx_inner",
phase_name="inductor_compile",
log_pt2_compile_event=True,
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
fwd_only=False,
)
)
# NB: Why is this the dynamo_compile counter? The rule here is that

View File

@ -1950,7 +1950,7 @@ class GraphLowering(torch.fx.Interpreter):
"GraphLowering.compile_to_module",
phase_name="code_gen",
log_pt2_compile_event=True,
dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
fwd_only=False,
):
return self._compile_to_module()

View File

@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial(
"FbRemoteFxGraphCache.get",
phase_name="remote_fx_graph_cache_get",
log_pt2_compile_event=False,
dynamo_compile_column_us="remote_fx_graph_cache_get_time_us",
fwd_only=False,
)
remote_fx_cache_put_timed = functools.partial(
dynamo_timed,
"FbRemoteFxGraphCache.put",
phase_name="remote_fx_graph_cache_put",
log_pt2_compile_event=False,
dynamo_compile_column_us="remote_fx_graph_cache_put_time_us",
fwd_only=False,
)

View File

@ -135,13 +135,7 @@ try:
except AttributeError: # Compile workers only have a mock version of torch
@contextlib.contextmanager
def dynamo_timed(
key,
phase_name=None,
fwd_only=True,
metadata=None,
dynamo_compile_column_us=None,
):
def dynamo_timed(key, phase_name=None, fwd_only=True):
yield

View File

@ -1099,6 +1099,7 @@ class LazyString:
structured_logging_overhead: Dict[str, float] = defaultdict(float)
# Same principle as add_remote_cache_time_saved, but do it for structured logging
def add_structured_logging_overhead(time_spent: float) -> None:
global structured_logging_overhead
key = None