mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Reorder logs (#116106)
Currently when there is a print/warning in the graph, dynamo graph breaks causing export to fail. However export would like to just skip over these print/warning calls: https://github.com/pytorch/pytorch/issues/113792. Additionally there's a torch.compile feature request to "reorder prints" so that instead of graph breaking when hitting prints/logging, we can skip over these prints to create larger compiled graphs, and then print the results out after those compiled graphs: https://github.com/pytorch/pytorch/issues/93739. This PR also adds the `reorderable_logging_functions` config for users to register logging functions to be reordered (like `print` or a custom logging function). Printout of the bytecode after reordering the prints looks like the following: P914736600 There are some limitations to the printing right now: * You can only register logging functions, not methods * Inputs to the logging functions can only be tensors, constants, and format strings * Inputs to the logging functions which will later be mutated in-place will not be printed correctly TODO: Add the following tests * print function with argument of nested data structure; * print function with argument of nested data structure being updated inside of compile region (this would test if we handle side effect correctly); * custom defined logging functions with nn.Module or nn.Module attribute arguments; * custom defined logging functions with submodule input/output as arguments (we need to handle the mapping and fused-out value); * custom defined logging functions with tensor argument and mutation inside of the function (TBD: this may increase memory usage); Pull Request resolved: https://github.com/pytorch/pytorch/pull/116106 Approved by: https://github.com/yanboliang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fc56f8209
commit
c844b377fa
152
test/dynamo/test_reorder_logs.py
Normal file
152
test/dynamo/test_reorder_logs.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import io
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
|
||||
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
|
||||
def test_dont_reorder_print(self):
|
||||
def f(x):
|
||||
x = x + x
|
||||
print("moo")
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
counters.clear()
|
||||
x = torch.randn(3, 3)
|
||||
opt_f = torch.compile(backend="eager")(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(x)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(x)
|
||||
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
self.assertEqual(printed_output, "moo")
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
|
||||
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
||||
def test_reorder_print(self):
|
||||
def f(x):
|
||||
print("moo")
|
||||
x1 = x + x
|
||||
print(x1)
|
||||
x2 = x1 * x1
|
||||
print(1, 2, 3)
|
||||
x3 = x2 + x2
|
||||
return (x1, x3)
|
||||
|
||||
x = torch.ones(3, 3)
|
||||
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(x)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(x)
|
||||
|
||||
self.assertEqual(printed_output, f"moo\n{torch.ones(3, 3) * 2}\n1 2 3")
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
|
||||
@torch._dynamo.config.patch(reorderable_logging_functions={warnings.warn})
|
||||
def test_reorder_warnings(self):
|
||||
import warnings
|
||||
|
||||
def f(x):
|
||||
x1 = x + x
|
||||
warnings.warn("moo")
|
||||
x2 = x1 * x1
|
||||
warnings.warn(f"{x2}")
|
||||
x3 = x2 + x2
|
||||
return x3
|
||||
|
||||
x = torch.ones(3, 3)
|
||||
opt_f = torch.compile(backend="eager", fullgraph=True)(f)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
opt_out = opt_f(x)
|
||||
warning_messages = [str(i.message) for i in w]
|
||||
orig_out = f(x)
|
||||
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
self.assertIn("moo", warning_messages)
|
||||
|
||||
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
||||
def test_reorder_print_graph_break(self):
|
||||
def f(x):
|
||||
x1 = x + x
|
||||
print(f"res: {x1}")
|
||||
x2 = x1 * x1
|
||||
torch._dynamo.graph_break()
|
||||
x3 = x2 + x2
|
||||
print(1, 2, 3)
|
||||
return x3
|
||||
|
||||
x = torch.ones(3, 3)
|
||||
opt_f = torch.compile(backend="eager")(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(x)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(x)
|
||||
|
||||
self.assertEqual(printed_output, f"res: {torch.ones(3, 3) * 2}\n1 2 3")
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
|
||||
def test_reorder_custom_log_fn(self):
|
||||
custom_logs = []
|
||||
|
||||
def custom_log(s: str):
|
||||
torch._dynamo.graph_break()
|
||||
custom_logs.append(s)
|
||||
|
||||
def f(x):
|
||||
custom_log("moo")
|
||||
x1 = x + x
|
||||
custom_log(f"{x1}")
|
||||
return x + x
|
||||
|
||||
x = torch.ones(3, 3)
|
||||
counters.clear()
|
||||
with torch._dynamo.config.patch(reorderable_logging_functions={custom_log}):
|
||||
opt_f = torch.compile(backend="eager")(f)
|
||||
opt_out = opt_f(x)
|
||||
|
||||
self.assertEqual(sum(counters["graph_break"].values()), 1)
|
||||
self.assertEqual(custom_logs[0], "moo")
|
||||
self.assertEqual(custom_logs[1], f"{torch.ones(3, 3) * 2}")
|
||||
|
||||
@torch._dynamo.config.patch(reorderable_logging_functions={print})
|
||||
def test_constant_mutation(self):
|
||||
def f(x):
|
||||
alist = [x]
|
||||
alist.append(x + 1)
|
||||
print(alist[-1])
|
||||
alist[0].sum().item() # graph break
|
||||
res = alist.pop()
|
||||
print(alist[-1])
|
||||
res.sum().item() # graph break
|
||||
return res
|
||||
|
||||
inputs = (torch.tensor([1]),)
|
||||
counters.clear()
|
||||
opt_f = torch.compile(backend="eager")(f)
|
||||
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
|
||||
opt_out = opt_f(*inputs)
|
||||
printed_output = mock_stdout.getvalue().strip()
|
||||
orig_out = f(*inputs)
|
||||
|
||||
self.assertEqual(printed_output, "tensor([2])\ntensor([1])")
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
|
||||
graph_break_key = counters["graph_break"].keys()
|
||||
self.assertEqual(len(graph_break_key), 1)
|
||||
self.assertEqual(next(iter(graph_break_key)), "Tensor.item")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -4,6 +4,7 @@ import copy
|
||||
import dataclasses
|
||||
import io
|
||||
import unittest
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from re import escape
|
||||
@ -3610,6 +3611,45 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
self.assertEqual(res[0], torch.tensor(16))
|
||||
self.assertEqual(res[1], None)
|
||||
|
||||
def test_print(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
print("start")
|
||||
x1 = x + x
|
||||
print(x1)
|
||||
x2 = x1 * x1
|
||||
print(1, 2, 3)
|
||||
x3 = x2 + x2
|
||||
return (x1, x3)
|
||||
|
||||
gm = export(M(), (torch.randn(3, 3),)).graph_module
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, add)
|
||||
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
|
||||
return (add, add_1)""",
|
||||
)
|
||||
|
||||
def test_warning(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
warnings.warn("moo")
|
||||
res = x + x
|
||||
warnings.warn(f"{res}")
|
||||
return res
|
||||
|
||||
gm = export(M(), (torch.randn(3, 3),)).graph_module
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
|
||||
return (add,)""",
|
||||
)
|
||||
|
||||
def test_constant_fqn(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -5,7 +5,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
from os.path import abspath, dirname
|
||||
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -370,6 +370,12 @@ optimize_user_defined_triton_kernels = True
|
||||
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
|
||||
log_compilation_metrics = True
|
||||
|
||||
# A set of logging functions which will be reordered to the end of graph breaks,
|
||||
# allowing dynamo to construct larget graph. Note that there are some
|
||||
# limitations to this, such as how it does not correctly print objects that were
|
||||
# mutated after the print statement.
|
||||
reorderable_logging_functions: Set[Callable[[Any], None]] = set()
|
||||
|
||||
# simulates what would happen if we didn't have support for BUILD_SET opcode,
|
||||
# used for testing
|
||||
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
|
||||
|
||||
@ -955,6 +955,7 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
and all(isinstance(x, TensorVariable) for x in stack_values)
|
||||
and len(set(stack_values)) == len(stack_values)
|
||||
and self.side_effects.is_empty()
|
||||
and not len(tx.debug_locals) != 0
|
||||
and not self.backward_state
|
||||
):
|
||||
append_prefix_insts()
|
||||
@ -1004,6 +1005,14 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
cg.store_attr(name)
|
||||
self.side_effects.codegen_hooks(cg)
|
||||
self.side_effects.codegen_save_tempvars(cg)
|
||||
|
||||
# Return variables used for logging at the end
|
||||
for debug_var, args in tx.debug_locals:
|
||||
cg(debug_var)
|
||||
for arg in args:
|
||||
cg(arg)
|
||||
cg.extend_output(create_call_function(len(args), True))
|
||||
|
||||
cg.restore_stack(stack_values, value_from_source=not tx.export)
|
||||
self.side_effects.codegen_update_mutated(cg)
|
||||
|
||||
|
||||
@ -2139,6 +2139,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
for k in vars
|
||||
if k in f_locals
|
||||
}
|
||||
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = []
|
||||
if export:
|
||||
# export gets confused if we never realize unused inputs
|
||||
# in export mode just eagerly realize everything
|
||||
|
||||
@ -129,6 +129,7 @@ from .misc import (
|
||||
AutogradFunctionContextVariable,
|
||||
AutogradFunctionVariable,
|
||||
ComptimeVariable,
|
||||
DebuggingVariable,
|
||||
GetAttrVariable,
|
||||
GetSetDescriptorVariable,
|
||||
InspectSignatureVariable,
|
||||
@ -496,6 +497,11 @@ class VariableBuilder:
|
||||
elif isinstance(value, enum.Enum):
|
||||
self.install_guards(GuardBuilder.ID_MATCH)
|
||||
return EnumVariable(value=value, source=self.source)
|
||||
elif DebuggingVariable.is_reorderable_logging_function(value):
|
||||
# Put this above builtin_callable so that print() can be handled
|
||||
# along with other builtin debugging functions
|
||||
self.install_guards(GuardBuilder.BUILTIN_MATCH)
|
||||
return DebuggingVariable(value, source=self.source)
|
||||
elif is_utils_checkpoint(value):
|
||||
return build_checkpoint_variable(source=self.source)
|
||||
elif isinstance(value, functools.partial):
|
||||
|
||||
@ -11,6 +11,7 @@ from typing import Dict, List
|
||||
|
||||
import torch._C
|
||||
import torch._numpy as tnp
|
||||
import torch.utils._pytree as pytree
|
||||
from .. import config, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import unimplemented
|
||||
@ -817,3 +818,58 @@ class StringFormatVariable(VariableTracker):
|
||||
}
|
||||
codegen(variables.ConstDictVariable(kwargs))
|
||||
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
|
||||
|
||||
|
||||
class DebuggingVariable(VariableTracker):
|
||||
"""
|
||||
Represents a call to a debugging function like print(), or something
|
||||
registered to config.reorderable_logging_functions.
|
||||
"""
|
||||
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def is_reorderable_logging_function(obj):
|
||||
return (
|
||||
callable(obj)
|
||||
and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
|
||||
and obj in torch._dynamo.config.reorderable_logging_functions
|
||||
)
|
||||
|
||||
def call_function(self, tx, args, kwargs):
|
||||
if tx.export:
|
||||
# For export cases, we can just make debugging functions no-ops
|
||||
return
|
||||
|
||||
if not self.can_reorder_logs(self.value, args, kwargs):
|
||||
unimplemented(
|
||||
f"Reordering debugging function {self.value} "
|
||||
f"with inputs {args} {kwargs} is not yet implemented."
|
||||
)
|
||||
|
||||
tx.debug_locals.append((self, list(args)))
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
return self.source.reconstruct(codegen)
|
||||
|
||||
@staticmethod
|
||||
def can_reorder_logs(fn, args, kwargs) -> True:
|
||||
"""
|
||||
Run some additional checks for what sort of function calls can we
|
||||
actually reorder.
|
||||
"""
|
||||
|
||||
allowed_input_types = (
|
||||
variables.TensorVariable,
|
||||
variables.ConstantVariable,
|
||||
StringFormatVariable,
|
||||
)
|
||||
|
||||
flat_args = pytree.tree_leaves([args, kwargs])
|
||||
for arg in flat_args:
|
||||
if not isinstance(arg, allowed_input_types):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -4,6 +4,7 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
@ -73,9 +74,23 @@ class ExportDynamoConfig:
|
||||
"""
|
||||
|
||||
allow_rnn: bool = True
|
||||
reorderable_logging_functions: Set[Callable] = dataclasses.field(
|
||||
default_factory=set
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
|
||||
DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
|
||||
logging.critical,
|
||||
logging.debug,
|
||||
logging.error,
|
||||
logging.exception,
|
||||
logging.info,
|
||||
logging.log,
|
||||
logging.warning,
|
||||
print,
|
||||
warnings.warn,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user