Compare commits

...

7 Commits

Author SHA1 Message Date
71daa72b30 Update guards.h 2025-10-10 00:46:32 -07:00
5a1255a223 Update guards.cpp 2025-10-10 00:44:20 -07:00
573301c0b3 Update guards.cpp 2025-10-09 17:06:12 -07:00
e5e8f80f0b Update guards.cpp 2025-10-09 16:08:14 -07:00
2499e30f15 lint 2025-10-09 11:35:48 -07:00
f812a89371 nit 2025-10-08 14:29:23 -07:00
155d50c268 init 2025-10-08 14:20:11 -07:00
2 changed files with 14 additions and 3 deletions

View File

@ -4,6 +4,7 @@ import contextlib
import torch
import torch.distributed as dist
from torch._dynamo.testing import CompileCounterWithBackend
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.testing._internal.common_utils import (
@ -262,14 +263,21 @@ class TestDTensorDebugMode(TestCase):
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
def test_compile(self):
@torch.compile
cnt = CompileCounterWithBackend("inductor")
@torch.compile(backend=cnt)
def f(x):
return x.sin().cos()
x = torch.randn(8)
with DebugMode() as debug_mode:
f(x)
self.assertEqual(len(debug_mode.debug_string()), 0)
self.assertEqual(len(debug_mode.debug_string()), 0)
f(x)
f(x)
self.assertEqual(
cnt.frame_count, 1
) # check DebugMode doesn't trigger additional recompilations
instantiate_parametrized_tests(TestDTensorDebugMode)

View File

@ -21,7 +21,10 @@ struct LocalState {
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
if (override_dispatch_key_set.empty()) {
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ -
c10::DispatchKeySet(
{c10::DispatchKey::Python,
c10::DispatchKey::PythonTLSSnapshot});
} else {
return override_dispatch_key_set;
}