Files
pytorch/test/jit/test_warn.py
Xuehai Pan 6ff1e43a41 [BE][Easy][13/19] enforce style for empty lines in import segments in test/j*/ (#129764)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129764
Approved by: https://github.com/ezyang
2024-08-01 12:13:42 +00:00

151 lines
3.6 KiB
Python

# Owner(s): ["oncall: jit"]
import io
import os
import sys
import warnings
from contextlib import redirect_stderr
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once(self):
@torch.jit.script
def fn():
for _ in range(10):
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once_in_loop_func(self):
def w():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_once_per_func(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_once_per_func_in_loop(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
warnings.warn("I am warning you from " + caller)
@torch.jit.script
def foo():
warn("foo")
@torch.jit.script
def bar():
warn("bar")
f = io.StringIO()
with redirect_stderr(f):
foo()
bar()
FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True
).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True
).run(
f.getvalue()
)