From 0e5773b7fadef9e29b006af470b771fad55b5206 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 2 Oct 2025 08:59:57 +0000 Subject: [PATCH] [dynamo][export] Do not graph break on torch.autograd._profiler_enabled for export (#164418) Actually we would like to not graph break even in the case of Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo when there are distributed jobs that lead to NCCL timeouts. This bug is a rare edege case, but we have not been able to root cause it yet. But for export, we do not anticipate JIT tracing in distributed job training and therefore this PR is safe for export. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/164418 Approved by: https://github.com/StrongerXi, https://github.com/williamwen42 --- test/dynamo/test_profiler.py | 28 ++++++++++++++++++++++++++++ torch/_dynamo/config.py | 5 +++++ torch/_dynamo/eval_frame.py | 1 + torch/_dynamo/functional_export.py | 1 + torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/torch.py | 20 ++++++++++++++++++++ 6 files changed, 56 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 61dc63ed2d5c..921d7021650f 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -162,6 +162,34 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase): any(e.name == "TorchDynamo Cache Lookup" for e in prof.events()) ) + def test_profiler_enabled_export(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.sin(x) + if torch.autograd._profiler_enabled(): + return torch.cos(x) + else: + return torch.sigmoid(x) + + mod = Mod() + + x = torch.randn(4) + opt_mod = torch._dynamo.export(mod, (x)) + + ref = mod(x) + res = opt_mod.graph_module(x) + self.assertEqual(ref, res) + + with torch.autograd.profiler.profile(): + ref = mod(x) + # Reexport because export skips guards + opt_mod = torch._dynamo.export(mod, (x)) + res = opt_mod.graph_module(x) + self.assertEqual(ref, res) + def test_profiler_dynamo_compiled_region(self): def fn(x, y): r = y.sum(dim=1) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b8d1008dec8e..c572f900cfff 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -675,6 +675,11 @@ run_gc_after_compile = Config( # type: ignore[var-annotated] env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE", ) +# Does not graph break on torch.autograd._profiler_enabled if set to True. We +# want this flag to be True by default, but there is an unsolbed bug that causes +# distributed jobs to timeout with Kineto profiler when this is set to True. +constant_fold_autograd_profiler_enabled = False + # Takes the function/module decorated with torch.compile and passes it through a # wrapper. This ensures that nn.module hooks are also compiled in the same frame. wrap_top_frame = False diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e7e134d27bdf..8b55ac48cca2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2038,6 +2038,7 @@ def export( automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, + constant_fold_autograd_profiler_enabled=True, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, ), _compiling_state_context(), diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 13acd9280056..af418c04105d 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -451,6 +451,7 @@ def _dynamo_graph_capture_for_export( automatic_dynamic_shapes=False, capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, + constant_fold_autograd_profiler_enabled=True, log_graph_in_out_metadata=True, ) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 245094cf92dc..d72a8b0ce7be 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -179,7 +179,6 @@ manual_torch_name_rule_map: dict[ "torch.compiler.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_exporting": TorchInGraphFunctionVariable, - "torch.autograd._profiler_enabled": SkipFunctionVariable, "torch._C._to_dlpack": SkipFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, # We graph break on RNG state setters or getters like @@ -2445,6 +2444,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.atleast_3d", "torch.autograd._calculate_shape", "torch.autograd._is_checkpoint_valid", + "torch.autograd._profiler_enabled", "torch.autograd._make_grads", "torch.autograd._register_py_tensor_class_for_device", "torch.autograd._tensor_or_tensors_to_tuple", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index ad84cff320ff..2858e2af9252 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -270,6 +270,26 @@ class BaseTorchVariable(VariableTracker): def can_constant_fold_through(self): if self.value in constant_fold_functions: return True + + if ( + self.value is torch.autograd._profiler_enabled + and config.constant_fold_autograd_profiler_enabled + ): + # The relevant flag is enabled only for export. One might wonder + # why? + # + # Actually we would like to not graph break even in the case of + # Dynamo. But there is a weird-unsolved bug with Kineto + Dynamo + # when there are distributed jobs that lead to NCCL timeouts. This + # bug is a rare edege case, but we have not been able to root cause + # it yet. See https://www.internalfb.com/sevmanager/view/560336 for + # more details. + # + # So is this safe for export? Yes, for export, we do not anticipate + # JIT tracing in distributed job training, and the weird edge-case + # interaction with Kineto is not a valid usecase. So, this is ok. + return True + return getattr(self.value, "__module__", None) == "math"