Files
pytorch/test/dynamo/test_metrics_context.py

119 lines
3.9 KiB
Python

# Owner(s): ["module: dynamo"]
from torch._dynamo.metrics_context import MetricsContext, TopN
from torch._dynamo.test_case import run_tests, TestCase
class TestMetricsContext(TestCase):
def setUp(self):
super().setUp()
self.metrics = {}
def _on_exit(self, start_ns, end_ns, metrics, exc_type, exc_value):
# Save away the metrics to be validated in the test.
self.metrics = metrics.copy()
def test_context_exists(self):
"""
Setting a value without entering the context should raise.
"""
context = MetricsContext(self._on_exit)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.increment("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.set("m", 1)
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
context.update({"m", 1})
def test_nested_context(self):
"""
Only the outermost context should get an on_exit call, and it should
include everything.
"""
context = MetricsContext(self._on_exit)
with context:
with context:
context.set("m1", 1)
self.assertEqual(self.metrics, {})
context.set("m2", 2)
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
def test_set(self):
"""
Validate various ways to set metrics.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
context.set("m2", 2)
context.update({"m3": 3, "m4": 4})
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
def test_set_disallow_overwrite(self):
"""
Validate set won't overwrite.
"""
with MetricsContext(self._on_exit) as context:
context.set("m1", 1)
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.set("m1", 2)
self.assertEqual(self.metrics, {"m1": 1})
def test_update_disallow_overwrite(self):
"""
Validate update won't overwrite.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
with self.assertRaisesRegex(RuntimeError, "already been set"):
context.update({"m1": 7, "m3": 3})
def test_update_allow_overwrite(self):
"""
Validate update will overwrite when given param.
"""
with MetricsContext(self._on_exit) as context:
context.update({"m1": 1, "m2": 2})
context.update({"m1": 7, "m3": 3}, overwrite=True)
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
def test_add_to_set(self):
"""
Validate add_to_set.
"""
with MetricsContext(self._on_exit) as context:
context.add_to_set("m1", 1)
context.add_to_set("m1", 2)
context.add_to_set("m2", 3)
context.add_to_set("m2", 4)
self.assertEqual(self.metrics, {"m1": {1, 2}, "m2": {3, 4}})
self.assertTrue(isinstance(self.metrics["m1"], set))
self.assertTrue(isinstance(self.metrics["m2"], set))
def test_set_key_value(self):
with MetricsContext(self._on_exit) as context:
context.set_key_value("feature_usage", "k", True)
# Overrides allowed
context.set_key_value("feature_usage", "k2", True)
context.set_key_value("feature_usage", "k2", False)
self.assertEqual(self.metrics, {"feature_usage": {"k": True, "k2": False}})
def test_top_n(self):
top_n = TopN(3)
for k, v in (("seven", 7), ("four", 4), ("five", 5), ("six", 6), ("eight", 8)):
top_n.add(k, v)
self.assertEqual(len(top_n), 3)
print(list(top_n))
self.assertEqual(list(top_n), [("eight", 8), ("seven", 7), ("six", 6)])
if __name__ == "__main__":
run_tests()