mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Disable TracerWarnings on NNC opinfo tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78756 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5a0d8dccc
commit
eb49dde9cf
@ -94,3 +94,15 @@ class TestJitUtils(JitTestCase):
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
|
self.checkScriptRaisesRegex(s, (), Exception, "range", name="fn")
|
||||||
|
|
||||||
|
def test_no_tracer_warn_context_manager(self):
|
||||||
|
torch._C._jit_set_tracer_state_warn(True)
|
||||||
|
with jit_utils.NoTracerWarnContextManager() as no_warn:
|
||||||
|
self.assertEqual(
|
||||||
|
False,
|
||||||
|
torch._C._jit_get_tracer_state_warn()
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
True,
|
||||||
|
torch._C._jit_get_tracer_state_warn()
|
||||||
|
)
|
||||||
|
@ -24,7 +24,7 @@ from torch.testing._internal.common_utils import run_tests, ProfilingMode, GRAPH
|
|||||||
enable_profiling_mode_for_profiling_tests, slowTest
|
enable_profiling_mode_for_profiling_tests, slowTest
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, \
|
from torch.testing._internal.jit_utils import JitTestCase, \
|
||||||
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward, set_fusion_group_inlining, \
|
||||||
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions
|
clone_inputs, get_traced_sample_variant_pairs, TensorExprTestOptions, NoTracerWarnContextManager
|
||||||
|
|
||||||
from torch.testing._internal.common_methods_invocations import op_db
|
from torch.testing._internal.common_methods_invocations import op_db
|
||||||
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
|
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests, \
|
||||||
@ -2624,23 +2624,23 @@ def f({', '.join(param_names)}):
|
|||||||
@onlyCPU
|
@onlyCPU
|
||||||
@ops(op_db, dtypes=OpDTypes.supported)
|
@ops(op_db, dtypes=OpDTypes.supported)
|
||||||
def test_nnc_correctness(self, device, dtype, op):
|
def test_nnc_correctness(self, device, dtype, op):
|
||||||
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
with NoTracerWarnContextManager() as no_warn:
|
||||||
|
variant_sample_pairs = get_traced_sample_variant_pairs(device, dtype, op)
|
||||||
|
|
||||||
for variant, sample in variant_sample_pairs:
|
for variant, sample in variant_sample_pairs:
|
||||||
trace = create_traced_fn(self, variant, cache_traced_fn=True)
|
trace = create_traced_fn(self, variant, cache_traced_fn=True)
|
||||||
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
ref = variant(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||||
|
|
||||||
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||||
|
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
||||||
|
|
||||||
val = trace(*clone_inputs((sample.input, *sample.args)), **sample.kwargs)
|
self.assertEqual(ref, val)
|
||||||
|
|
||||||
self.assertEqual(ref, val)
|
# https://github.com/pytorch/pytorch/issues/35600
|
||||||
|
# each torch.jit.trace adds state to the _python_cu compilation unit
|
||||||
# https://github.com/pytorch/pytorch/issues/35600
|
# since this test traces a lot of functions, out-of-memory can occur
|
||||||
# each torch.jit.trace adds state to the _python_cu compilation unit
|
# if the CU is not cleared.
|
||||||
# since this test traces a lot of functions, out-of-memory can occur
|
torch.jit._state._python_cu.drop_all_functions()
|
||||||
# if the CU is not cleared.
|
|
||||||
torch.jit._state._python_cu.drop_all_functions()
|
|
||||||
|
|
||||||
only_for = ("cpu", "cuda")
|
only_for = ("cpu", "cuda")
|
||||||
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
|
instantiate_device_type_tests(TestNNCOpInfo, globals(), only_for=only_for)
|
||||||
|
@ -66,6 +66,12 @@ void badArgType(const T& v) {
|
|||||||
thread_local std::shared_ptr<TracingState> tracing_state;
|
thread_local std::shared_ptr<TracingState> tracing_state;
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
static std::atomic<bool> tracer_state_warn_mode{true};
|
||||||
|
|
||||||
|
std::atomic<bool>& getTracerStateWarnMode() {
|
||||||
|
return tracer_state_warn_mode;
|
||||||
|
}
|
||||||
|
|
||||||
std::function<void()> pauseTracing() {
|
std::function<void()> pauseTracing() {
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
std::shared_ptr<tracer::TracingState> state = getTracingState();
|
||||||
|
@ -40,6 +40,8 @@ using ::c10::ivalue::ConstantString;
|
|||||||
using torch::autograd::Variable;
|
using torch::autograd::Variable;
|
||||||
using variable_list = std::vector<Variable>;
|
using variable_list = std::vector<Variable>;
|
||||||
|
|
||||||
|
TORCH_API std::atomic<bool>& getTracerStateWarnMode();
|
||||||
|
|
||||||
struct TORCH_API TracingState
|
struct TORCH_API TracingState
|
||||||
: public std::enable_shared_from_this<TracingState> {
|
: public std::enable_shared_from_this<TracingState> {
|
||||||
TracingState();
|
TracingState();
|
||||||
@ -48,7 +50,7 @@ struct TORCH_API TracingState
|
|||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
std::shared_ptr<Graph> graph;
|
std::shared_ptr<Graph> graph;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
bool warn = true;
|
bool warn = getTracerStateWarnMode();
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
bool strict = true;
|
bool strict = true;
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
|
@ -662,6 +662,17 @@ void initJITBindings(PyObject* module) {
|
|||||||
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
|
.def("_jit_set_llga_enabled", &RegisterLlgaFuseGraph::setEnabled)
|
||||||
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
|
.def("_jit_llga_enabled", &RegisterLlgaFuseGraph::isEnabled)
|
||||||
#endif
|
#endif
|
||||||
|
.def(
|
||||||
|
"_jit_set_tracer_state_warn",
|
||||||
|
[](bool new_warn) {
|
||||||
|
jit::tracer::getTracerStateWarnMode() = new_warn;
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"_jit_get_tracer_state_warn",
|
||||||
|
[]() {
|
||||||
|
bool current_tracer_warn = jit::tracer::getTracerStateWarnMode();
|
||||||
|
return current_tracer_warn;
|
||||||
|
})
|
||||||
.def(
|
.def(
|
||||||
"_jit_set_nvfuser_skip_node_kind",
|
"_jit_set_nvfuser_skip_node_kind",
|
||||||
// Args:
|
// Args:
|
||||||
|
@ -645,6 +645,14 @@ class JitTestCase(JitCommonTestCase):
|
|||||||
|
|
||||||
return sm
|
return sm
|
||||||
|
|
||||||
|
class NoTracerWarnContextManager(object):
|
||||||
|
def __enter__(self):
|
||||||
|
self.prev = torch._C._jit_get_tracer_state_warn()
|
||||||
|
torch._C._jit_set_tracer_state_warn(False)
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
torch._C._jit_set_tracer_state_warn(self.prev)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def inline_everything_mode(should_inline):
|
def inline_everything_mode(should_inline):
|
||||||
old = torch._C._jit_get_inline_everything_mode()
|
old = torch._C._jit_get_inline_everything_mode()
|
||||||
|
Reference in New Issue
Block a user