diff --git a/test/jit/test_jit_utils.py b/test/jit/test_jit_utils.py index b7ec27d5b554..ceb46489f20d 100644 --- a/test/jit/test_jit_utils.py +++ b/test/jit/test_jit_utils.py @@ -94,3 +94,15 @@ class TestJitUtils(JitTestCase): """) self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn") + + def test_no_tracer_warn_context_manager(self): + torch._C._jit_set_tracer_state_warn(True) + with jit_utils.NoTracerWarnContextManager() as no_warn: + self.assertEqual( + False, + torch._C._jit_get_tracer_state_warn() + ) + self.assertEqual( + True, + torch._C._jit_get_tracer_state_warn() + ) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index cb14fe573358..f69e217a68b1 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -24,7 +24,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH enable_profiling_mode_for_profiling_tests, slowTest from torch.testing._internal.jit_utils import JitTestCase, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ - clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions + clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \ @@ -2624,23 +2624,23 @@ def f({', '.join(param_names)}): @onlyCPU @ops(op_db, dtypes=OpDTypes.supported) def test_nnc_correctness(self, device, dtype, op): - variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) + with NoTracerWarnContextManager() as no_warn: + variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) - for variant, sample in variant_sample_pairs: - trace = create_traced_fn(self, variant, cache_traced_fn=True) - ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + for variant, sample in variant_sample_pairs: + trace = create_traced_fn(self, variant, cache_traced_fn=True) + ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) - trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) - val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) + self.assertEqual(ref, val) - self.assertEqual(ref, val) - - # https://github.com/pytorch/pytorch/issues/35600 - # each torch.jit.trace adds state to the _python_cu compilation unit - # since this test traces a lot of functions, out-of-memory can occur - # if the CU is not cleared. - torch.jit._state._python_cu.drop_all_functions() + # https://github.com/pytorch/pytorch/issues/35600 + # each torch.jit.trace adds state to the _python_cu compilation unit + # since this test traces a lot of functions, out-of-memory can occur + # if the CU is not cleared. + torch.jit._state._python_cu.drop_all_functions() only_for = ("cpu", "cuda") instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 8fdc8f390804..42e32f1eb3b7 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -66,6 +66,12 @@ void badArgType(const T& v) { thread_local std::shared_ptr tracing_state; } // namespace detail +static std::atomic tracer_state_warn_mode{true}; + +std::atomic& getTracerStateWarnMode() { + return tracer_state_warn_mode; +} + std::function pauseTracing() { // NOLINTNEXTLINE std::shared_ptr state = getTracingState(); diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index de66f0841a6a..4f1e8b0c7d34 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -40,6 +40,8 @@ using ::c10::ivalue::ConstantString; using torch::autograd::Variable; using variable_list = std::vector; +TORCH_API std::atomic& getTracerStateWarnMode(); + struct TORCH_API TracingState : public std::enable_shared_from_this { TracingState(); @@ -48,7 +50,7 @@ struct TORCH_API TracingState // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr graph; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - bool warn = true; + bool warn = getTracerStateWarnMode(); // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool strict = true; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index f7d97087aed6..6842e751d898 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -662,6 +662,17 @@ void initJITBindings(PyObject* module) { .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled) .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled) #endif + .def( + "_jit_set_tracer_state_warn", + [](bool new_warn) { + jit::tracer::getTracerStateWarnMode() = new_warn; + }) + .def( + "_jit_get_tracer_state_warn", + []() { + bool current_tracer_warn = jit::tracer::getTracerStateWarnMode(); + return current_tracer_warn; + }) .def( "_jit_set_nvfuser_skip_node_kind", // Args: diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 95c55e7db870..707529181b63 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -645,6 +645,14 @@ class JitTestCase(JitCommonTestCase): return sm +class NoTracerWarnContextManager(object): + def __enter__(self): + self.prev = torch._C._jit_get_tracer_state_warn() + torch._C._jit_set_tracer_state_warn(False) + + def __exit__(self, *args): + torch._C._jit_set_tracer_state_warn(self.prev) + @contextmanager def inline_everything_mode(should_inline): old = torch._C._jit_get_inline_everything_mode()