Compare commits

...

1 Commits

Author SHA1 Message Date
84dece5154 Move config out of recompile() 2025-11-06 17:36:44 -08:00
5 changed files with 19 additions and 14 deletions

View File

@ -7486,7 +7486,7 @@ class TestFXMemoryProfiler(TestCase):
return fx_frames
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_fx_memory_profiler_augmentation(self):
"""Test that memory snapshots are augmented with FX debug information."""

View File

@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)

View File

@ -739,11 +739,8 @@ enable_aot_compile = False
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
enrich_profiler_metadata: bool = False
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -2,6 +2,8 @@ import os
import sys
from typing import Optional
from torch.utils._config_module import Config, install_config_module
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
no_data_dependent_graph_break = (
@ -100,7 +102,11 @@ backed_size_oblivious = False
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
skip_dtype_check_in_meta_registrations = False
from torch.utils._config_module import install_config_module
# Experimental: If True, graph module will register fx metadata during recompile()
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
default=False,
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
)
install_config_module(sys.modules[__name__])

View File

@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
from ._compatibility import compatibility
from .experimental import _config as fx_experimental_config
from .graph import (
_BoxedCodeGen,
_custom_builtins,
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
# Do not import anything inside recompile, it might slow down the
# function and cause perf regression. Import outside of the method instead.
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
root_module="self",
record_func=fx_experimental_config.enrich_profiler_metadata,
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
if dynamo_config.enrich_profiler_metadata:
if fx_experimental_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
node_metadata: dict[int, dict[str, Any]] = {}
for i, node in enumerate(self._graph.nodes):