mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Start recording inductor provenance (#162669)
Summary: This stores information on where fx graphs come from, which makes it significantly easier to debug. One outstanding question 1) I only stored the kernel stack traces, do we also want the node mappings? Test Plan: I wrote a explicit logging test which makes a module, fx traces it, compiles it, and makes sure the logging infomration shows up. ``` clr@devvm17763 ~/fbsource/fbcode/caffe2/test/dynamo % buck2 test @//mode/opt fbcode//caffe2/test/dynamo:test_dynamo -- test_utils File changed: fbsource//xplat/caffe2/test/dynamo/test_utils.py File changed: fbcode//caffe2/test/dynamo/test_utils.py Buck UI: https://www.internalfb.com/buck2/528dea32-2416-4a62-a1ec-39f3c0efdd2e Test UI: https://www.internalfb.com/intern/testinfra/testrun/13229324015574003 Network: Up: 0B Down: 0B Executing actions. Remaining 0/2 Command: test. Time elapsed: 17.3s Tests finished: Pass 16. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Rollback Plan: Differential Revision: D82037582 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162669 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b3ea75895
commit
98a488c9aa
@ -510,6 +510,7 @@ class TestDynamoTimed(TestCase):
|
||||
raw = dataclasses.asdict(compilation_events[0])
|
||||
del raw["feature_usage"]
|
||||
del raw["ir_count"]
|
||||
del raw["inductor_provenance"]
|
||||
del raw["param_numel"]
|
||||
del raw["param_bytes"]
|
||||
del raw["param_count"]
|
||||
@ -694,6 +695,7 @@ class TestDynamoTimed(TestCase):
|
||||
raw = dataclasses.asdict(compilation_events[1])
|
||||
del raw["feature_usage"]
|
||||
del raw["ir_count"]
|
||||
del raw["inductor_provenance"]
|
||||
del raw["guard_latency_us"]
|
||||
del raw["param_numel"]
|
||||
del raw["param_bytes"]
|
||||
@ -911,6 +913,27 @@ class TestDynamoTimed(TestCase):
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[0].ir_count, second)
|
||||
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"log_compilation_metrics": True,
|
||||
}
|
||||
)
|
||||
@inductor_config.patch(
|
||||
{"trace.enabled": True, "trace.provenance_tracking_level": 1},
|
||||
)
|
||||
def test_inductor_provenance(self):
|
||||
module = torch.nn.Linear(6, 66)
|
||||
graph_module = torch.fx.symbolic_trace(module)
|
||||
|
||||
compilation_events = []
|
||||
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
|
||||
torch.compile(graph_module)(torch.randn(6, 6))
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(
|
||||
compilation_events[0].inductor_provenance,
|
||||
{'{"extern_kernels.addmm:1": []}'},
|
||||
)
|
||||
|
||||
@dynamo_config.patch({"log_compilation_metrics": True})
|
||||
@inductor_config.patch({"force_disable_caches": True})
|
||||
def test_dynamic_shape_feature_use(self):
|
||||
|
@ -1376,6 +1376,7 @@ class CompilationMetrics:
|
||||
recompile_user_contexts: Optional[set[str]] = None
|
||||
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
|
||||
pytorch_version: Optional[str] = None
|
||||
inductor_provenance: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:
|
||||
|
@ -42,7 +42,12 @@ import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.exc import SkipFrame
|
||||
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
|
||||
from torch._dynamo.utils import (
|
||||
CompileEventLogger,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_metrics_context,
|
||||
)
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.common import (
|
||||
custom_backend_codegen_configs,
|
||||
@ -1281,6 +1286,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
},
|
||||
payload_fn=lambda: graph.inductor_provenance_stack_traces_str,
|
||||
)
|
||||
if get_metrics_context().in_progress():
|
||||
get_metrics_context().add_to_set(
|
||||
"inductor_provenance", graph.inductor_provenance_stack_traces_str
|
||||
)
|
||||
return graph, cache_info
|
||||
|
||||
@staticmethod
|
||||
|
@ -1544,6 +1544,9 @@ class _InProcessFxCompile(FxCompile):
|
||||
},
|
||||
payload_fn=lambda: inductor_kernel_stack_trace_str,
|
||||
)
|
||||
get_metrics_context().add_to_set(
|
||||
"inductor_provenance", inductor_kernel_stack_trace_str
|
||||
)
|
||||
|
||||
node_runtimes = None
|
||||
if inductor_metrics_log.isEnabledFor(logging.INFO):
|
||||
|
Reference in New Issue
Block a user