mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disable TracerWarnings on NNC opinfo tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78756 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5a0d8dccc
commit
eb49dde9cf
@ -94,3 +94,15 @@ class TestJitUtils(JitTestCase):
|
||||
""")
|
||||
|
||||
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
|
||||
|
||||
def test_no_tracer_warn_context_manager(self):
|
||||
torch._C._jit_set_tracer_state_warn(True)
|
||||
with jit_utils.NoTracerWarnContextManager() as no_warn:
|
||||
self.assertEqual(
|
||||
False,
|
||||
torch._C._jit_get_tracer_state_warn()
|
||||
)
|
||||
self.assertEqual(
|
||||
True,
|
||||
torch._C._jit_get_tracer_state_warn()
|
||||
)
|
||||
|
@ -24,7 +24,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH
|
||||
enable_profiling_mode_for_profiling_tests, slowTest
|
||||
from torch.testing._internal.jit_utils import JitTestCase, \
|
||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
||||
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions
|
||||
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
|
||||
@ -2624,6 +2624,7 @@ def f({', '.join(param_names)}):
|
||||
@onlyCPU
|
||||
@ops(op_db, dtypes=OpDTypes.supported)
|
||||
def test_nnc_correctness(self, device, dtype, op):
|
||||
with NoTracerWarnContextManager() as no_warn:
|
||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||
|
||||
for variant, sample in variant_sample_pairs:
|
||||
@ -2631,7 +2632,6 @@ def f({', '.join(param_names)}):
|
||||
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||
|
||||
self.assertEqual(ref, val)
|
||||
|
@ -66,6 +66,12 @@ void badArgType(const T& v) {
|
||||
thread_local std::shared_ptr<TracingState> tracing_state;
|
||||
} // namespace detail
|
||||
|
||||
static std::atomic<bool> tracer_state_warn_mode{true};
|
||||
|
||||
std::atomic<bool>& getTracerStateWarnMode() {
|
||||
return tracer_state_warn_mode;
|
||||
}
|
||||
|
||||
std::function<void()> pauseTracing() {
|
||||
// NOLINTNEXTLINE
|
||||
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
||||
|
@ -40,6 +40,8 @@ using ::c10::ivalue::ConstantString;
|
||||
using torch::autograd::Variable;
|
||||
using variable_list = std::vector<Variable>;
|
||||
|
||||
TORCH_API std::atomic<bool>& getTracerStateWarnMode();
|
||||
|
||||
struct TORCH_API TracingState
|
||||
: public std::enable_shared_from_this<TracingState> {
|
||||
TracingState();
|
||||
@ -48,7 +50,7 @@ struct TORCH_API TracingState
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
std::shared_ptr<Graph> graph;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
bool warn = true;
|
||||
bool warn = getTracerStateWarnMode();
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
bool strict = true;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
|
@ -662,6 +662,17 @@ void initJITBindings(PyObject* module) {
|
||||
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
|
||||
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
|
||||
#endif
|
||||
.def(
|
||||
"_jit_set_tracer_state_warn",
|
||||
[](bool new_warn) {
|
||||
jit::tracer::getTracerStateWarnMode() = new_warn;
|
||||
})
|
||||
.def(
|
||||
"_jit_get_tracer_state_warn",
|
||||
[]() {
|
||||
bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
|
||||
return current_tracer_warn;
|
||||
})
|
||||
.def(
|
||||
"_jit_set_nvfuser_skip_node_kind",
|
||||
// Args:
|
||||
|
@ -645,6 +645,14 @@ class JitTestCase(JitCommonTestCase):
|
||||
|
||||
return sm
|
||||
|
||||
class NoTracerWarnContextManager(object):
|
||||
def __enter__(self):
|
||||
self.prev = torch._C._jit_get_tracer_state_warn()
|
||||
torch._C._jit_set_tracer_state_warn(False)
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch._C._jit_set_tracer_state_warn(self.prev)
|
||||
|
||||
@contextmanager
|
||||
def inline_everything_mode(should_inline):
|
||||
old = torch._C._jit_get_inline_everything_mode()
|
||||
|
Reference in New Issue
Block a user