mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[logging] Overhaul dynamo_timed and CompilationMetrics logging. (#139849)"
This reverts commit cb15c1515778499ae801dcf67d55c8bdab4724ef. Reverted https://github.com/pytorch/pytorch/pull/139849 on behalf of https://github.com/kit1980 due to Breaking an internal tests + there is a bug according to the author ([comment](https://github.com/pytorch/pytorch/pull/139849#issuecomment-2474459094))
This commit is contained in:
@ -1,78 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
from torch._dynamo.metrics_context import MetricsContext
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
|
||||
|
||||
class TestMetricsContext(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.metrics = {}
|
||||
|
||||
def _on_exit(self, metrics):
|
||||
# Save away the metrics to be validated in the test.
|
||||
self.metrics = metrics.copy()
|
||||
|
||||
def test_context_exists(self):
|
||||
"""
|
||||
Setting a value without entering the context should raise.
|
||||
"""
|
||||
context = MetricsContext(self._on_exit)
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.increment("m", 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.set("m", 1)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
||||
context.update({"m", 1})
|
||||
|
||||
def test_nested_context(self):
|
||||
"""
|
||||
Only the outermost context should get an on_exit call, and it should
|
||||
include everything.
|
||||
"""
|
||||
context = MetricsContext(self._on_exit)
|
||||
with context:
|
||||
with context:
|
||||
context.set("m1", 1)
|
||||
self.assertEqual(self.metrics, {})
|
||||
context.set("m2", 2)
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
Validate various ways to set metrics.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.set("m1", 1)
|
||||
context.set("m2", 2)
|
||||
context.update({"m3": 3, "m4": 4})
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
|
||||
|
||||
def test_set_disallow_overwrite(self):
|
||||
"""
|
||||
Validate set won't overwrite.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.set("m1", 1)
|
||||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
context.set("m1", 2)
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1})
|
||||
|
||||
def test_update_disallow_overwrite(self):
|
||||
"""
|
||||
Validate update won't overwite.
|
||||
"""
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.update({"m1": 1, "m2": 2})
|
||||
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
||||
context.update({"m1": 7, "m3": 3})
|
||||
|
||||
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -159,7 +159,6 @@ class TestDynamoTimed(TestCase):
|
||||
'_recursive_post_grad_passes': [0.0, 0.0],
|
||||
'_recursive_pre_grad_passes': [0.0],
|
||||
'async_compile.wait': [0.0, 0.0],
|
||||
'backward._backward_impl': [0.0],
|
||||
'compile_file': [0.0, 0.0],
|
||||
'compile_fx.<locals>.bw_compiler': [0.0],
|
||||
'compile_fx.<locals>.fw_compiler_base': [0.0],
|
||||
@ -175,7 +174,6 @@ class TestDynamoTimed(TestCase):
|
||||
"""\
|
||||
{'backend_compile': 0.0,
|
||||
'code_gen': 0.0,
|
||||
'entire_backward_compile': 0.0,
|
||||
'entire_frame_compile': 0.0,
|
||||
'inductor_compile': 0.0,
|
||||
'total_wall_time': 0.0}""", # noqa: B950
|
||||
@ -202,7 +200,6 @@ class TestDynamoTimed(TestCase):
|
||||
{'accumulated_cache_size': 0,
|
||||
'aot_autograd_cumulative_compile_time_us': 0,
|
||||
'backend_compile_time_s': 0.0,
|
||||
'backward_cumulative_compile_time_us': None,
|
||||
'cache_size': 0,
|
||||
'co_filename': None,
|
||||
'co_firstlineno': None,
|
||||
@ -213,13 +210,12 @@ class TestDynamoTimed(TestCase):
|
||||
'config_inline_inbuilt_nn_modules': False,
|
||||
'config_suppress_errors': False,
|
||||
'cuda_synchronize_time_us': None,
|
||||
'distributed_ephemeral_timeout_us': None,
|
||||
'distributed_ephemeral_timeout_us': 0,
|
||||
'duration_us': 0,
|
||||
'dynamo_compile_time_before_restart_us': 0,
|
||||
'dynamo_config': None,
|
||||
'dynamo_cumulative_compile_time_us': 0,
|
||||
'dynamo_time_before_restart_s': 0.0,
|
||||
'end_time_us': 100,
|
||||
'entire_frame_compile_time_s': 0.0,
|
||||
'fail_reason': None,
|
||||
'fail_type': None,
|
||||
@ -235,10 +231,9 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_compile_time_s': 0.0,
|
||||
'inductor_cumulative_compile_time_us': 0,
|
||||
'is_forward': True,
|
||||
'log_format_version': 2,
|
||||
'non_compliant_ops': set(),
|
||||
'num_triton_bundles': None,
|
||||
'remote_cache_time_saved_s': None,
|
||||
'remote_cache_time_saved_s': 0,
|
||||
'remote_fx_graph_cache_get_time_ms': None,
|
||||
'remote_fx_graph_cache_get_time_us': None,
|
||||
'remote_fx_graph_cache_put_time_ms': None,
|
||||
@ -262,7 +257,6 @@ class TestDynamoTimed(TestCase):
|
||||
{'accumulated_cache_size': None,
|
||||
'aot_autograd_cumulative_compile_time_us': None,
|
||||
'backend_compile_time_s': None,
|
||||
'backward_cumulative_compile_time_us': 0,
|
||||
'cache_size': None,
|
||||
'co_filename': None,
|
||||
'co_firstlineno': None,
|
||||
@ -279,7 +273,6 @@ class TestDynamoTimed(TestCase):
|
||||
'dynamo_config': None,
|
||||
'dynamo_cumulative_compile_time_us': None,
|
||||
'dynamo_time_before_restart_s': None,
|
||||
'end_time_us': 100,
|
||||
'entire_frame_compile_time_s': None,
|
||||
'fail_reason': None,
|
||||
'fail_type': None,
|
||||
@ -295,7 +288,6 @@ class TestDynamoTimed(TestCase):
|
||||
'inductor_compile_time_s': 0.0,
|
||||
'inductor_cumulative_compile_time_us': 0,
|
||||
'is_forward': False,
|
||||
'log_format_version': 2,
|
||||
'non_compliant_ops': None,
|
||||
'num_triton_bundles': None,
|
||||
'remote_cache_time_saved_s': None,
|
||||
@ -310,7 +302,7 @@ class TestDynamoTimed(TestCase):
|
||||
'specialize_float': None,
|
||||
'start_time': None,
|
||||
'start_time_us': 100,
|
||||
'structured_logging_overhead_s': None,
|
||||
'structured_logging_overhead_s': 0.0,
|
||||
'structured_logging_overhead_us': 0,
|
||||
'triton_compile_time_us': None}""", # noqa: B950
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ import torch
|
||||
import torch._logging
|
||||
from torch._C._dynamo.guards import GlobalStateGuard
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter, get_metrics_context
|
||||
from torch._dynamo.utils import CompileTimeInstructionCounter
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
from torch._logging import structured
|
||||
from torch._utils_internal import (
|
||||
@ -105,9 +105,12 @@ from .symbolic_convert import (
|
||||
from .trace_rules import is_numpy
|
||||
from .utils import (
|
||||
CleanupManager,
|
||||
codecache_metrics,
|
||||
CompilationMetrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
format_bytecode,
|
||||
frame_phase_timing,
|
||||
gen_record_file_name,
|
||||
get_chromium_event_logger,
|
||||
increment_frame,
|
||||
@ -115,8 +118,10 @@ from .utils import (
|
||||
istype,
|
||||
LazyString,
|
||||
orig_code_map,
|
||||
record_compilation_metrics,
|
||||
reset_graph_break_dup_checker,
|
||||
setup_compile_debug,
|
||||
to_int_ms,
|
||||
to_int_us,
|
||||
troubleshooting_url,
|
||||
write_record_to_file,
|
||||
@ -693,9 +698,7 @@ def _compile(
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
dynamo_timed(
|
||||
"_compile.compile_inner",
|
||||
phase_name="entire_frame_compile",
|
||||
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
|
||||
"_compile.compile_inner", phase_name="entire_frame_compile"
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
@ -860,11 +863,9 @@ def _compile(
|
||||
chromium_event_log.reset()
|
||||
chromium_start_time = time.time_ns()
|
||||
chromium_event_log.log_event_start("dynamo", chromium_start_time, {})
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(
|
||||
CompileContext(compile_id)
|
||||
), metrics_context:
|
||||
):
|
||||
restart_reasons: set[str] = set()
|
||||
# This is shared across restarts
|
||||
speculation_log = SpeculationLog()
|
||||
@ -971,6 +972,7 @@ def _compile(
|
||||
fail_user_frame_lineno: Optional[int] = None
|
||||
torch._dynamo.utils.ReinplaceCounters.clear()
|
||||
guarded_code = None
|
||||
codecache_metrics.clear()
|
||||
try:
|
||||
guarded_code = compile_inner(code, one_graph, hooks, transform)
|
||||
|
||||
@ -987,7 +989,6 @@ def _compile(
|
||||
|
||||
return guarded_code
|
||||
except Exception as e:
|
||||
# TODO(masnesral): Populating the exception info should be automatic
|
||||
fail_type = type(e).__qualname__
|
||||
fail_reason = str(e)
|
||||
# NB: e's msg is mutated here to add user stack, but we DON'T want
|
||||
@ -1037,34 +1038,66 @@ def _compile(
|
||||
if tracer:
|
||||
tracer.output.local_scope = {}
|
||||
|
||||
end_time_ns = time.time_ns()
|
||||
duration_ns = end_time_ns - start_time_ns
|
||||
duration_ns = time.time_ns() - start_time_ns
|
||||
|
||||
from .utils import curr_frame
|
||||
|
||||
frame_key = str(curr_frame)
|
||||
if fail_reason is None and output is not None:
|
||||
if (
|
||||
fail_reason is None
|
||||
and output is not None
|
||||
and frame_key in frame_phase_timing
|
||||
):
|
||||
guard_count = len(output.guards)
|
||||
shape_env_guard_count = len(output.shape_env.guards)
|
||||
graph_op_count = output.count_calls()
|
||||
graph_node_count = len(output.graph.nodes)
|
||||
graph_input_count = len(output.placeholders)
|
||||
entire_frame_compile_time = frame_phase_timing[frame_key].get(
|
||||
"entire_frame_compile", None
|
||||
)
|
||||
backend_compile_time = frame_phase_timing[frame_key].get(
|
||||
"backend_compile", None
|
||||
)
|
||||
inductor_compile_time = frame_phase_timing[frame_key].get(
|
||||
"inductor_compile", None
|
||||
)
|
||||
code_gen_time = frame_phase_timing[frame_key].get("code_gen", None)
|
||||
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
|
||||
compliant_custom_ops = {
|
||||
op.__qualname__ for op in output.compliant_custom_ops
|
||||
}
|
||||
remote_cache_time_saved = frame_phase_timing[frame_key].get(
|
||||
"remote_cache_time_saved", 0
|
||||
)
|
||||
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
|
||||
"remote_fx_graph_cache_get", None
|
||||
)
|
||||
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
|
||||
"remote_fx_graph_cache_put", None
|
||||
)
|
||||
num_triton_bundles = codecache_metrics.get("num_triton_bundles", None)
|
||||
torch._dynamo.utils.ReinplaceCounters.log()
|
||||
|
||||
else:
|
||||
guard_count = None
|
||||
shape_env_guard_count = None
|
||||
graph_op_count = None
|
||||
graph_node_count = None
|
||||
graph_input_count = None
|
||||
entire_frame_compile_time = None
|
||||
backend_compile_time = None
|
||||
inductor_compile_time = None
|
||||
code_gen_time = None
|
||||
non_compliant_ops = set({})
|
||||
compliant_custom_ops = set({})
|
||||
restart_reasons = set()
|
||||
# If compilation failed, the entire time is wasted
|
||||
dynamo_time_before_restart = duration_ns / 1e9
|
||||
remote_cache_time_saved = None
|
||||
remote_fx_graph_cache_get_time = None
|
||||
remote_fx_graph_cache_put_time = None
|
||||
num_triton_bundles = None
|
||||
|
||||
structured_logging_overhead_s = (
|
||||
torch._logging.get_structured_logging_overhead()
|
||||
@ -1099,55 +1132,74 @@ def _compile(
|
||||
}
|
||||
|
||||
config_dict = clean_for_json(config.get_config_copy())
|
||||
metrics = {
|
||||
"compile_id": str(compile_id),
|
||||
"frame_key": frame_key,
|
||||
"co_name": code.co_name,
|
||||
"co_filename": code.co_filename,
|
||||
"co_firstlineno": code.co_firstlineno,
|
||||
"cache_size": cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
"accumulated_cache_size": cache_size.num_cache_entries,
|
||||
"guard_count": guard_count,
|
||||
"shape_env_guard_count": shape_env_guard_count,
|
||||
"graph_op_count": graph_op_count,
|
||||
"graph_node_count": graph_node_count,
|
||||
"graph_input_count": graph_input_count,
|
||||
# TODO(masnesral): start_time and end_time shouldn't need to be
|
||||
# populated manually.
|
||||
"start_time": start_time_ns / 1e9,
|
||||
"fail_type": fail_type,
|
||||
"fail_reason": fail_reason,
|
||||
"fail_user_frame_filename": fail_user_frame_filename,
|
||||
"fail_user_frame_lineno": fail_user_frame_lineno,
|
||||
"non_compliant_ops": non_compliant_ops,
|
||||
"compliant_custom_ops": compliant_custom_ops,
|
||||
"restart_reasons": restart_reasons,
|
||||
"dynamo_time_before_restart_s": dynamo_time_before_restart,
|
||||
"has_guarded_code": guarded_code is not None,
|
||||
"structured_logging_overhead_s": structured_logging_overhead_s,
|
||||
"config_suppress_errors": config.suppress_errors,
|
||||
"config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
|
||||
"specialize_float": config.specialize_float,
|
||||
"dynamo_config": json.dumps(config_dict),
|
||||
"is_forward": True,
|
||||
"start_time_us": start_time_ns // 1000,
|
||||
"end_time_us": end_time_ns // 1000,
|
||||
"duration_us": duration_ns // 1000,
|
||||
"dynamo_compile_time_before_restart_us": to_int_us(
|
||||
metrics = CompilationMetrics(
|
||||
str(compile_id),
|
||||
frame_key,
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
cache_size.num_cache_entries,
|
||||
guard_count,
|
||||
shape_env_guard_count,
|
||||
graph_op_count,
|
||||
graph_node_count,
|
||||
graph_input_count,
|
||||
start_time_ns / 1e9,
|
||||
entire_frame_compile_time,
|
||||
backend_compile_time,
|
||||
inductor_compile_time,
|
||||
code_gen_time,
|
||||
fail_type,
|
||||
fail_reason,
|
||||
fail_user_frame_filename,
|
||||
fail_user_frame_lineno,
|
||||
non_compliant_ops,
|
||||
compliant_custom_ops,
|
||||
restart_reasons,
|
||||
dynamo_time_before_restart,
|
||||
guarded_code is not None,
|
||||
remote_cache_time_saved,
|
||||
structured_logging_overhead_s,
|
||||
config.suppress_errors,
|
||||
config.inline_inbuilt_nn_modules,
|
||||
config.specialize_float,
|
||||
json.dumps(config_dict),
|
||||
True, # is_forward
|
||||
num_triton_bundles,
|
||||
to_int_ms(remote_fx_graph_cache_get_time),
|
||||
to_int_ms(remote_fx_graph_cache_put_time),
|
||||
start_time_us=start_time_ns // 1000,
|
||||
duration_us=duration_ns // 1000,
|
||||
dynamo_cumulative_compile_time_us=to_int_us(entire_frame_compile_time),
|
||||
aot_autograd_cumulative_compile_time_us=to_int_us(backend_compile_time),
|
||||
inductor_cumulative_compile_time_us=to_int_us(inductor_compile_time),
|
||||
inductor_code_gen_cumulative_compile_time_us=to_int_us(code_gen_time),
|
||||
triton_compile_time_us=None, # TODO: instrument
|
||||
runtime_cudagraphify_time_us=None, # TODO: instrument in separate event
|
||||
runtime_triton_autotune_time_us=None, # TODO: instrument in separate event
|
||||
dynamo_compile_time_before_restart_us=to_int_us(
|
||||
dynamo_time_before_restart
|
||||
),
|
||||
"structured_logging_overhead_us": to_int_us(
|
||||
structured_logging_overhead_s
|
||||
cuda_synchronize_time_us=None, # TODO: instrument
|
||||
distributed_ephemeral_timeout_us=to_int_us(
|
||||
remote_cache_time_saved
|
||||
), # TODO: instrument more accurately
|
||||
structured_logging_overhead_us=to_int_us(structured_logging_overhead_s),
|
||||
remote_fx_graph_cache_get_time_us=to_int_us(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
}
|
||||
metrics_context.update_outer(metrics)
|
||||
remote_fx_graph_cache_put_time_us=to_int_us(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
)
|
||||
record_compilation_metrics(metrics)
|
||||
torch._dynamo.callback_handler.run_end_callbacks()
|
||||
chromium_event_log.log_event_end(
|
||||
"dynamo", time.time_ns(), {}, chromium_start_time, True
|
||||
)
|
||||
# === END WARNING WARNING WARNING ===
|
||||
|
||||
chromium_event_log.log_event_end(
|
||||
"dynamo", time.time_ns(), {}, chromium_start_time, True
|
||||
)
|
||||
|
||||
|
||||
class ConvertFrame:
|
||||
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks) -> None:
|
||||
|
@ -1,95 +0,0 @@
|
||||
from typing import Any, Callable, Dict, Optional, Type
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
OnExitType: TypeAlias = Callable[[Dict[str, Any]], None]
|
||||
|
||||
|
||||
class MetricsContext:
|
||||
def __init__(self, on_exit: OnExitType):
|
||||
"""
|
||||
Use this class as a contextmanager to create a context under which to accumulate
|
||||
a set of metrics, e.g., metrics gathered during a compilation. On exit of the
|
||||
contextmanager, call the provided 'on_exit' function and pass a dictionary of
|
||||
all metrics set during the lifetime of the contextmanager.
|
||||
"""
|
||||
self._on_exit = on_exit
|
||||
self._metrics: Dict[str, Any] = {}
|
||||
self._level = 0
|
||||
|
||||
def __enter__(self) -> "MetricsContext":
|
||||
"""
|
||||
Initialize metrics recording.
|
||||
"""
|
||||
if self._level == 0:
|
||||
# In case of recursion, track at the outermost context.
|
||||
self._metrics = {}
|
||||
|
||||
self._level += 1
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
_traceback: Any,
|
||||
) -> None:
|
||||
"""
|
||||
At exit, call the provided on_exit function.
|
||||
"""
|
||||
self._level -= 1
|
||||
assert self._level >= 0
|
||||
if self._level == 0:
|
||||
self._on_exit(self._metrics)
|
||||
|
||||
def in_progress(self) -> bool:
|
||||
"""
|
||||
True if we've entered the context.
|
||||
"""
|
||||
return self._level > 0
|
||||
|
||||
def increment(self, metric: str, value: int) -> None:
|
||||
"""
|
||||
Increment a metric by a given amount.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError(f"Cannot increment {metric} outside of a MetricsContext")
|
||||
if metric not in self._metrics:
|
||||
self._metrics[metric] = 0
|
||||
self._metrics[metric] += value
|
||||
|
||||
def set(self, metric: str, value: Any) -> None:
|
||||
"""
|
||||
Set a metric to a given value. Raises if the metric has been assigned previously
|
||||
in the current context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
|
||||
if metric in self._metrics:
|
||||
raise RuntimeError(
|
||||
f"Metric '{metric}' has already been set in the current context"
|
||||
)
|
||||
self._metrics[metric] = value
|
||||
|
||||
def update(self, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Set multiple metrics directly. This method does NOT increment. Raises if any
|
||||
metric has been assigned previously in the current context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
|
||||
existing = self._metrics.keys() & values.keys()
|
||||
if existing:
|
||||
raise RuntimeError(
|
||||
f"Metric(s) {existing} have already been set in the current context"
|
||||
)
|
||||
self._metrics.update(values)
|
||||
|
||||
def update_outer(self, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Update, but only when at the outermost context.
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError("Cannot update metrics outside of a MetricsContext")
|
||||
if self._level == 1:
|
||||
self.update(values)
|
@ -1394,7 +1394,6 @@ class OutputGraph:
|
||||
"OutputGraph.call_user_compiler",
|
||||
phase_name="backend_compile",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
|
||||
):
|
||||
return self._call_user_compiler(gm)
|
||||
|
||||
|
@ -43,7 +43,6 @@ from typing import (
|
||||
DefaultDict,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
KeysView,
|
||||
@ -72,7 +71,6 @@ from torch._C import (
|
||||
_push_on_torch_function_stack,
|
||||
)
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.metrics_context import MetricsContext
|
||||
from torch._guards import Source, TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import (
|
||||
@ -141,10 +139,12 @@ log = logging.getLogger(__name__)
|
||||
# profiling compilation time by function
|
||||
compilation_time_metrics: Dict[str, List[float]] = {}
|
||||
|
||||
# This supports calculate_time_spent(), which reports cumulative times
|
||||
# across the process for any "phase" populated by dynamo_timed. Reset if
|
||||
# reset_frame_count() is called.
|
||||
cumulative_time_spent_ns: Dict[str, float] = collections.defaultdict(float)
|
||||
# profiling compilation time by frame phase
|
||||
frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
|
||||
lambda: collections.defaultdict(float)
|
||||
)
|
||||
|
||||
codecache_metrics: Counter[str] = collections.Counter()
|
||||
|
||||
timer_counter = itertools.count()
|
||||
|
||||
@ -220,7 +220,7 @@ def increment_frame() -> None:
|
||||
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
|
||||
def reset_frame_count() -> None:
|
||||
global curr_frame
|
||||
cumulative_time_spent_ns.clear()
|
||||
frame_phase_timing.clear()
|
||||
compilation_time_metrics.clear()
|
||||
curr_frame = 0
|
||||
|
||||
@ -233,16 +233,25 @@ def increment_op_count(cnt: int) -> None:
|
||||
op_count += cnt
|
||||
|
||||
|
||||
# Get the total time in seconds for each "phase"
|
||||
# Calculate total time spent so far for each phase
|
||||
# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
|
||||
def calculate_time_spent() -> Dict[str, float]:
|
||||
total_wall_time = 0.0
|
||||
total_by_key = {}
|
||||
for phase, timing in cumulative_time_spent_ns.items():
|
||||
total_by_key[phase] = timing / 1e9
|
||||
for timings in frame_phase_timing.values():
|
||||
total_wall_time += timings.get(
|
||||
"entire_frame_compile", timings.get("inductor_compile", 0)
|
||||
)
|
||||
|
||||
for key, timing in timings.items():
|
||||
if key not in total_by_key:
|
||||
total_by_key[key] = timing
|
||||
else:
|
||||
total_by_key[key] += timing
|
||||
|
||||
if total_by_key:
|
||||
total_by_key["total_wall_time"] = total_wall_time
|
||||
|
||||
total_by_key["total_wall_time"] = total_by_key.get(
|
||||
"entire_frame_compile", 0
|
||||
) + total_by_key.get("entire_backward_compile", 0)
|
||||
return total_by_key
|
||||
|
||||
|
||||
@ -261,124 +270,188 @@ def print_time_report() -> None:
|
||||
print(out)
|
||||
|
||||
|
||||
# Use the following singleton to capture and log CompilationMetrics. Entering the context
|
||||
# manager allocates a new record to be logged when it exits. (You should not need to use
|
||||
# this directly unless you introduce a new code path where compilation metrics would be
|
||||
# gathered). While compiling, use the setters or timer in MetricsContext to update fields
|
||||
# in the current context. For example:
|
||||
#
|
||||
# To set a single field once (use overwrite=True to overwrite):
|
||||
# get_metrics_context().set("metric_name", value)
|
||||
#
|
||||
# To set multiple fields at once (use overwrite=True to overwrite):
|
||||
# get_metrics_context().update({"name1": val1, "name2": val2})
|
||||
#
|
||||
# To increment an integer field:
|
||||
# get_metrics_context().increment("metric_name", value)
|
||||
#
|
||||
# To record execution time, MetricsContext works with dynamo_timed:
|
||||
# def foo(...):
|
||||
# # Updates the "metric_us" field.
|
||||
# with dynamo_timed("metric", dynamo_compile_column_us="metric_us")
|
||||
# ...
|
||||
#
|
||||
_METRICS_CONTEXT: MetricsContext
|
||||
def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
|
||||
frame_phase_timing[key][phase_name] += time_spent
|
||||
|
||||
|
||||
def get_metrics_context() -> MetricsContext:
|
||||
return _METRICS_CONTEXT
|
||||
# Use frame_phase_timing to record remote_cache_time_saved
|
||||
# This follows the same principles of key as the other frame phase timings,
|
||||
# but is incremented by FxGraphCache (and later AOTAutogradCache) directly
|
||||
def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) -> None:
|
||||
key = None
|
||||
if is_backward:
|
||||
# Use compile id as the frame key for backwards compilation
|
||||
key = str(torch._guards.CompileContext.current_compile_id())
|
||||
else:
|
||||
key = str(curr_frame)
|
||||
# Convert to seconds (as a float)
|
||||
time_saved = time_saved_ns / 1e9
|
||||
_add_time_spent(key, "remote_cache_time_saved", time_saved)
|
||||
|
||||
|
||||
# dynamo_timed is a context manager
|
||||
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
|
||||
# where the key is the functions name.
|
||||
# For example:
|
||||
#
|
||||
# def _foo(...):
|
||||
# with dynamo_timed("_foo"):
|
||||
# ...
|
||||
#
|
||||
# Would show up as an entry in our timing dict:
|
||||
# OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
|
||||
# This is extremely useful for granular debugging.
|
||||
#
|
||||
# Although it is tempting to use dynamo_timed as a decorator, please do not.
|
||||
# In its decorator form it makes cProfile traces less useful as dynamo_timed
|
||||
# suddenly becomes a bottleneck for lots of function calls (as only one parent
|
||||
# pointer is recorded).
|
||||
#
|
||||
# For a higher-level mode, pass a phase_name into dynamo_timed
|
||||
# phase_names record an extra record into a separate compilation timing structure,
|
||||
# one keyed on frame+name rather than function.
|
||||
# The frame is incremented outside of this function, in def increment_frame() above.
|
||||
# `fwd_only` is used to identify if this phase or function is only called
|
||||
# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
|
||||
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
|
||||
|
||||
|
||||
@contextmanager
|
||||
def dynamo_timed(
|
||||
key: str,
|
||||
# TODO(masneral): Deprecate this param.
|
||||
phase_name: Optional[str] = None,
|
||||
log_pt2_compile_event: bool = False,
|
||||
# TODO(masnesral): fwd_only is ignored. Remove it.
|
||||
log_pt2_compile_event: bool = False, # Whether or not to log it to internal pt2 compile event
|
||||
fwd_only: bool = True,
|
||||
metadata: Optional[Dict[str, object]] = None,
|
||||
dynamo_compile_column_us: Optional[str] = None,
|
||||
) -> Generator[Any, None, None]:
|
||||
"""
|
||||
dynamo_timed is a context manager
|
||||
By wrapping a function in dynamo_timed, we can get a few things:
|
||||
|
||||
1) Log timings to pt2_compile_events.
|
||||
2) Log timings to CompilationMetrics (dynamo_compile).
|
||||
3) Chromium events.
|
||||
4) Storing a record in compilation_time_metrics
|
||||
For example:
|
||||
|
||||
def _foo(...):
|
||||
with dynamo_timed("_foo"):
|
||||
...
|
||||
|
||||
Would show up as an entry in our timing dict:
|
||||
OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
|
||||
This is extremely useful for granular debugging.
|
||||
|
||||
Although it is tempting to use dynamo_timed as a decorator, please do not.
|
||||
In its decorator form it makes cProfile traces less useful as dynamo_timed
|
||||
suddenly becomes a bottleneck for lots of function calls (as only one parent
|
||||
pointer is recorded).
|
||||
|
||||
Params:
|
||||
- key: key into compile_time_metrics. If phase_name is not provided, this is
|
||||
also the event name used for pt2_compile_events logs and chromium events.
|
||||
- phase_name: Optional override for the event name.
|
||||
- log_pt2_compile_event: Whether to log a pt2 compile event internally.
|
||||
- metadata: Extra metadata to put in pt2_compile_events.
|
||||
- dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
|
||||
field to be logged to dyname_compile column. We expect all columns to be _us;
|
||||
therefore, the field name must end with "_us".
|
||||
"""
|
||||
# We're standardizing on microseconds for dynamo_compile timings.
|
||||
if dynamo_compile_column_us is not None:
|
||||
assert dynamo_compile_column_us.endswith("_us")
|
||||
|
||||
if phase_name:
|
||||
event_name = phase_name
|
||||
fn_name = key
|
||||
else:
|
||||
event_name = key
|
||||
fn_name = None
|
||||
|
||||
):
|
||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||
if key not in compilation_time_metrics:
|
||||
compilation_time_metrics[key] = []
|
||||
|
||||
event_metadata = {}
|
||||
if metadata:
|
||||
event_metadata.update(metadata)
|
||||
if fn_name:
|
||||
event_metadata.update({"fn_name": fn_name})
|
||||
|
||||
chromium_log: ChromiumEventLogger = get_chromium_event_logger()
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
time_spent = float("-inf")
|
||||
start_ns = time.time_ns()
|
||||
chromium_log.log_event_start(event_name, start_ns, event_metadata)
|
||||
|
||||
try:
|
||||
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
|
||||
t0 = time.time()
|
||||
if phase_name:
|
||||
chromium_log.log_event_start(phase_name, start_ns, {"fn_name": key})
|
||||
else:
|
||||
chromium_log.log_event_start(key, start_ns, {})
|
||||
yield
|
||||
time_spent = time.time() - t0
|
||||
compilation_time_metrics[key].append(time_spent)
|
||||
except Exception as e:
|
||||
fail_type = str(type(e))
|
||||
fail_reason = str(e)
|
||||
raise
|
||||
finally:
|
||||
end_ns = time.time_ns()
|
||||
time_spent_ns = end_ns - start_ns
|
||||
compilation_time_metrics[key].append(time_spent_ns / 1e9)
|
||||
chromium_log.log_event_end(
|
||||
event_name, end_ns, {}, start_ns, log_pt2_compile_event
|
||||
)
|
||||
if dynamo_compile_column_us:
|
||||
metrics_context = get_metrics_context()
|
||||
if metrics_context.in_progress():
|
||||
metrics_context.increment(
|
||||
dynamo_compile_column_us, time_spent_ns // 1000
|
||||
)
|
||||
# TODO: the events that we capture in calculate_time_spent() seem a little
|
||||
# arbitrary. Currently, it's only those fields that are present in
|
||||
# CompilationMetrics (but note that we accumulate by the associated event
|
||||
# name, not the field name in CompilationMetrics). Do we want to keep it
|
||||
# this way?
|
||||
cumulative_time_spent_ns[event_name] += time_spent_ns
|
||||
# Always log the end event even on exception
|
||||
if phase_name:
|
||||
chromium_log.log_event_end(
|
||||
phase_name,
|
||||
end_ns,
|
||||
{},
|
||||
start_ns,
|
||||
log_pt2_compile_event,
|
||||
)
|
||||
else:
|
||||
chromium_log.log_event_end(key, end_ns, {}, start_ns, log_pt2_compile_event)
|
||||
# Only record backward compilation metrics if phase_name is not None!
|
||||
if phase_name:
|
||||
frame_key = str(curr_frame)
|
||||
# fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch.
|
||||
# use frame_key as time aggregation key.
|
||||
if fwd_only and fail_type is None:
|
||||
_add_time_spent(frame_key, phase_name, time_spent)
|
||||
else:
|
||||
# fwd + bwd compilation stages: inductor_compile, code_gen.
|
||||
# use frame_key as time aggregation key for fwd graphs;
|
||||
# use compile_id as time aggregation key for bwd graphs.
|
||||
if torch._guards.TracingContext.try_get() is not None:
|
||||
aot_graph_name = str(
|
||||
torch._guards.TracingContext.get().aot_graph_name
|
||||
)
|
||||
if (
|
||||
"forward" in aot_graph_name or "inference" in aot_graph_name
|
||||
) and fail_type is None:
|
||||
_add_time_spent(frame_key, phase_name, time_spent)
|
||||
elif "backward" in aot_graph_name:
|
||||
compile_id = str(
|
||||
torch._guards.CompileContext.current_compile_id()
|
||||
)
|
||||
if fail_type is None:
|
||||
_add_time_spent(compile_id, phase_name, time_spent)
|
||||
|
||||
# log backward compilation metrics at the end of `inductor_compile` of bwd graph,
|
||||
# one record for one bwd graph.
|
||||
if phase_name == "inductor_compile":
|
||||
if fail_type is None:
|
||||
inductor_compile_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("inductor_compile", None)
|
||||
code_gen_time = frame_phase_timing[compile_id].get(
|
||||
"code_gen", None
|
||||
)
|
||||
remote_cache_time_saved = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_cache_time_saved", None)
|
||||
remote_fx_graph_cache_get_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_fx_graph_cache_get", None)
|
||||
remote_fx_graph_cache_put_time = frame_phase_timing[
|
||||
compile_id
|
||||
].get("remote_fx_graph_cache_put", None)
|
||||
else:
|
||||
inductor_compile_time = None
|
||||
code_gen_time = None
|
||||
remote_cache_time_saved = None
|
||||
remote_fx_graph_cache_get_time = None
|
||||
remote_fx_graph_cache_put_time = None
|
||||
structured_logging_overhead_s = (
|
||||
torch._logging.get_structured_logging_overhead()
|
||||
)
|
||||
metrics = CompilationMetrics(
|
||||
compile_id=compile_id,
|
||||
inductor_compile_time_s=inductor_compile_time,
|
||||
code_gen_time_s=code_gen_time,
|
||||
fail_type=fail_type,
|
||||
fail_reason=fail_reason,
|
||||
remote_cache_time_saved_s=remote_cache_time_saved,
|
||||
structured_logging_overhead_s=structured_logging_overhead_s,
|
||||
is_forward=False, # is_forward
|
||||
num_triton_bundles=codecache_metrics.get(
|
||||
"num_triton_bundles", None
|
||||
),
|
||||
remote_fx_graph_cache_get_time_ms=to_int_ms(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
remote_fx_graph_cache_put_time_ms=to_int_ms(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
start_time_us=start_ns // 1000,
|
||||
duration_us=(end_ns - start_ns) // 1000,
|
||||
inductor_cumulative_compile_time_us=to_int_us(
|
||||
inductor_compile_time
|
||||
),
|
||||
inductor_code_gen_cumulative_compile_time_us=to_int_us(
|
||||
code_gen_time
|
||||
),
|
||||
distributed_ephemeral_timeout_us=to_int_us(
|
||||
remote_cache_time_saved
|
||||
), # TODO: instrument more accurately
|
||||
structured_logging_overhead_us=to_int_us(
|
||||
structured_logging_overhead_s
|
||||
),
|
||||
remote_fx_graph_cache_get_time_us=to_int_us(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
remote_fx_graph_cache_put_time_us=to_int_us(
|
||||
remote_fx_graph_cache_put_time
|
||||
),
|
||||
)
|
||||
record_compilation_metrics(metrics)
|
||||
|
||||
|
||||
@overload
|
||||
@ -793,11 +866,6 @@ def to_int_us(v: Optional[float]) -> Optional[int]:
|
||||
return None if v is None else int(v * 1_000_000)
|
||||
|
||||
|
||||
# Version field added to every log. Increment to make it easier to distinguish new
|
||||
# vs. old entries when you make a substantive change to how the logs are populated.
|
||||
LOG_FORMAT_VERSION = 2
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationMetrics:
|
||||
compile_id: Optional[str] = None
|
||||
@ -845,18 +913,15 @@ class CompilationMetrics:
|
||||
aot_autograd_cumulative_compile_time_us: Optional[int] = None
|
||||
inductor_cumulative_compile_time_us: Optional[int] = None
|
||||
inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
|
||||
triton_compile_time_us: Optional[int] = None # TODO: instrument
|
||||
runtime_cudagraphify_time_us: Optional[int] = None # TODO: instrument
|
||||
runtime_triton_autotune_time_us: Optional[int] = None # TODO: instrument
|
||||
triton_compile_time_us: Optional[int] = None
|
||||
runtime_cudagraphify_time_us: Optional[int] = None
|
||||
runtime_triton_autotune_time_us: Optional[int] = None
|
||||
dynamo_compile_time_before_restart_us: Optional[int] = None
|
||||
cuda_synchronize_time_us: Optional[int] = None # TODO: instrument
|
||||
cuda_synchronize_time_us: Optional[int] = None
|
||||
distributed_ephemeral_timeout_us: Optional[int] = None
|
||||
structured_logging_overhead_us: Optional[int] = None
|
||||
remote_fx_graph_cache_get_time_us: Optional[int] = None
|
||||
remote_fx_graph_cache_put_time_us: Optional[int] = None
|
||||
backward_cumulative_compile_time_us: Optional[int] = None
|
||||
end_time_us: Optional[int] = None
|
||||
log_format_version: int = LOG_FORMAT_VERSION
|
||||
|
||||
|
||||
DEFAULT_COMPILATION_METRICS_LIMIT = 64
|
||||
@ -904,32 +969,8 @@ def add_compilation_metrics_to_chromium(c: CompilationMetrics):
|
||||
)
|
||||
|
||||
|
||||
def record_compilation_metrics(metrics: Dict[str, Any]):
|
||||
# TODO: Temporary; populate legacy fields from their replacements.
|
||||
# Remove when we decide we can really deprecate them.
|
||||
def us_to_s(field):
|
||||
metric = metrics.get(field, None)
|
||||
return metric / 1e6 if metric is not None else None
|
||||
|
||||
def us_to_ms(field):
|
||||
metric = metrics.get(field, None)
|
||||
return metric // 1000 if metric is not None else None
|
||||
|
||||
legacy_metrics = {
|
||||
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
|
||||
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
|
||||
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
|
||||
"code_gen_time_s": us_to_s("inductor_code_gen_cumulative_compile_time_us"),
|
||||
"remote_cache_time_saved_s": us_to_s("distributed_ephemeral_timeout_us"),
|
||||
"remote_fx_graph_cache_get_time_ms": us_to_ms(
|
||||
"remote_fx_graph_cache_get_time_us"
|
||||
),
|
||||
"remote_fx_graph_cache_put_time_ms": us_to_ms(
|
||||
"remote_fx_graph_cache_put_time_us"
|
||||
),
|
||||
}
|
||||
|
||||
compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics})
|
||||
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
|
||||
global _compilation_metrics
|
||||
_compilation_metrics.append(compilation_metrics)
|
||||
if compilation_metrics.is_forward:
|
||||
name = "compilation_metrics"
|
||||
@ -938,7 +979,10 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
|
||||
name = "bwd_compilation_metrics"
|
||||
torch._logging.trace_structured(
|
||||
name,
|
||||
lambda: {k: list(v) if isinstance(v, set) else v for k, v in metrics.items()},
|
||||
lambda: {
|
||||
k: list(v) if isinstance(v, set) else v
|
||||
for k, v in dataclasses.asdict(compilation_metrics).items()
|
||||
},
|
||||
# NB: Because compilation metrics *includes* the logging overhead time,
|
||||
# we can't both *measure* the logging overhead of compilation metrics
|
||||
# without making it inconsistent with compilation metrics itself, so
|
||||
@ -949,10 +993,6 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
|
||||
log_compilation_event(compilation_metrics)
|
||||
|
||||
|
||||
# record_compilation_metrics is called by the singleton MetricsContext exit handler.
|
||||
_METRICS_CONTEXT = MetricsContext(on_exit=record_compilation_metrics)
|
||||
|
||||
|
||||
def set_compilation_metrics_limit(new_size: int) -> None:
|
||||
global _compilation_metrics
|
||||
while len(_compilation_metrics) > new_size:
|
||||
|
@ -632,7 +632,7 @@ class AOTAutogradCache:
|
||||
)
|
||||
# TODO: should we use the same field for remote cache time saved for both
|
||||
# FXGraphCache and AOTAutogradCache?
|
||||
# get_metrics_context().increment(...)
|
||||
# add_remote_cache_time_saved(time_saved_ns, is_backward=False)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
|
@ -10,7 +10,6 @@ import builtins
|
||||
import collections
|
||||
import itertools
|
||||
import pprint
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
@ -19,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context, to_int_us
|
||||
from torch._guards import (
|
||||
compile_context,
|
||||
CompileContext,
|
||||
@ -2004,49 +2002,17 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with tracing(saved_context), compile_context(
|
||||
saved_compile_context
|
||||
), context(), track_graph_compiling(
|
||||
aot_config, "backward"
|
||||
), get_metrics_context(), dynamo_timed(
|
||||
"backward._backward_impl",
|
||||
phase_name="entire_backward_compile",
|
||||
dynamo_compile_column_us="backward_cumulative_compile_time_us",
|
||||
):
|
||||
fail_type: Optional[str] = None
|
||||
fail_reason: Optional[str] = None
|
||||
start_ns = time.time_ns()
|
||||
try:
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
), context(), track_graph_compiling(aot_config, "backward"):
|
||||
CompiledFunction.compiled_bw = aot_config.bw_compiler(
|
||||
bw_module, placeholder_list
|
||||
)
|
||||
# Maybe save cache entry
|
||||
if try_save_cache_entry is not None:
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
# Maybe save cache entry
|
||||
if try_save_cache_entry is not None:
|
||||
try_save_cache_entry(
|
||||
CompiledFunction.compiled_bw,
|
||||
fw_metadata,
|
||||
aot_config,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(masnesral): Populating the exception info should be automatic.
|
||||
fail_type = type(e).__qualname__
|
||||
fail_reason = str(e)
|
||||
finally:
|
||||
# TODO(masnesral): Populating time fields should be automatic.
|
||||
end_ns = time.time_ns()
|
||||
metrics = {
|
||||
"compile_id": str(
|
||||
torch._guards.CompileContext.current_compile_id()
|
||||
),
|
||||
"fail_type": fail_type,
|
||||
"fail_reason": fail_reason,
|
||||
"is_forward": False,
|
||||
"start_time_us": start_ns // 1000,
|
||||
"end_time_us": end_ns // 1000,
|
||||
"duration_us": (end_ns - start_ns) // 1000,
|
||||
"structured_logging_overhead_us": to_int_us(
|
||||
torch._logging.get_structured_logging_overhead(),
|
||||
),
|
||||
}
|
||||
get_metrics_context().update_outer(metrics)
|
||||
|
||||
if (
|
||||
torch._functorch.config.donated_buffer
|
||||
|
@ -54,10 +54,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.utils import (
|
||||
add_remote_cache_time_saved,
|
||||
codecache_metrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_chromium_event_logger,
|
||||
get_metrics_context,
|
||||
)
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
@ -1151,7 +1152,7 @@ class FxGraphCache:
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
get_metrics_context().increment("num_triton_bundles", 1)
|
||||
codecache_metrics["num_triton_bundles"] += 1
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
@ -1448,9 +1449,7 @@ class FxGraphCache:
|
||||
|
||||
if (time_saved_ns := compiled_graph._time_taken_ns) is not None:
|
||||
cache_info["time_saved_ns"] = time_saved_ns
|
||||
get_metrics_context().increment(
|
||||
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
|
||||
)
|
||||
add_remote_cache_time_saved(time_saved_ns, is_backward)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
|
@ -567,7 +567,7 @@ def compile_fx_inner(
|
||||
"compile_fx_inner",
|
||||
phase_name="inductor_compile",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="inductor_cumulative_compile_time_us",
|
||||
fwd_only=False,
|
||||
)
|
||||
)
|
||||
# NB: Why is this the dynamo_compile counter? The rule here is that
|
||||
|
@ -1950,7 +1950,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
"GraphLowering.compile_to_module",
|
||||
phase_name="code_gen",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
|
||||
fwd_only=False,
|
||||
):
|
||||
return self._compile_to_module()
|
||||
|
||||
|
@ -45,14 +45,14 @@ remote_fx_cache_get_timed = functools.partial(
|
||||
"FbRemoteFxGraphCache.get",
|
||||
phase_name="remote_fx_graph_cache_get",
|
||||
log_pt2_compile_event=False,
|
||||
dynamo_compile_column_us="remote_fx_graph_cache_get_time_us",
|
||||
fwd_only=False,
|
||||
)
|
||||
remote_fx_cache_put_timed = functools.partial(
|
||||
dynamo_timed,
|
||||
"FbRemoteFxGraphCache.put",
|
||||
phase_name="remote_fx_graph_cache_put",
|
||||
log_pt2_compile_event=False,
|
||||
dynamo_compile_column_us="remote_fx_graph_cache_put_time_us",
|
||||
fwd_only=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -135,13 +135,7 @@ try:
|
||||
except AttributeError: # Compile workers only have a mock version of torch
|
||||
|
||||
@contextlib.contextmanager
|
||||
def dynamo_timed(
|
||||
key,
|
||||
phase_name=None,
|
||||
fwd_only=True,
|
||||
metadata=None,
|
||||
dynamo_compile_column_us=None,
|
||||
):
|
||||
def dynamo_timed(key, phase_name=None, fwd_only=True):
|
||||
yield
|
||||
|
||||
|
||||
|
@ -1099,6 +1099,7 @@ class LazyString:
|
||||
structured_logging_overhead: Dict[str, float] = defaultdict(float)
|
||||
|
||||
|
||||
# Same principle as add_remote_cache_time_saved, but do it for structured logging
|
||||
def add_structured_logging_overhead(time_spent: float) -> None:
|
||||
global structured_logging_overhead
|
||||
key = None
|
||||
|
Reference in New Issue
Block a user