mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Partially addresses #123062 Ran lintrunner on: - `test/jit` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623 Approved by: https://github.com/ezyang
149 lines
3.6 KiB
Python
149 lines
3.6 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import io
|
|
import os
|
|
import sys
|
|
import warnings
|
|
from contextlib import redirect_stderr
|
|
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead."
|
|
)
|
|
|
|
|
|
class TestWarn(JitTestCase):
|
|
def test_warn(self):
|
|
@torch.jit.script
|
|
def fn():
|
|
warnings.warn("I am warning you")
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=1, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_only_once(self):
|
|
@torch.jit.script
|
|
def fn():
|
|
for _ in range(10):
|
|
warnings.warn("I am warning you")
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=1, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_only_once_in_loop_func(self):
|
|
def w():
|
|
warnings.warn("I am warning you")
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
for _ in range(10):
|
|
w()
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=1, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_once_per_func(self):
|
|
def w1():
|
|
warnings.warn("I am warning you")
|
|
|
|
def w2():
|
|
warnings.warn("I am warning you")
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
w1()
|
|
w2()
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=2, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_once_per_func_in_loop(self):
|
|
def w1():
|
|
warnings.warn("I am warning you")
|
|
|
|
def w2():
|
|
warnings.warn("I am warning you")
|
|
|
|
@torch.jit.script
|
|
def fn():
|
|
for _ in range(10):
|
|
w1()
|
|
w2()
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=2, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_multiple_calls_multiple_warnings(self):
|
|
@torch.jit.script
|
|
def fn():
|
|
warnings.warn("I am warning you")
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
fn()
|
|
fn()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you", count=2, exactly=True
|
|
).run(f.getvalue())
|
|
|
|
def test_warn_multiple_calls_same_func_diff_stack(self):
|
|
def warn(caller: str):
|
|
warnings.warn("I am warning you from " + caller)
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
warn("foo")
|
|
|
|
@torch.jit.script
|
|
def bar():
|
|
warn("bar")
|
|
|
|
f = io.StringIO()
|
|
with redirect_stderr(f):
|
|
foo()
|
|
bar()
|
|
|
|
FileCheck().check_count(
|
|
str="UserWarning: I am warning you from foo", count=1, exactly=True
|
|
).check_count(
|
|
str="UserWarning: I am warning you from bar", count=1, exactly=True
|
|
).run(
|
|
f.getvalue()
|
|
)
|