Files
pytorch/test/jit/test_warn.py
Jane Xu 09c7771e9c Set test owners for jit tests (#66808)
Summary:
Action following https://github.com/pytorch/pytorch/issues/66232

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66808

Reviewed By: mrshenli

Differential Revision: D31761414

Pulled By: janeyx99

fbshipit-source-id: baf8c49ff9c4bcda7b0ea0f6aafd26380586e72d
2021-10-25 07:51:10 -07:00

168 lines
4.1 KiB
Python

# Owner(s): ["oncall: jit"]
import os
import sys
import io
import torch
import warnings
from contextlib import redirect_stderr
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_only_once(self):
@torch.jit.script
def fn():
for _ in range(10):
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_only_once_in_loop_func(self):
def w():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=1,
exactly=True) \
.run(f.getvalue())
def test_warn_once_per_func(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_once_per_func_in_loop(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
fn()
FileCheck() \
.check_count(
str="UserWarning: I am warning you",
count=2,
exactly=True) \
.run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
warnings.warn("I am warning you from " + caller)
@torch.jit.script
def foo():
warn("foo")
@torch.jit.script
def bar():
warn("bar")
f = io.StringIO()
with redirect_stderr(f):
foo()
bar()
FileCheck() \
.check_count(
str="UserWarning: I am warning you from foo",
count=1,
exactly=True) \
.check_count(
str="UserWarning: I am warning you from bar",
count=1,
exactly=True) \
.run(f.getvalue())