Files
pytorch/test/jit/test_warn.py
Yuanhao Ji 604c9c5601 Enable UFMT on all of test/jit (#123623)
Partially addresses #123062

Ran lintrunner on:

- `test/jit`

with command:

```bash
lintrunner -a --take UFMT --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123623
Approved by: https://github.com/ezyang
2024-04-11 23:45:05 +00:00

149 lines
3.6 KiB
Python

# Owner(s): ["oncall: jit"]
import io
import os
import sys
import warnings
from contextlib import redirect_stderr
import torch
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestWarn(JitTestCase):
def test_warn(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once(self):
@torch.jit.script
def fn():
for _ in range(10):
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_only_once_in_loop_func(self):
def w():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=1, exactly=True
).run(f.getvalue())
def test_warn_once_per_func(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_once_per_func_in_loop(self):
def w1():
warnings.warn("I am warning you")
def w2():
warnings.warn("I am warning you")
@torch.jit.script
def fn():
for _ in range(10):
w1()
w2()
f = io.StringIO()
with redirect_stderr(f):
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_multiple_warnings(self):
@torch.jit.script
def fn():
warnings.warn("I am warning you")
f = io.StringIO()
with redirect_stderr(f):
fn()
fn()
FileCheck().check_count(
str="UserWarning: I am warning you", count=2, exactly=True
).run(f.getvalue())
def test_warn_multiple_calls_same_func_diff_stack(self):
def warn(caller: str):
warnings.warn("I am warning you from " + caller)
@torch.jit.script
def foo():
warn("foo")
@torch.jit.script
def bar():
warn("bar")
f = io.StringIO()
with redirect_stderr(f):
foo()
bar()
FileCheck().check_count(
str="UserWarning: I am warning you from foo", count=1, exactly=True
).check_count(
str="UserWarning: I am warning you from bar", count=1, exactly=True
).run(
f.getvalue()
)