mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add logging for num_triton_bundles (#139807)
Summary: Adding logs for number of inductor cache triton bundles Test Plan: Ran adhoc code and looked at dynamo_compile/sandbox https://fburl.com/scuba/dynamo_compile/sandbox/nhktfy19 Differential Revision: D65490826 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139807 Approved by: https://github.com/masnesral
This commit is contained in:
committed by
PyTorch MergeBot
parent
9018326bb8
commit
1270c78268
@ -105,6 +105,7 @@ from .symbolic_convert import (
|
||||
from .trace_rules import is_numpy
|
||||
from .utils import (
|
||||
CleanupManager,
|
||||
codecache_metrics,
|
||||
CompilationMetrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
@ -973,6 +974,7 @@ def _compile(
|
||||
fail_user_frame_lineno: Optional[int] = None
|
||||
torch._dynamo.utils.ReinplaceCounters.clear()
|
||||
guarded_code = None
|
||||
codecache_metrics.clear()
|
||||
try:
|
||||
guarded_code = compile_inner(code, one_graph, hooks, transform)
|
||||
return guarded_code
|
||||
@ -1058,6 +1060,7 @@ def _compile(
|
||||
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
|
||||
"remote_fx_graph_cache_put", None
|
||||
)
|
||||
num_triton_bundles = codecache_metrics.get("num_triton_bundles", None)
|
||||
torch._dynamo.utils.ReinplaceCounters.log()
|
||||
|
||||
else:
|
||||
@ -1078,6 +1081,7 @@ def _compile(
|
||||
remote_cache_time_saved = None
|
||||
remote_fx_graph_cache_get_time = None
|
||||
remote_fx_graph_cache_put_time = None
|
||||
num_triton_bundles = None
|
||||
|
||||
structured_logging_overhead_s = (
|
||||
torch._logging.get_structured_logging_overhead()
|
||||
@ -1146,6 +1150,7 @@ def _compile(
|
||||
config.specialize_float,
|
||||
json.dumps(config_dict),
|
||||
True, # is_forward
|
||||
num_triton_bundles,
|
||||
to_int_ms(remote_fx_graph_cache_get_time),
|
||||
to_int_ms(remote_fx_graph_cache_put_time),
|
||||
start_time_us=start_time_ns // 1000,
|
||||
|
@ -144,6 +144,8 @@ frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
|
||||
lambda: collections.defaultdict(float)
|
||||
)
|
||||
|
||||
codecache_metrics: Counter[str] = collections.Counter()
|
||||
|
||||
timer_counter = itertools.count()
|
||||
|
||||
|
||||
@ -419,6 +421,9 @@ def dynamo_timed(
|
||||
remote_cache_time_saved_s=remote_cache_time_saved,
|
||||
structured_logging_overhead_s=structured_logging_overhead_s,
|
||||
is_forward=False, # is_forward
|
||||
num_triton_bundles=codecache_metrics.get(
|
||||
"num_triton_bundles", None
|
||||
),
|
||||
remote_fx_graph_cache_get_time_ms=to_int_ms(
|
||||
remote_fx_graph_cache_get_time
|
||||
),
|
||||
@ -899,6 +904,7 @@ class CompilationMetrics:
|
||||
specialize_float: Optional[bool] = None
|
||||
dynamo_config: Optional[str] = None
|
||||
is_forward: Optional[bool] = None
|
||||
num_triton_bundles: Optional[int] = None
|
||||
remote_fx_graph_cache_get_time_ms: Optional[int] = None
|
||||
remote_fx_graph_cache_put_time_ms: Optional[int] = None
|
||||
start_time_us: Optional[int] = None
|
||||
|
@ -55,6 +55,7 @@ import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.utils import (
|
||||
add_remote_cache_time_saved,
|
||||
codecache_metrics,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_chromium_event_logger,
|
||||
@ -1150,6 +1151,8 @@ class FxGraphCache:
|
||||
logger.add_event_data(
|
||||
"inductor_compile", cached_kernel_names=meta.cached_kernel_names
|
||||
)
|
||||
if len(meta.cached_kernel_names) > 0:
|
||||
codecache_metrics["num_triton_bundles"] += 1
|
||||
|
||||
inductor_meta = autotune_cache.inductor_meta_from_config()
|
||||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
|
Reference in New Issue
Block a user