[Inductor][Observability] Add logging for split cat pass (#116442)

Summary: Add logs for both in the pre and post grad passes

Test Plan:
```
buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode split_batch
```
[2023-12-26 17:14:24,203] [0/0] torch._inductor.fx_passes.post_grad: [INFO] counters of inductor dict after apply the split cat in the post grad pass: Counter({'pattern_matcher_nodes': 4076, 'pattern_matcher_count': 2917, 'remove_split_with_size_one': 1322, 'split_cat_norm': 461, 'consecutive_split_merged': 371, 'scmerge_cat_removed': 41, 'scmerge_cat_added': 32, 'scmerge_split_removed': 28, 'getitem_cat_merged': 11, 'batch_fusion': 7, 'scmerge_split_sections_removed': 3, 'scmerge_split_added': 2, 'split_squeeze_replaced': 2})

[2023-12-26 17:16:28,437] torch._inductor.fx_passes.post_grad: [INFO] counters of inductor dict after apply the split cat in the post grad pass: Counter({'pattern_matcher_nodes': 4122, 'pattern_matcher_count': 2935, 'remove_split_with_size_one': 1322, 'split_cat_norm': 461, 'consecutive_split_merged': 371, 'scmerge_cat_removed': 41, 'batch_fusion': 39, 'scmerge_cat_added': 32, 'scmerge_split_removed': 28, 'getitem_cat_merged': 11, 'scmerge_split_sections_removed': 3, 'scmerge_split_added': 2, 'split_squeeze_replaced': 2})

Differential Revision: D52425400

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116442
Approved by: https://github.com/yanboliang
This commit is contained in:
Menglu Yu
2023-12-29 05:10:45 +00:00
committed by PyTorch MergeBot
parent 8deaa13417
commit df85a920cf
2 changed files with 10 additions and 2 deletions

View File

@ -143,7 +143,7 @@ from user code:
)
test_aot = within_range_record_test(2, 6, aot=logging.INFO)
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG)
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)
@make_logging_test()

View File

@ -32,7 +32,7 @@ from torch._dynamo import (
logging as dynamo_logging,
utils as dynamo_utils,
)
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
from torch._dynamo.utils import counters, detect_fake_mode, lazy_format_graph_code
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import code_hash, CompiledFxGraph, FxGraphCache
@ -511,6 +511,10 @@ def fx_codegen_and_compile(
post_grad_passes(gm, is_inference=is_inference)
V.debug.fx_graph_transformed(gm, example_inputs)
post_grad_graphs_log.debug("%s", lazy_format_graph_code("AFTER POST GRAD", gm))
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the post grad pass: %s",
counters["inductor"],
)
with V.set_fake_mode(fake_mode):
graph = GraphLowering(
@ -1010,6 +1014,10 @@ def compile_fx(
)
model_ = pre_grad_passes(model_, example_inputs_)
log.debug(
"counters of inductor dict after apply passes on the input FX graph in the pre grad pass: %s",
counters["inductor"],
)
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_):
return flatten_graph_inputs(