mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)"
This reverts commit 306b344a1847749f0baf085dcd92560f4e99cd1b.
Reverted https://github.com/pytorch/pytorch/pull/164992 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/inductor/test_inductor_scheduler.py::TestSchedulerCUDA::test_flop_counter_op_options0_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18417066364/job/52485636942) [HUD commit link](306b344a18
) ([comment](https://github.com/pytorch/pytorch/pull/164992#issuecomment-3397927142))
This commit is contained in:
@ -4,7 +4,6 @@ import contextlib
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
@ -322,21 +321,14 @@ class TestDTensorDebugMode(TestCase):
|
||||
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
|
||||
|
||||
def test_compile(self):
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
|
||||
@torch.compile(backend=cnt)
|
||||
@torch.compile
|
||||
def f(x):
|
||||
return x.sin().cos()
|
||||
|
||||
x = torch.randn(8)
|
||||
with DebugMode() as debug_mode:
|
||||
f(x)
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertEqual(
|
||||
cnt.frame_count, 1
|
||||
) # check DebugMode doesn't trigger additional recompilations
|
||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
@ -21,10 +21,7 @@ struct LocalState {
|
||||
|
||||
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
|
||||
if (override_dispatch_key_set.empty()) {
|
||||
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ -
|
||||
c10::DispatchKeySet(
|
||||
{c10::DispatchKey::Python,
|
||||
c10::DispatchKey::PythonTLSSnapshot});
|
||||
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
|
||||
} else {
|
||||
return override_dispatch_key_set;
|
||||
}
|
||||
|
Reference in New Issue
Block a user