Revert "[dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)"

This reverts commit 306b344a1847749f0baf085dcd92560f4e99cd1b.

Reverted https://github.com/pytorch/pytorch/pull/164992 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/inductor/test_inductor_scheduler.py::TestSchedulerCUDA::test_flop_counter_op_options0_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18417066364/job/52485636942) [HUD commit link](306b344a18) ([comment](https://github.com/pytorch/pytorch/pull/164992#issuecomment-3397927142))
This commit is contained in:
PyTorch MergeBot
2025-10-13 15:14:34 +00:00
parent 4874cce52f
commit 8580112682
2 changed files with 3 additions and 14 deletions

View File

@ -4,7 +4,6 @@ import contextlib
import torch
import torch.distributed as dist
from torch._dynamo.testing import CompileCounterWithBackend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
@ -322,21 +321,14 @@ class TestDTensorDebugMode(TestCase):
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
def test_compile(self):
cnt = CompileCounterWithBackend("inductor")
@torch.compile(backend=cnt)
@torch.compile
def f(x):
return x.sin().cos()
x = torch.randn(8)
with DebugMode() as debug_mode:
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)
f(x)
f(x)
self.assertEqual(
cnt.frame_count, 1
) # check DebugMode doesn't trigger additional recompilations
self.assertEqual(len(debug_mode.debug_string()), 0)
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -21,10 +21,7 @@ struct LocalState {
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
if (override_dispatch_key_set.empty()) {
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ -
c10::DispatchKeySet(
{c10::DispatchKey::Python,
c10::DispatchKey::PythonTLSSnapshot});
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
} else {
return override_dispatch_key_set;
}