Files
pytorch/test/dynamo/test_reorder_logs.py

222 lines
7.2 KiB
Python

# Owner(s): ["module: dynamo"]
import io
import logging
import warnings
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
logger = logging.getLogger(__name__)
logger_test = logging.getLogger("test")
def f_info(x):
x = x + x
logger.info("moo")
x = x * x
return x
def f_isEnabledFor(x):
x = x + x
if logger.isEnabledFor(logging.INFO):
logger.info("moo")
x = x * x
return x
@instantiate_parametrized_tests
class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
@parametrize(
"ignore_method, fn, should_ignore_logger",
[
(None, f_info, False),
(logger_test.info, f_info, False),
(None, f_isEnabledFor, False),
(logger_test.isEnabledFor, f_isEnabledFor, False),
(logger.info, f_info, True),
(logging.Logger.info, f_info, True),
(logger.isEnabledFor, f_isEnabledFor, True),
(logging.Logger.isEnabledFor, f_isEnabledFor, True),
],
)
def test_ignore_logger(self, ignore_method, fn, should_ignore_logger):
counters.clear()
x = torch.randn(3, 3)
orig_out = fn(x)
with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}):
opt_f = torch.compile(backend="eager")(fn)
with self.assertLogs(logger, level="INFO") as captured:
logger.info("call logger info to avoid error")
opt_out = opt_f(x)
printed_output = [entry.split(":", 2)[2] for entry in captured.output]
self.assertTrue(same(orig_out, opt_out))
if should_ignore_logger:
self.assertNotIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 0)
else:
self.assertIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 1)
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
def test_dont_reorder_print(self):
def f(x):
x = x + x
print("moo")
x = x * x
return x
counters.clear()
x = torch.randn(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertEqual(printed_output, "moo")
self.assertEqual(len(counters["graph_break"]), 1)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print(self):
def f(x):
print("moo")
x1 = x + x
print(x1)
x2 = x1 * x1
print(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
@torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
def test_reorder_warnings(self):
import warnings
def f(x):
x1 = x + x
warnings.warn("moo")
x2 = x1 * x1
warnings.warn(f"{x2}")
x3 = x2 + x2
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with warnings.catch_warnings(record=True) as w:
opt_out = opt_f(x)
warning_messages = [str(i.message) for i in w]
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertIn("moo", warning_messages)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print_graph_break(self):
def f(x):
x1 = x + x
print(f"res: {x1}")
x2 = x1 * x1
torch._dynamo.graph_break()
x3 = x2 + x2
print(1, 2, 3)
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
def test_reorder_custom_log_fn(self):
custom_logs = []
def custom_log(s: str):
torch._dynamo.graph_break()
custom_logs.append(s)
def f(x):
custom_log("moo")
x1 = x + x
custom_log(f"{x1}")
return x + x
x = torch.ones(3, 3)
counters.clear()
with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
opt_f = torch.compile(backend="eager")(f)
opt_f(x)
self.assertEqual(sum(counters["graph_break"].values()), 1)
self.assertEqual(custom_logs[0], "moo")
self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_constant_mutation(self):
def f(x):
alist = [x]
alist.append(x + 1)
print(alist[-1])
alist[0].sum().item() # graph break
res = alist.pop()
print(alist[-1])
res.sum().item() # graph break
return res
inputs = (torch.tensor([1]),)
counters.clear()
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
self.assertTrue(same(orig_out, opt_out))
graph_break_key = counters["graph_break"].keys()
self.assertEqual(len(graph_break_key), 1)
self.assertExpectedInline(
next(iter(graph_break_key)),
"""\
Unsupported Tensor.item() call with capture_scalar_outputs=False
Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.
Developer debug context: call_method TensorVariable() item () {}
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()