From 85801126821d4f509f3cf5aafa24dbcd3cd11183 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 13 Oct 2025 15:14:34 +0000 Subject: [PATCH] Revert "[dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)" This reverts commit 306b344a1847749f0baf085dcd92560f4e99cd1b. Reverted https://github.com/pytorch/pytorch/pull/164992 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/inductor/test_inductor_scheduler.py::TestSchedulerCUDA::test_flop_counter_op_options0_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18417066364/job/52485636942) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/306b344a1847749f0baf085dcd92560f4e99cd1b) ([comment](https://github.com/pytorch/pytorch/pull/164992#issuecomment-3397927142)) --- test/distributed/tensor/debug/test_debug_mode.py | 12 ++---------- torch/csrc/dynamo/guards.h | 5 +---- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index d122b770b285..aab91ddebe94 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -4,7 +4,6 @@ import contextlib import torch import torch.distributed as dist -from torch._dynamo.testing import CompileCounterWithBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.distributed.tensor._dtensor_spec import ShardOrderEntry @@ -322,21 +321,14 @@ class TestDTensorDebugMode(TestCase): self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string()) def test_compile(self): - cnt = CompileCounterWithBackend("inductor") - - @torch.compile(backend=cnt) + @torch.compile def f(x): return x.sin().cos() x = torch.randn(8) with DebugMode() as debug_mode: f(x) - self.assertEqual(len(debug_mode.debug_string()), 0) - f(x) - f(x) - self.assertEqual( - cnt.frame_count, 1 - ) # check DebugMode doesn't trigger additional recompilations + self.assertEqual(len(debug_mode.debug_string()), 0) instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/csrc/dynamo/guards.h b/torch/csrc/dynamo/guards.h index 38346b97b243..0bb5590283f2 100644 --- a/torch/csrc/dynamo/guards.h +++ b/torch/csrc/dynamo/guards.h @@ -21,10 +21,7 @@ struct LocalState { at::DispatchKeySet apply(at::DispatchKeySet ks) const { if (override_dispatch_key_set.empty()) { - return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ - - c10::DispatchKeySet( - {c10::DispatchKey::Python, - c10::DispatchKey::PythonTLSSnapshot}); + return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; } else { return override_dispatch_key_set; }