mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
pytorch/features: Make a feature logger and record triton bundling (#141056)
This modifies metrics_context to allow us to store whether a feature was used or not. This also starts recording this for triton bundling. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141056 Approved by: https://github.com/masnesral
This commit is contained in:
committed by
PyTorch MergeBot
parent
0155a112fd
commit
f5d00f1456
@ -87,6 +87,15 @@ class TestMetricsContext(TestCase):
|
||||
self.assertTrue(isinstance(self.metrics["m1"], set))
|
||||
self.assertTrue(isinstance(self.metrics["m2"], set))
|
||||
|
||||
def test_set_key_value(self):
|
||||
with MetricsContext(self._on_exit) as context:
|
||||
context.set_key_value("feature_usage", "k", True)
|
||||
# Overrides allowed
|
||||
context.set_key_value("feature_usage", "k2", True)
|
||||
context.set_key_value("feature_usage", "k2", False)
|
||||
|
||||
self.assertEqual(self.metrics, {"feature_usage": {"k": True, "k2": False}})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -204,8 +204,10 @@ class TestDynamoTimed(TestCase):
|
||||
|
||||
# First event is for the forward. Formatting makes reading diffs
|
||||
# much easier.
|
||||
raw = dataclasses.asdict(compilation_events[0])
|
||||
del raw["feature_usage"]
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(dataclasses.asdict(compilation_events[0])),
|
||||
pprint.pformat(raw),
|
||||
"""\
|
||||
{'accumulated_cache_size': 0,
|
||||
'aot_autograd_cumulative_compile_time_us': 0,
|
||||
@ -274,8 +276,10 @@ class TestDynamoTimed(TestCase):
|
||||
)
|
||||
|
||||
# Second event is for the backward
|
||||
raw = dataclasses.asdict(compilation_events[1])
|
||||
del raw["feature_usage"]
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(dataclasses.asdict(compilation_events[1])),
|
||||
pprint.pformat(raw),
|
||||
"""\
|
||||
{'accumulated_cache_size': None,
|
||||
'aot_autograd_cumulative_compile_time_us': None,
|
||||
|
@ -71,6 +71,20 @@ class MetricsContext:
|
||||
)
|
||||
self._metrics[metric] = value
|
||||
|
||||
def set_key_value(self, metric: str, key: str, value: Any) -> None:
|
||||
"""
|
||||
Treats a give metric as a dictionary and set the k and value within it.
|
||||
Note that the metric must be a dictionary or not present.
|
||||
|
||||
We allow this to be called multiple times (i.e. for features, it's not uncommon
|
||||
for them to be used multiple times within a single compilation).
|
||||
"""
|
||||
if self._level == 0:
|
||||
raise RuntimeError(f"Cannot set {metric} outside of a MetricsContext")
|
||||
if metric not in self._metrics:
|
||||
self._metrics[metric] = {}
|
||||
self._metrics[metric][key] = value
|
||||
|
||||
def update(self, values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Set multiple metrics directly. This method does NOT increment. Raises if any
|
||||
|
@ -870,6 +870,7 @@ class CompilationMetrics:
|
||||
inductor_fx_remote_cache_backend_type: Optional[str] = None
|
||||
inductor_fx_remote_cache_hit_keys: Optional[str] = None
|
||||
inductor_fx_remote_cache_miss_keys: Optional[str] = None
|
||||
feature_usage: Optional[dict[str, bool]] = None
|
||||
|
||||
|
||||
DEFAULT_COMPILATION_METRICS_LIMIT = 64
|
||||
@ -3587,3 +3588,11 @@ class CompileTimeInstructionCounter:
|
||||
finally:
|
||||
if config.record_compile_time_instruction_count:
|
||||
cls.end()
|
||||
|
||||
|
||||
def set_feature_use(feature: str, usage: bool):
|
||||
"""
|
||||
Records whether we are using a feature
|
||||
Generally a feature is a JK.
|
||||
"""
|
||||
get_metrics_context().set_key_value("feature_usage", feature, usage)
|
||||
|
@ -5,7 +5,7 @@ import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
||||
from torch._utils_internal import justknobs_check
|
||||
|
||||
from .runtime.runtime_utils import triton_cache_dir
|
||||
@ -143,7 +143,13 @@ class TritonBundler:
|
||||
"""
|
||||
if not TritonBundler.is_enabled():
|
||||
cls.end_compile()
|
||||
set_feature_use(
|
||||
"pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2", False
|
||||
)
|
||||
return [], None
|
||||
set_feature_use(
|
||||
"pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2", True
|
||||
)
|
||||
|
||||
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
|
||||
entries = cls._entries
|
||||
|
Reference in New Issue
Block a user