mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Related PR: https://github.com/meta-pytorch/compile-graph-break-site/pull/30 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159975 Approved by: https://github.com/Lucaskabela
222 lines
7.2 KiB
Python
222 lines
7.2 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import io
|
|
import logging
|
|
import warnings
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import same
|
|
from torch._dynamo.utils import counters
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger_test = logging.getLogger("test")
|
|
|
|
|
|
def f_info(x):
|
|
x = x + x
|
|
logger.info("moo")
|
|
x = x * x
|
|
return x
|
|
|
|
|
|
def f_isEnabledFor(x):
|
|
x = x + x
|
|
if logger.isEnabledFor(logging.INFO):
|
|
logger.info("moo")
|
|
x = x * x
|
|
return x
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
|
|
@parametrize(
|
|
"ignore_method, fn, should_ignore_logger",
|
|
[
|
|
(None, f_info, False),
|
|
(logger_test.info, f_info, False),
|
|
(None, f_isEnabledFor, False),
|
|
(logger_test.isEnabledFor, f_isEnabledFor, False),
|
|
(logger.info, f_info, True),
|
|
(logging.Logger.info, f_info, True),
|
|
(logger.isEnabledFor, f_isEnabledFor, True),
|
|
(logging.Logger.isEnabledFor, f_isEnabledFor, True),
|
|
],
|
|
)
|
|
def test_ignore_logger(self, ignore_method, fn, should_ignore_logger):
|
|
counters.clear()
|
|
x = torch.randn(3, 3)
|
|
orig_out = fn(x)
|
|
with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}):
|
|
opt_f = torch.compile(backend="eager")(fn)
|
|
with self.assertLogs(logger, level="INFO") as captured:
|
|
logger.info("call logger info to avoid error")
|
|
opt_out = opt_f(x)
|
|
printed_output = [entry.split(":", 2)[2] for entry in captured.output]
|
|
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
if should_ignore_logger:
|
|
self.assertNotIn("moo", printed_output)
|
|
self.assertEqual(len(counters["graph_break"]), 0)
|
|
else:
|
|
self.assertIn("moo", printed_output)
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
|
|
|
|
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
|
|
def test_dont_reorder_print(self):
|
|
def f(x):
|
|
x = x + x
|
|
print("moo")
|
|
x = x * x
|
|
return x
|
|
|
|
counters.clear()
|
|
x = torch.randn(3, 3)
|
|
opt_f = torch.compile(backend="eager")(f)
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_out = opt_f(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
orig_out = f(x)
|
|
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
self.assertEqual(printed_output, "moo")
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
|
|
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
|
def test_reorder_print(self):
|
|
def f(x):
|
|
print("moo")
|
|
x1 = x + x
|
|
print(x1)
|
|
x2 = x1 * x1
|
|
print(1, 2, 3)
|
|
x3 = x2 + x2
|
|
return (x1, x3)
|
|
|
|
x = torch.ones(3, 3)
|
|
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_out = opt_f(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
orig_out = f(x)
|
|
|
|
self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
|
|
@torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
|
|
def test_reorder_warnings(self):
|
|
import warnings
|
|
|
|
def f(x):
|
|
x1 = x + x
|
|
warnings.warn("moo")
|
|
x2 = x1 * x1
|
|
warnings.warn(f"{x2}")
|
|
x3 = x2 + x2
|
|
return x3
|
|
|
|
x = torch.ones(3, 3)
|
|
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
|
|
with warnings.catch_warnings(record=True) as w:
|
|
opt_out = opt_f(x)
|
|
warning_messages = [str(i.message) for i in w]
|
|
orig_out = f(x)
|
|
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
self.assertIn("moo", warning_messages)
|
|
|
|
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
|
def test_reorder_print_graph_break(self):
|
|
def f(x):
|
|
x1 = x + x
|
|
print(f"res: {x1}")
|
|
x2 = x1 * x1
|
|
torch._dynamo.graph_break()
|
|
x3 = x2 + x2
|
|
print(1, 2, 3)
|
|
return x3
|
|
|
|
x = torch.ones(3, 3)
|
|
opt_f = torch.compile(backend="eager")(f)
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_out = opt_f(x)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
orig_out = f(x)
|
|
|
|
self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
|
|
def test_reorder_custom_log_fn(self):
|
|
custom_logs = []
|
|
|
|
def custom_log(s: str):
|
|
torch._dynamo.graph_break()
|
|
custom_logs.append(s)
|
|
|
|
def f(x):
|
|
custom_log("moo")
|
|
x1 = x + x
|
|
custom_log(f"{x1}")
|
|
return x + x
|
|
|
|
x = torch.ones(3, 3)
|
|
counters.clear()
|
|
with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
|
|
opt_f = torch.compile(backend="eager")(f)
|
|
opt_f(x)
|
|
|
|
self.assertEqual(sum(counters["graph_break"].values()), 1)
|
|
self.assertEqual(custom_logs[0], "moo")
|
|
self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
|
|
|
|
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
|
def test_constant_mutation(self):
|
|
def f(x):
|
|
alist = [x]
|
|
alist.append(x + 1)
|
|
print(alist[-1])
|
|
alist[0].sum().item() # graph break
|
|
res = alist.pop()
|
|
print(alist[-1])
|
|
res.sum().item() # graph break
|
|
return res
|
|
|
|
inputs = (torch.tensor([1]),)
|
|
counters.clear()
|
|
opt_f = torch.compile(backend="eager")(f)
|
|
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
|
opt_out = opt_f(*inputs)
|
|
printed_output = mock_stdout.getvalue().strip()
|
|
orig_out = f(*inputs)
|
|
|
|
self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
|
|
self.assertTrue(same(orig_out, opt_out))
|
|
|
|
graph_break_key = counters["graph_break"].keys()
|
|
self.assertEqual(len(graph_break_key), 1)
|
|
self.assertExpectedInline(
|
|
next(iter(graph_break_key)),
|
|
"""\
|
|
Unsupported Tensor.item() call with capture_scalar_outputs=False
|
|
Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
|
|
Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.
|
|
|
|
Developer debug context: call_method TensorVariable() item () {}
|
|
|
|
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|