Disable TracerWarnings on NNC opinfo tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78756

Approved by: https://github.com/davidberard98
This commit is contained in:
goldenxuett
2022-06-02 13:12:06 -07:00
committed by PyTorch MergeBot
parent c5a0d8dccc
commit eb49dde9cf
6 changed files with 54 additions and 15 deletions

View File

@ -94,3 +94,15 @@ class TestJitUtils(JitTestCase):
""")
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
def test_no_tracer_warn_context_manager(self):
torch._C._jit_set_tracer_state_warn(True)
with jit_utils.NoTracerWarnContextManager() as no_warn:
self.assertEqual(
False,
torch._C._jit_get_tracer_state_warn()
)
self.assertEqual(
True,
torch._C._jit_get_tracer_state_warn()
)

View File

@ -24,7 +24,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH
enable_profiling_mode_for_profiling_tests, slowTest
from torch.testing._internal.jit_utils import JitTestCase, \
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
@ -2624,6 +2624,7 @@ def f({', '.join(param_names)}):
@onlyCPU
@ops(op_db, dtypes=OpDTypes.supported)
def test_nnc_correctness(self, device, dtype, op):
with NoTracerWarnContextManager() as no_warn:
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
for variant, sample in variant_sample_pairs:
@ -2631,7 +2632,6 @@ def f({', '.join(param_names)}):
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
self.assertEqual(ref, val)

View File

@ -66,6 +66,12 @@ void badArgType(const T& v) {
thread_local std::shared_ptr<TracingState> tracing_state;
} // namespace detail
static std::atomic<bool> tracer_state_warn_mode{true};
std::atomic<bool>& getTracerStateWarnMode() {
return tracer_state_warn_mode;
}
std::function<void()> pauseTracing() {
// NOLINTNEXTLINE
std::shared_ptr<tracer::TracingState> state = getTracingState();

View File

@ -40,6 +40,8 @@ using ::c10::ivalue::ConstantString;
using torch::autograd::Variable;
using variable_list = std::vector<Variable>;
TORCH_API std::atomic<bool>& getTracerStateWarnMode();
struct TORCH_API TracingState
: public std::enable_shared_from_this<TracingState> {
TracingState();
@ -48,7 +50,7 @@ struct TORCH_API TracingState
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<Graph> graph;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool warn = true;
bool warn = getTracerStateWarnMode();
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool strict = true;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)

View File

@ -662,6 +662,17 @@ void initJITBindings(PyObject* module) {
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
#endif
.def(
"_jit_set_tracer_state_warn",
[](bool new_warn) {
jit::tracer::getTracerStateWarnMode() = new_warn;
})
.def(
"_jit_get_tracer_state_warn",
[]() {
bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
return current_tracer_warn;
})
.def(
"_jit_set_nvfuser_skip_node_kind",
// Args:

View File

@ -645,6 +645,14 @@ class JitTestCase(JitCommonTestCase):
return sm
class NoTracerWarnContextManager(object):
def __enter__(self):
self.prev = torch._C._jit_get_tracer_state_warn()
torch._C._jit_set_tracer_state_warn(False)
def __exit__(self, *args):
torch._C._jit_set_tracer_state_warn(self.prev)
@contextmanager
def inline_everything_mode(should_inline):
old = torch._C._jit_get_inline_everything_mode()