diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index d122b770b285..aab91ddebe94 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -4,7 +4,6 @@ import contextlib import torch import torch.distributed as dist -from torch._dynamo.testing import CompileCounterWithBackend from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard from torch.distributed.tensor._dtensor_spec import ShardOrderEntry @@ -322,21 +321,14 @@ class TestDTensorDebugMode(TestCase): self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string()) def test_compile(self): - cnt = CompileCounterWithBackend("inductor") - - @torch.compile(backend=cnt) + @torch.compile def f(x): return x.sin().cos() x = torch.randn(8) with DebugMode() as debug_mode: f(x) - self.assertEqual(len(debug_mode.debug_string()), 0) - f(x) - f(x) - self.assertEqual( - cnt.frame_count, 1 - ) # check DebugMode doesn't trigger additional recompilations + self.assertEqual(len(debug_mode.debug_string()), 0) instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/csrc/dynamo/guards.h b/torch/csrc/dynamo/guards.h index 38346b97b243..0bb5590283f2 100644 --- a/torch/csrc/dynamo/guards.h +++ b/torch/csrc/dynamo/guards.h @@ -21,10 +21,7 @@ struct LocalState { at::DispatchKeySet apply(at::DispatchKeySet ks) const { if (override_dispatch_key_set.empty()) { - return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ - - c10::DispatchKeySet( - {c10::DispatchKey::Python, - c10::DispatchKey::PythonTLSSnapshot}); + return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; } else { return override_dispatch_key_set; }