mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[Reland][Dynamo] Don't log compilation metrics for PyTorch unit tests (#115571)
Reland #115452, which was reverted to simplify a merge conflict with #115386 Pull Request resolved: https://github.com/pytorch/pytorch/pull/115571 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
064846dbc2
commit
89ee3af076
@ -314,6 +314,9 @@ capture_autograd_function = True
|
||||
# enable/disable dynamo tracing for `torch.func` transforms
|
||||
capture_func_transforms = False
|
||||
|
||||
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
|
||||
log_compilation_metrics = True
|
||||
|
||||
# simulates what would happen if we didn't have support for BUILD_SET opcode,
|
||||
# used for testing
|
||||
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
|
||||
|
||||
@ -620,55 +620,58 @@ def _compile(
|
||||
e.__traceback__
|
||||
) from None
|
||||
finally:
|
||||
from .utils import curr_frame
|
||||
if config.log_compilation_metrics:
|
||||
from .utils import curr_frame
|
||||
|
||||
frame_key = str(curr_frame)
|
||||
if (
|
||||
fail_reason is None
|
||||
and output is not None
|
||||
and frame_key in frame_phase_timing
|
||||
):
|
||||
guard_count = len(output.guards)
|
||||
graph_op_count = output.count_calls()
|
||||
graph_node_count = len(output.graph.nodes)
|
||||
graph_input_count = len(output.placeholders)
|
||||
entire_frame_compile_time = frame_phase_timing[frame_key].get(
|
||||
"entire_frame_compile", None
|
||||
frame_key = str(curr_frame)
|
||||
if (
|
||||
fail_reason is None
|
||||
and output is not None
|
||||
and frame_key in frame_phase_timing
|
||||
):
|
||||
guard_count = len(output.guards)
|
||||
graph_op_count = output.count_calls()
|
||||
graph_node_count = len(output.graph.nodes)
|
||||
graph_input_count = len(output.placeholders)
|
||||
entire_frame_compile_time = frame_phase_timing[frame_key].get(
|
||||
"entire_frame_compile", None
|
||||
)
|
||||
backend_compile_time = frame_phase_timing[frame_key].get(
|
||||
"backend_compile", None
|
||||
)
|
||||
non_compliant_ops = {
|
||||
op.__qualname__ for op in output.non_compliant_ops
|
||||
}
|
||||
compliant_custom_ops = {
|
||||
op.__qualname__ for op in output.compliant_custom_ops
|
||||
}
|
||||
else:
|
||||
guard_count = None
|
||||
graph_op_count = None
|
||||
graph_node_count = None
|
||||
graph_input_count = None
|
||||
entire_frame_compile_time = None
|
||||
backend_compile_time = None
|
||||
non_compliant_ops = set({})
|
||||
compliant_custom_ops = set({})
|
||||
metrics = CompilationMetrics(
|
||||
frame_key,
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
cache_size.num_cache_entries,
|
||||
guard_count,
|
||||
graph_op_count,
|
||||
graph_node_count,
|
||||
graph_input_count,
|
||||
entire_frame_compile_time,
|
||||
backend_compile_time,
|
||||
fail_reason,
|
||||
non_compliant_ops,
|
||||
compliant_custom_ops,
|
||||
)
|
||||
backend_compile_time = frame_phase_timing[frame_key].get(
|
||||
"backend_compile", None
|
||||
)
|
||||
non_compliant_ops = {op.__qualname__ for op in output.non_compliant_ops}
|
||||
compliant_custom_ops = {
|
||||
op.__qualname__ for op in output.compliant_custom_ops
|
||||
}
|
||||
else:
|
||||
guard_count = None
|
||||
graph_op_count = None
|
||||
graph_node_count = None
|
||||
graph_input_count = None
|
||||
entire_frame_compile_time = None
|
||||
backend_compile_time = None
|
||||
non_compliant_ops = set({})
|
||||
compliant_custom_ops = set({})
|
||||
metrics = CompilationMetrics(
|
||||
frame_key,
|
||||
code.co_name,
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
cache_size.num_cache_entries_with_same_id_matched_objs,
|
||||
cache_size.num_cache_entries,
|
||||
guard_count,
|
||||
graph_op_count,
|
||||
graph_node_count,
|
||||
graph_input_count,
|
||||
entire_frame_compile_time,
|
||||
backend_compile_time,
|
||||
fail_reason,
|
||||
non_compliant_ops,
|
||||
compliant_custom_ops,
|
||||
)
|
||||
log_compilation_event(metrics)
|
||||
log_compilation_event(metrics)
|
||||
|
||||
|
||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
|
||||
|
||||
@ -52,7 +52,11 @@ class TestCase(TorchTestCase):
|
||||
super().setUpClass()
|
||||
cls._exit_stack = contextlib.ExitStack()
|
||||
cls._exit_stack.enter_context(
|
||||
config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
|
||||
config.patch(
|
||||
raise_on_ctx_manager_usage=True,
|
||||
suppress_errors=False,
|
||||
log_compilation_metrics=False,
|
||||
),
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@ -1325,6 +1325,8 @@ if TEST_WITH_TORCHDYNAMO:
|
||||
import torch._dynamo
|
||||
# Do not spend time on helper functions that are called with different inputs
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8
|
||||
# Do not log compilation metrics from unit tests
|
||||
torch._dynamo.config.log_compilation_metrics = False
|
||||
if TEST_WITH_TORCHINDUCTOR:
|
||||
import torch._inductor.config
|
||||
torch._inductor.config.fallback_random = True
|
||||
|
||||
Reference in New Issue
Block a user