[invoke_subgraph] Add logging (#155284)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155284
Approved by: https://github.com/zou3519
ghstack dependencies: #155270
This commit is contained in:
Animesh Jain
2025-06-06 23:22:40 -07:00
committed by PyTorch MergeBot
parent 0f3f59784d
commit db491825e0
5 changed files with 72 additions and 7 deletions

View File

@ -169,6 +169,22 @@ class LoggingTests(LoggingTestCase):
self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)
@make_logging_test(hierarchical_compile=True)
def test_hierarchical_compile(self, records):
from torch._higher_order_ops.invoke_subgraph import mark_compile_region
@mark_compile_region
def gn(x):
return x * 2
def fn(x):
return gn(x)
fn_opt = torch.compile(fn, backend="inductor")
fn_opt(torch.ones(1000, 1000))
fn_opt(torch.ones(1000, 1000))
self.assertGreater(len(records), 0)
@make_logging_test()
def test_dynamo_error(self, records):
try:
@ -960,6 +976,7 @@ exclusions = {
"loop_tiling",
"autotuning",
"graph_region_expansion",
"hierarchical_compile",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
def raise_hard_error_if_graph_break(reason):
@ -261,7 +262,7 @@ def _check_supported_callable_arg(
)
def are_same_graph_modules(a_mod, b_mod, fake_mode):
def are_same_graph_modules(fn_name, a_mod, b_mod, fake_mode):
from torch._subclasses._fake_tensor_utils import _CacheKeyState
from torch._subclasses.fake_tensor import extract_tensor_metadata
@ -322,7 +323,11 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
if not check_all_args(a_flat, b_flat):
# print("call_function args failed")
hc_log.debug(
"%s: Graph comparison failed at node (call_function): %s",
fn_name,
a_node,
)
return False
elif a_node.op == "call_method":
if a_node.target != b_node.target:
@ -330,13 +335,17 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
if not check_all_args(a_flat, b_flat):
# print("call_method args failed")
hc_log.debug(
"%s: Graph comparison failed at node (call_method) : %s",
fn_name,
a_node,
)
return False
elif a_node.op == "output":
a_flat, _ = pytree.tree_flatten((a_node.args, a_node.kwargs))
b_flat, _ = pytree.tree_flatten((b_node.args, b_node.kwargs))
if not check_all_args(a_flat, b_flat):
# print("output args failed")
hc_log.debug("%s: Graph comparison failed at the output node", fn_name)
return False
elif a_node.op == "get_attr":
a_attr = getattr(a_mod, a_node.target)
@ -345,7 +354,7 @@ def are_same_graph_modules(a_mod, b_mod, fake_mode):
if not isinstance(b_attr, torch.fx.GraphModule):
return False
# This is an example of a HOP inside a HOP
if not are_same_graph_modules(a_attr, b_attr, fake_mode):
if not are_same_graph_modules(fn_name, a_attr, b_attr, fake_mode):
return False
else:
# TODO - write an example with tensor as a graph attribute in
@ -3359,9 +3368,11 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
if isinstance(fn_vt, UserFunctionVariable):
fn_id = id(fn_vt.get_function())
fn_name = fn_vt.get_function().__name__
else:
assert isinstance(fn_vt, UnspecializedNNModuleVariable)
fn_id = id(fn_vt.value.forward.__func__)
fn_name = fn_vt.value.forward.__name__
previously_installed_submodules = []
if invoke_subgraph_cache:
previously_installed_submodules = (
@ -3373,12 +3384,21 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
for submodule_name in reversed(previously_installed_submodules):
assert submodule_name in tx.output.nn_modules
previous_mod = tx.output.nn_modules[submodule_name]
if are_same_graph_modules(previous_mod, current_mod, tx.fake_mode):
if are_same_graph_modules(
fn_name, previous_mod, current_mod, tx.fake_mode
):
return submodule_name
body_name = super().install_subgraph_in_output_graph(
tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph"
)
hc_log.debug(
"%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s",
fn_name,
body_name,
fn_name,
len(previously_installed_submodules) + 1,
)
if invoke_subgraph_cache:
invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name)

View File

@ -251,6 +251,7 @@ def set_logs(
autotuning: bool = False,
graph_region_expansion: bool = False,
inductor_metrics: bool = False,
hierarchical_compile: bool = False,
) -> None:
"""
Sets the log level for individual components and toggles individual log
@ -448,6 +449,9 @@ def set_logs(
inductor_metrics (:class:`bool`):
Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
hierarchical_compile (:class:`bool`):
Whether to emit debug info for hierarchical compilation. Default: ``False``
Example::
>>> # xdoctest: +SKIP
@ -560,6 +564,7 @@ def set_logs(
autotuning=autotuning,
graph_region_expansion=graph_region_expansion,
inductor_metrics=inductor_metrics,
hierarchical_compile=hierarchical_compile,
)

View File

@ -233,5 +233,9 @@ register_artifact(
"Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes",
off_by_default=True,
)
register_artifact(
"hierarchical_compile",
"Logs debug info for hierarchical compilation",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")

View File

@ -62,6 +62,7 @@ if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
log = logging.getLogger(__name__)
hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile")
# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186
# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105
@ -1433,6 +1434,15 @@ class FakeTensorMode(TorchDispatchMode):
key = self._cache_key(state, func, args, kwargs)
except _BypassDispatchCache as e:
# We couldn't create the cache key at all
if (
isinstance(func, torch._ops.HigherOrderOperator)
and func.name() == "invoke_subgraph"
):
hc_log.debug(
"Fake tensor cache failed: identifier = %s, reason = %s",
args[1],
e.reason,
)
FakeTensorMode.cache_bypasses[e.reason] += 1
if key is None:
@ -1477,6 +1487,15 @@ class FakeTensorMode(TorchDispatchMode):
# We ran "extra" checks on the cache key and determined that it's no
# good. Record the reason and mark it so we don't bother validating
# again.
if (
isinstance(func, torch._ops.HigherOrderOperator)
and func.name() == "invoke_subgraph"
):
hc_log.debug(
"Fake tensor cache failed: identifier = %s, reason = %s",
args[1],
e.reason,
)
FakeTensorMode.cache_bypasses[e.reason] += 1
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
return output