mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[invoke_subgraph] Add logging (#155284)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155284 Approved by: https://github.com/zou3519 ghstack dependencies: #155270
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f3f59784d
commit
db491825e0
@ -169,6 +169,22 @@ class LoggingTests(LoggingTestCase):
|
||||
self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
|
||||
self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)
|
||||
|
||||
@make_logging_test(hierarchical_compile=True)
|
||||
def test_hierarchical_compile(self, records):
|
||||
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
|
||||
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
return x * 2
|
||||
|
||||
def fn(x):
|
||||
return gn(x)
|
||||
|
||||
fn_opt = torch.compile(fn, backend="inductor")
|
||||
fn_opt(torch.ones(1000, 1000))
|
||||
fn_opt(torch.ones(1000, 1000))
|
||||
self.assertGreater(len(records), 0)
|
||||
|
||||
@make_logging_test()
|
||||
def test_dynamo_error(self, records):
|
||||
try:
|
||||
@ -960,6 +976,7 @@ exclusions = {
|
||||
"loop_tiling",
|
||||
"autotuning",
|
||||
"graph_region_expansion",
|
||||
"hierarchical_compile",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
|
||||
|
||||
|
||||
def raise_hard_error_if_graph_break(reason):
|
||||
@ -261,7 +262,7 @@ def _check_supported_callable_arg(
|
||||
)
|
||||
|
||||
|
||||
def are_same_graph_modules(a_mod, b_mod, fake_mode):
|
||||
def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode):
|
||||
from torch._subclasses._fake_tensor_utils import _CacheKeyState
|
||||
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
||||
|
||||
@ -322,7 +323,11 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
|
||||
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
|
||||
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
|
||||
if not check_all_args(a_flat, b_flat):
|
||||
# print("call_function args failed")
|
||||
hc_log.debug(
|
||||
"%s: Graph comparison failed at node (call_function): %s",
|
||||
fn_name,
|
||||
a_node,
|
||||
)
|
||||
return False
|
||||
elif a_node.op == "call_method":
|
||||
if a_node.target != b_node.target:
|
||||
@ -330,13 +335,17 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
|
||||
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
|
||||
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
|
||||
if not check_all_args(a_flat, b_flat):
|
||||
# print("call_method args failed")
|
||||
hc_log.debug(
|
||||
"%s: Graph comparison failed at node (call_method) : %s",
|
||||
fn_name,
|
||||
a_node,
|
||||
)
|
||||
return False
|
||||
elif a_node.op == "output":
|
||||
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
|
||||
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
|
||||
if not check_all_args(a_flat, b_flat):
|
||||
# print("output args failed")
|
||||
hc_log.debug("%s: Graph comparison failed at the output node", fn_name)
|
||||
return False
|
||||
elif a_node.op == "get_attr":
|
||||
a_attr = getattr(a_mod, a_node.target)
|
||||
@ -345,7 +354,7 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
|
||||
if not isinstance(b_attr, torch.fx.GraphModule):
|
||||
return False
|
||||
# This is an example of a HOP inside a HOP
|
||||
if not are_same_graph_modules(a_attr, b_attr, fake_mode):
|
||||
if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode):
|
||||
return False
|
||||
else:
|
||||
# TODO - write an example with tensor as a graph attribute in
|
||||
@ -3359,9 +3368,11 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
|
||||
if isinstance(fn_vt, UserFunctionVariable):
|
||||
fn_id = id(fn_vt.get_function())
|
||||
fn_name = fn_vt.get_function().__name__
|
||||
else:
|
||||
assert isinstance(fn_vt, UnspecializedNNModuleVariable)
|
||||
fn_id = id(fn_vt.value.forward.__func__)
|
||||
fn_name = fn_vt.value.forward.__name__
|
||||
previously_installed_submodules = []
|
||||
if invoke_subgraph_cache:
|
||||
previously_installed_submodules = (
|
||||
@ -3373,12 +3384,21 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
for submodule_name in reversed(previously_installed_submodules):
|
||||
assert submodule_name in tx.output.nn_modules
|
||||
previous_mod = tx.output.nn_modules[submodule_name]
|
||||
if are_same_graph_modules(previous_mod, current_mod, tx.fake_mode):
|
||||
if are_same_graph_modules(
|
||||
fn_name, previous_mod, current_mod, tx.fake_mode
|
||||
):
|
||||
return submodule_name
|
||||
|
||||
body_name = super().install_subgraph_in_output_graph(
|
||||
tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph"
|
||||
)
|
||||
hc_log.debug(
|
||||
"%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s",
|
||||
fn_name,
|
||||
body_name,
|
||||
fn_name,
|
||||
len(previously_installed_submodules) + 1,
|
||||
)
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name)
|
||||
|
||||
|
@ -251,6 +251,7 @@ def set_logs(
|
||||
autotuning: bool = False,
|
||||
graph_region_expansion: bool = False,
|
||||
inductor_metrics: bool = False,
|
||||
hierarchical_compile: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the log level for individual components and toggles individual log
|
||||
@ -448,6 +449,9 @@ def set_logs(
|
||||
inductor_metrics (:class:`bool`):
|
||||
Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
|
||||
|
||||
hierarchical_compile (:class:`bool`):
|
||||
Whether to emit debug info for hierarchical compilation. Default: ``False``
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP
|
||||
@ -560,6 +564,7 @@ def set_logs(
|
||||
autotuning=autotuning,
|
||||
graph_region_expansion=graph_region_expansion,
|
||||
inductor_metrics=inductor_metrics,
|
||||
hierarchical_compile=hierarchical_compile,
|
||||
)
|
||||
|
||||
|
||||
|
@ -233,5 +233,9 @@ register_artifact(
|
||||
"Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact(
|
||||
"hierarchical_compile",
|
||||
"Logs debug info for hierarchical compilation",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
@ -62,6 +62,7 @@ if TYPE_CHECKING:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
|
||||
|
||||
# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186
|
||||
# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105
|
||||
@ -1433,6 +1434,15 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
key = self._cache_key(state, func, args, kwargs)
|
||||
except _BypassDispatchCache as e:
|
||||
# We couldn't create the cache key at all
|
||||
if (
|
||||
isinstance(func, torch._ops.HigherOrderOperator)
|
||||
and func.name() == "invoke_subgraph"
|
||||
):
|
||||
hc_log.debug(
|
||||
"Fake tensor cache failed: identifier = %s, reason = %s",
|
||||
args[1],
|
||||
e.reason,
|
||||
)
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
|
||||
if key is None:
|
||||
@ -1477,6 +1487,15 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# We ran "extra" checks on the cache key and determined that it's no
|
||||
# good. Record the reason and mark it so we don't bother validating
|
||||
# again.
|
||||
if (
|
||||
isinstance(func, torch._ops.HigherOrderOperator)
|
||||
and func.name() == "invoke_subgraph"
|
||||
):
|
||||
hc_log.debug(
|
||||
"Fake tensor cache failed: identifier = %s, reason = %s",
|
||||
args[1],
|
||||
e.reason,
|
||||
)
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
||||
return output
|
||||
|
Reference in New Issue
Block a user