[dynamo] Reorder logs (#116106)

Currently when there is a print/warning in the graph, dynamo graph breaks causing export to fail. However export would like to just skip over these print/warning calls: https://github.com/pytorch/pytorch/issues/113792.

Additionally there's a torch.compile feature request to "reorder prints" so that instead of graph breaking when hitting prints/logging, we can skip over these prints to create larger compiled graphs, and then print the results out after those compiled graphs: https://github.com/pytorch/pytorch/issues/93739. This PR also adds the `reorderable_logging_functions` config for users to register logging functions to be reordered (like `print` or a custom logging function). Printout of the bytecode after reordering the prints looks like the following: P914736600

There are some limitations to the printing right now:
* You can only register logging functions, not methods
* Inputs to the logging functions can only be tensors, constants, and format strings
* Inputs to the logging functions which will later be mutated in-place will not be printed correctly

TODO: Add the following tests
* print function with argument of nested data structure;
* print function with argument of nested data structure being updated inside of compile region (this would test if we handle side effect correctly);
* custom defined logging functions with nn.Module or nn.Module attribute arguments;
* custom defined logging functions with submodule input/output as arguments (we need to handle the mapping and fused-out value);
* custom defined logging functions with tensor argument and mutation inside of the function (TBD: this may increase memory usage);

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116106
Approved by: https://github.com/yanboliang
This commit is contained in:
angelayi
2024-03-01 17:04:24 +00:00
committed by PyTorch MergeBot
parent 9fc56f8209
commit c844b377fa
8 changed files with 286 additions and 1 deletions

View File

@ -0,0 +1,152 @@
# Owner(s): ["module: dynamo"]
import io
import warnings
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
def test_dont_reorder_print(self):
def f(x):
x = x + x
print("moo")
x = x * x
return x
counters.clear()
x = torch.randn(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertEqual(printed_output, "moo")
self.assertEqual(len(counters["graph_break"]), 1)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print(self):
def f(x):
print("moo")
x1 = x + x
print(x1)
x2 = x1 * x1
print(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
@torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
def test_reorder_warnings(self):
import warnings
def f(x):
x1 = x + x
warnings.warn("moo")
x2 = x1 * x1
warnings.warn(f"{x2}")
x3 = x2 + x2
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
with warnings.catch_warnings(record=True) as w:
opt_out = opt_f(x)
warning_messages = [str(i.message) for i in w]
orig_out = f(x)
self.assertTrue(same(orig_out, opt_out))
self.assertIn("moo", warning_messages)
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_reorder_print_graph_break(self):
def f(x):
x1 = x + x
print(f"res: {x1}")
x2 = x1 * x1
torch._dynamo.graph_break()
x3 = x2 + x2
print(1, 2, 3)
return x3
x = torch.ones(3, 3)
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(x)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(x)
self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
self.assertTrue(same(orig_out, opt_out))
def test_reorder_custom_log_fn(self):
custom_logs = []
def custom_log(s: str):
torch._dynamo.graph_break()
custom_logs.append(s)
def f(x):
custom_log("moo")
x1 = x + x
custom_log(f"{x1}")
return x + x
x = torch.ones(3, 3)
counters.clear()
with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
opt_f = torch.compile(backend="eager")(f)
opt_out = opt_f(x)
self.assertEqual(sum(counters["graph_break"].values()), 1)
self.assertEqual(custom_logs[0], "moo")
self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
@torch._dynamo.config.patch(reorderable_logging_functions={print})
def test_constant_mutation(self):
def f(x):
alist = [x]
alist.append(x + 1)
print(alist[-1])
alist[0].sum().item() # graph break
res = alist.pop()
print(alist[-1])
res.sum().item() # graph break
return res
inputs = (torch.tensor([1]),)
counters.clear()
opt_f = torch.compile(backend="eager")(f)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_out = opt_f(*inputs)
printed_output = mock_stdout.getvalue().strip()
orig_out = f(*inputs)
self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
self.assertTrue(same(orig_out, opt_out))
graph_break_key = counters["graph_break"].keys()
self.assertEqual(len(graph_break_key), 1)
self.assertEqual(next(iter(graph_break_key)), "Tensor.item")
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -4,6 +4,7 @@ import copy
import dataclasses
import io
import unittest
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from re import escape
@ -3610,6 +3611,45 @@ def forward(self, arg0_1, arg1_1, arg2_1):
self.assertEqual(res[0], torch.tensor(16))
self.assertEqual(res[1], None)
def test_print(self):
class M(torch.nn.Module):
def forward(self, x):
print("start")
x1 = x + x
print(x1)
x2 = x1 * x1
print(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
gm = export(M(), (torch.randn(3, 3),)).graph_module
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
mul = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
return (add, add_1)""",
)
def test_warning(self):
class M(torch.nn.Module):
def forward(self, x):
warnings.warn("moo")
res = x + x
warnings.warn(f"{res}")
return res
gm = export(M(), (torch.randn(3, 3),)).graph_module
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return (add,)""",
)
def test_constant_fqn(self):
class Nested(torch.nn.Module):
def __init__(self):

View File

@ -5,7 +5,7 @@ import re
import sys
import tempfile
from os.path import abspath, dirname
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
import torch
@ -370,6 +370,12 @@ optimize_user_defined_triton_kernels = True
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
log_compilation_metrics = True
# A set of logging functions which will be reordered to the end of graph breaks,
# allowing dynamo to construct larget graph. Note that there are some
# limitations to this, such as how it does not correctly print objects that were
# mutated after the print statement.
reorderable_logging_functions: Set[Callable[[Any], None]] = set()
# simulates what would happen if we didn't have support for BUILD_SET opcode,
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False

View File

@ -955,6 +955,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
and not len(tx.debug_locals) != 0
and not self.backward_state
):
append_prefix_insts()
@ -1004,6 +1005,14 @@ class OutputGraph(Checkpointable[OutputGraphState]):
cg.store_attr(name)
self.side_effects.codegen_hooks(cg)
self.side_effects.codegen_save_tempvars(cg)
# Return variables used for logging at the end
for debug_var, args in tx.debug_locals:
cg(debug_var)
for arg in args:
cg(arg)
cg.extend_output(create_call_function(len(args), True))
cg.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(cg)

View File

@ -2139,6 +2139,7 @@ class InstructionTranslator(InstructionTranslatorBase):
for k in vars
if k in f_locals
}
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
if export:
# export gets confused if we never realize unused inputs
# in export mode just eagerly realize everything

View File

@ -129,6 +129,7 @@ from .misc import (
AutogradFunctionContextVariable,
AutogradFunctionVariable,
ComptimeVariable,
DebuggingVariable,
GetAttrVariable,
GetSetDescriptorVariable,
InspectSignatureVariable,
@ -496,6 +497,11 @@ class VariableBuilder:
elif isinstance(value, enum.Enum):
self.install_guards(GuardBuilder.ID_MATCH)
return EnumVariable(value=value, source=self.source)
elif DebuggingVariable.is_reorderable_logging_function(value):
# Put this above builtin_callable so that print() can be handled
# along with other builtin debugging functions
self.install_guards(GuardBuilder.BUILTIN_MATCH)
return DebuggingVariable(value, source=self.source)
elif is_utils_checkpoint(value):
return build_checkpoint_variable(source=self.source)
elif isinstance(value, functools.partial):

View File

@ -11,6 +11,7 @@ from typing import Dict, List
import torch._C
import torch._numpy as tnp
import torch.utils._pytree as pytree
from .. import config, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
@ -817,3 +818,58 @@ class StringFormatVariable(VariableTracker):
}
codegen(variables.ConstDictVariable(kwargs))
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
class DebuggingVariable(VariableTracker):
"""
Represents a call to a debugging function like print(), or something
registered to config.reorderable_logging_functions.
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
@staticmethod
def is_reorderable_logging_function(obj):
return (
callable(obj)
and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
and obj in torch._dynamo.config.reorderable_logging_functions
)
def call_function(self, tx, args, kwargs):
if tx.export:
# For export cases, we can just make debugging functions no-ops
return
if not self.can_reorder_logs(self.value, args, kwargs):
unimplemented(
f"Reordering debugging function {self.value} "
f"with inputs {args} {kwargs} is not yet implemented."
)
tx.debug_locals.append((self, list(args)))
def reconstruct(self, codegen):
return self.source.reconstruct(codegen)
@staticmethod
def can_reorder_logs(fn, args, kwargs) -> True:
"""
Run some additional checks for what sort of function calls can we
actually reorder.
"""
allowed_input_types = (
variables.TensorVariable,
variables.ConstantVariable,
StringFormatVariable,
)
flat_args = pytree.tree_leaves([args, kwargs])
for arg in flat_args:
if not isinstance(arg, allowed_input_types):
return False
return True

View File

@ -4,6 +4,7 @@ import inspect
import logging
import re
import time
import warnings
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@ -73,9 +74,23 @@ class ExportDynamoConfig:
"""
allow_rnn: bool = True
reorderable_logging_functions: Set[Callable] = dataclasses.field(
default_factory=set
)
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
logging.critical,
logging.debug,
logging.error,
logging.exception,
logging.info,
logging.log,
logging.warning,
print,
warnings.warn,
}
@contextmanager