[pytorch][dynamo_compile] Log inductor config to dynamo_compile (#140790)

Summary:
Scrubbed inductor config logging to dynamo_compile as json:str.

Scrub RE: `r'((^TYPE_CHECKING$)|(.*_progress$)|(.*TESTING.*)|(.*(rocm|halide).*)|(^trace\..*)|(^_))'`to save some space.

Test Plan:
Staging logger: https://fburl.com/data/ltkt08zm

P1679697917

{F1958428018}

Differential Revision: D65806399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140790
Approved by: https://github.com/masnesral
This commit is contained in:
Prajesh Praveen Anchalia
2024-11-19 02:39:33 +00:00
committed by PyTorch MergeBot
parent 9ae19ffbed
commit 1e234e63b3
2 changed files with 125 additions and 6 deletions

View File

@ -197,6 +197,7 @@ class TestDynamoTimed(TestCase):
e.dynamo_config = None
e.co_filename = None
e.co_firstlineno = None
e.inductor_config = None
# First event is for the forward. Formatting makes reading diffs
# much easier.
@ -237,6 +238,7 @@ class TestDynamoTimed(TestCase):
'has_guarded_code': True,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,
'inductor_config': None,
'inductor_cumulative_compile_time_us': 0,
'is_forward': True,
'joint_graph_pass_time_us': 0,
@ -300,6 +302,7 @@ class TestDynamoTimed(TestCase):
'has_guarded_code': None,
'inductor_code_gen_cumulative_compile_time_us': 0,
'inductor_compile_time_s': 0.0,
'inductor_config': None,
'inductor_cumulative_compile_time_us': 0,
'is_forward': False,
'joint_graph_pass_time_us': None,
@ -326,6 +329,77 @@ class TestDynamoTimed(TestCase):
)
class TestInductorConfigParsingForLogging(TestCase):
"""
Test for parsing inductor config for logging in CompilationMetrics.
"""
class TestObject:
def __init__(self, a, b):
self.a = a
self.b = b
def test_inductor_config_jsonify(self):
"""
Sanity check if the actual inductor config is parsed correctly
"""
inductor_config_json = utils._scrubbed_inductor_config_for_logging()
self.assertTrue(isinstance(inductor_config_json, str))
@mock.patch("torch._dynamo.utils.torch._inductor.config")
def test_inductor_config_parsing_non_conforming_items(self, mocked_inductor_config):
"""
Test if the inductor config is parsed correctly when the config is
- None
- not a dict
- not json serializable
- complex unserializable objects
"""
obj = TestCase
test_mock_config = {
"some": {1: "0", obj: "this", "name": obj, "some": True},
"data": {1: "0", obj: "this", "name": obj, "some": True},
"list": [
{1: "0", obj: "this", "name": obj, "some": True},
{1: "0", obj: "this", "name": obj, "some": True},
],
"object": {
1: "0",
obj: "this",
"name": obj,
"some": True,
"data": {1: "0", obj: "this", "name": obj, "some": True},
},
}
expected = (
"""{"some": {"1": "0", "name": "Value is not JSON serializable", "some": true},"""
""" "data": {"1": "0", "name": "Value is not JSON serializable", "some": true}, "list": """
"""[{"1": "0", "name": "Value is not JSON serializable", "some": true}, """
"""{"1": "0", "name": "Value is not JSON serializable", "some": true}], "object": """
"""{"1": "0", "name": "Value is not JSON serializable", "some": true, "data": """
"""{"1": "0", "name": "Value is not JSON serializable", "some": true}}}"""
)
mocked_inductor_config.get_config_copy.return_value = test_mock_config
inductor_config_json = utils._scrubbed_inductor_config_for_logging()
self.assertEqual(inductor_config_json, expected)
expected = "{}"
mocked_inductor_config.get_config_copy.return_value = {obj: obj}
inductor_config_json = utils._scrubbed_inductor_config_for_logging()
self.assertEqual(inductor_config_json, expected)
expected = "Inductor Config is not JSON serializable"
mocked_inductor_config.get_config_copy.return_value = obj
inductor_config_json = utils._scrubbed_inductor_config_for_logging()
self.assertEqual(inductor_config_json, expected)
expected = None
mocked_inductor_config.get_config_copy.return_value = None
inductor_config_json = utils._scrubbed_inductor_config_for_logging()
self.assertEqual(inductor_config_json, expected)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -14,6 +14,7 @@ import gc
import importlib
import inspect
import itertools
import json
import linecache
import logging
import math
@ -61,7 +62,6 @@ from typing_extensions import Literal, TypeIs
import torch
import torch._functorch.config
import torch._inductor.config as inductor_config
import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
@ -865,6 +865,7 @@ class CompilationMetrics:
post_grad_pass_time_us: Optional[int] = None
joint_graph_pass_time_us: Optional[int] = None
log_format_version: int = LOG_FORMAT_VERSION
inductor_config: Optional[str] = None
DEFAULT_COMPILATION_METRICS_LIMIT = 64
@ -912,9 +913,48 @@ def add_compilation_metrics_to_chromium(c: Dict[str, Any]) -> None:
)
def _scrubbed_inductor_config_for_logging() -> Optional[str]:
"""
Method to parse and scrub unintersting configs from inductor config
"""
# TypeSafeSerializer for json.dumps()
# Skips complex types as values in config dict
class TypeSafeSerializer(json.JSONEncoder):
def default(self, o):
try:
return super().default(o)
except Exception:
return "Value is not JSON serializable"
configs_to_scrub_re = r"((^TYPE_CHECKING$)|(.*_progress$)|(.*TESTING.*)|(.*(rocm|halide).*)|(^trace\..*)|(^_))"
keys_to_scrub = set()
inductor_conf_str = None
inductor_config_copy = (
torch._inductor.config.get_config_copy() if torch._inductor.config else None
)
if inductor_config_copy is not None:
try:
for key, val in inductor_config_copy.items():
if not isinstance(key, str) or re.search(configs_to_scrub_re, key):
keys_to_scrub.add(key)
# Convert set() to list for json.dumps()
if isinstance(val, set):
inductor_config_copy[key] = list(val)
# Evict unwanted keys
for key in keys_to_scrub:
del inductor_config_copy[key]
# Stringify Inductor config
inductor_conf_str = json.dumps(
inductor_config_copy, cls=TypeSafeSerializer, skipkeys=True
)
except Exception:
# Don't crash because of runtime logging errors
inductor_conf_str = "Inductor Config is not JSON serializable"
return inductor_conf_str
def record_compilation_metrics(metrics: Dict[str, Any]):
# TODO: Temporary; populate legacy fields from their replacements.
# Remove when we decide we can really deprecate them.
def us_to_s(field):
metric = metrics.get(field, None)
return metric / 1e6 if metric is not None else None
@ -923,7 +963,12 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
metric = metrics.get(field, None)
return metric // 1000 if metric is not None else None
legacy_metrics = {
common_metrics = {
"inductor_config": _scrubbed_inductor_config_for_logging(),
# -------- Any future common metircs go here --------
#
# Legacy metircs go here(TODO: Temporary; populate legacy fields from their replacements.)
# Remove when we decide we can really deprecate them.
"entire_frame_compile_time_s": us_to_s("dynamo_cumulative_compile_time_us"),
"backend_compile_time_s": us_to_s("aot_autograd_cumulative_compile_time_us"),
"inductor_compile_time_s": us_to_s("inductor_cumulative_compile_time_us"),
@ -937,7 +982,7 @@ def record_compilation_metrics(metrics: Dict[str, Any]):
),
}
compilation_metrics = CompilationMetrics(**{**metrics, **legacy_metrics})
compilation_metrics = CompilationMetrics(**{**metrics, **common_metrics})
_compilation_metrics.append(compilation_metrics)
if compilation_metrics.is_forward:
name = "compilation_metrics"
@ -2036,7 +2081,7 @@ def same(
and math.isnan(res_error)
# Some unit test for the accuracy minifier relies on
# returning false in this case.
and not inductor_config.cpp.inject_relu_bug_TESTING_ONLY
and not torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY
):
passes_test = True
if not passes_test: