Start recording inductor provenance (#162669)

Summary:
This stores information on where fx graphs come from, which makes it
significantly easier to debug.

One outstanding question

1) I only stored the kernel stack traces, do we also want the node mappings?

Test Plan:
I wrote a explicit logging test which makes a module, fx traces it, compiles it, and makes sure the logging infomration shows up.

```
clr@devvm17763 ~/fbsource/fbcode/caffe2/test/dynamo
 % buck2 test @//mode/opt fbcode//caffe2/test/dynamo:test_dynamo -- test_utils

File changed: fbsource//xplat/caffe2/test/dynamo/test_utils.py
File changed: fbcode//caffe2/test/dynamo/test_utils.py
Buck UI: https://www.internalfb.com/buck2/528dea32-2416-4a62-a1ec-39f3c0efdd2e
Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229324015574003
Network: Up: 0B  Down: 0B
Executing actions. Remaining     0/2
Command: test.
Time elapsed: 17.3s
Tests finished: Pass 16. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Rollback Plan:

Differential Revision: D82037582

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162669
Approved by: https://github.com/yushangdi
This commit is contained in:
Colin L Reliability Rice
2025-10-16 23:05:31 +00:00
committed by PyTorch MergeBot
parent 5b3ea75895
commit 98a488c9aa
4 changed files with 37 additions and 1 deletions

View File

@ -510,6 +510,7 @@ class TestDynamoTimed(TestCase):
raw = dataclasses.asdict(compilation_events[0])
del raw["feature_usage"]
del raw["ir_count"]
del raw["inductor_provenance"]
del raw["param_numel"]
del raw["param_bytes"]
del raw["param_count"]
@ -694,6 +695,7 @@ class TestDynamoTimed(TestCase):
raw = dataclasses.asdict(compilation_events[1])
del raw["feature_usage"]
del raw["ir_count"]
del raw["inductor_provenance"]
del raw["guard_latency_us"]
del raw["param_numel"]
del raw["param_bytes"]
@ -911,6 +913,27 @@ class TestDynamoTimed(TestCase):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[0].ir_count, second)
@dynamo_config.patch(
{
"log_compilation_metrics": True,
}
)
@inductor_config.patch(
{"trace.enabled": True, "trace.provenance_tracking_level": 1},
)
def test_inductor_provenance(self):
module = torch.nn.Linear(6, 66)
graph_module = torch.fx.symbolic_trace(module)
compilation_events = []
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
torch.compile(graph_module)(torch.randn(6, 6))
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(
compilation_events[0].inductor_provenance,
{'{"extern_kernels.addmm:1": []}'},
)
@dynamo_config.patch({"log_compilation_metrics": True})
@inductor_config.patch({"force_disable_caches": True})
def test_dynamic_shape_feature_use(self):

View File

@ -1376,6 +1376,7 @@ class CompilationMetrics:
recompile_user_contexts: Optional[set[str]] = None
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
pytorch_version: Optional[str] = None
inductor_provenance: Optional[str] = None
@classmethod
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:

View File

@ -42,7 +42,12 @@ import torch.distributed as dist
from torch import SymInt, Tensor
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.exc import SkipFrame
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
from torch._dynamo.utils import (
CompileEventLogger,
counters,
dynamo_timed,
get_metrics_context,
)
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.common import (
custom_backend_codegen_configs,
@ -1281,6 +1286,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
},
payload_fn=lambda: graph.inductor_provenance_stack_traces_str,
)
if get_metrics_context().in_progress():
get_metrics_context().add_to_set(
"inductor_provenance", graph.inductor_provenance_stack_traces_str
)
return graph, cache_info
@staticmethod

View File

@ -1544,6 +1544,9 @@ class _InProcessFxCompile(FxCompile):
},
payload_fn=lambda: inductor_kernel_stack_trace_str,
)
get_metrics_context().add_to_set(
"inductor_provenance", inductor_kernel_stack_trace_str
)
node_runtimes = None
if inductor_metrics_log.isEnabledFor(logging.INFO):