Disable TracerWarnings on NNC opinfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78756

Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-06-02 13:12:06 -07:00
committed by PyTorch MergeBot
parent c5a0d8dccc
commit eb49dde9cf
6 changed files with 54 additions and 15 deletions

View File

@ -94,3 +94,15 @@ class TestJitUtils(JitTestCase):
""") """)
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn") self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
def test_no_tracer_warn_context_manager(self):
torch._C._jit_set_tracer_state_warn(True)
with jit_utils.NoTracerWarnContextManager() as no_warn:
self.assertEqual(
False,
torch._C._jit_get_tracer_state_warn()
)
self.assertEqual(
True,
torch._C._jit_get_tracer_state_warn()
)

View File

@ -24,7 +24,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH
enable_profiling_mode_for_profiling_tests, slowTest enable_profiling_mode_for_profiling_tests, slowTest
from torch.testing._internal.jit_utils import JitTestCase, \ from torch.testing._internal.jit_utils import JitTestCase, \
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \ RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \ from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
@ -2624,23 +2624,23 @@ def f({', '.join(param_names)}):
@onlyCPU @onlyCPU
@ops(op_db, dtypes=OpDTypes.supported) @ops(op_db, dtypes=OpDTypes.supported)
def test_nnc_correctness(self, device, dtype, op): def test_nnc_correctness(self, device, dtype, op):
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op) with NoTracerWarnContextManager() as no_warn:
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs: for variant, sample in variant_sample_pairs:
trace = create_traced_fn(self, variant, cache_traced_fn=True) trace = create_traced_fn(self, variant, cache_traced_fn=True)
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs) self.assertEqual(ref, val)
self.assertEqual(ref, val) # https://github.com/pytorch/pytorch/issues/35600
# each torch.jit.trace adds state to the _python_cu compilation unit
# https://github.com/pytorch/pytorch/issues/35600 # since this test traces a lot of functions, out-of-memory can occur
# each torch.jit.trace adds state to the _python_cu compilation unit # if the CU is not cleared.
# since this test traces a lot of functions, out-of-memory can occur torch.jit._state._python_cu.drop_all_functions()
# if the CU is not cleared.
torch.jit._state._python_cu.drop_all_functions()
only_for = ("cpu", "cuda") only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for) instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)

View File

@ -66,6 +66,12 @@ void badArgType(const T& v) {
thread_local std::shared_ptr<TracingState> tracing_state; thread_local std::shared_ptr<TracingState> tracing_state;
} // namespace detail } // namespace detail
static std::atomic<bool> tracer_state_warn_mode{true};
std::atomic<bool>& getTracerStateWarnMode() {
return tracer_state_warn_mode;
}
std::function<void()> pauseTracing() { std::function<void()> pauseTracing() {
// NOLINTNEXTLINE // NOLINTNEXTLINE
std::shared_ptr<tracer::TracingState> state = getTracingState(); std::shared_ptr<tracer::TracingState> state = getTracingState();

View File

@ -40,6 +40,8 @@ using ::c10::ivalue::ConstantString;
using torch::autograd::Variable; using torch::autograd::Variable;
using variable_list = std::vector<Variable>; using variable_list = std::vector<Variable>;
TORCH_API std::atomic<bool>& getTracerStateWarnMode();
struct TORCH_API TracingState struct TORCH_API TracingState
: public std::enable_shared_from_this<TracingState> { : public std::enable_shared_from_this<TracingState> {
TracingState(); TracingState();
@ -48,7 +50,7 @@ struct TORCH_API TracingState
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<Graph> graph; std::shared_ptr<Graph> graph;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool warn = true; bool warn = getTracerStateWarnMode();
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool strict = true; bool strict = true;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)

View File

@ -662,6 +662,17 @@ void initJITBindings(PyObject* module) {
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled) .def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled) .def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
#endif #endif
.def(
"_jit_set_tracer_state_warn",
[](bool new_warn) {
jit::tracer::getTracerStateWarnMode() = new_warn;
})
.def(
"_jit_get_tracer_state_warn",
[]() {
bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
return current_tracer_warn;
})
.def( .def(
"_jit_set_nvfuser_skip_node_kind", "_jit_set_nvfuser_skip_node_kind",
// Args: // Args:

View File

@ -645,6 +645,14 @@ class JitTestCase(JitCommonTestCase):
return sm return sm
class NoTracerWarnContextManager(object):
def __enter__(self):
self.prev = torch._C._jit_get_tracer_state_warn()
torch._C._jit_set_tracer_state_warn(False)
def __exit__(self, *args):
torch._C._jit_set_tracer_state_warn(self.prev)
@contextmanager @contextmanager
def inline_everything_mode(should_inline): def inline_everything_mode(should_inline):
old = torch._C._jit_get_inline_everything_mode() old = torch._C._jit_get_inline_everything_mode()