mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157639 Approved by: https://github.com/yewentao256, https://github.com/jansel ghstack dependencies: #157638
119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
from torch._dynamo.metrics_context import MetricsContext, TopN
|
|
from torch._dynamo.test_case import run_tests, TestCase
|
|
|
|
|
|
class TestMetricsContext(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.metrics = {}
|
|
|
|
def _on_exit(self, start_ns, end_ns, metrics, exc_type, exc_value):
|
|
# Save away the metrics to be validated in the test.
|
|
self.metrics = metrics.copy()
|
|
|
|
def test_context_exists(self):
|
|
"""
|
|
Setting a value without entering the context should raise.
|
|
"""
|
|
context = MetricsContext(self._on_exit)
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.increment("m", 1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.set("m", 1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "outside of a MetricsContext"):
|
|
context.update({"m", 1})
|
|
|
|
def test_nested_context(self):
|
|
"""
|
|
Only the outermost context should get an on_exit call, and it should
|
|
include everything.
|
|
"""
|
|
context = MetricsContext(self._on_exit)
|
|
with context:
|
|
with context:
|
|
context.set("m1", 1)
|
|
self.assertEqual(self.metrics, {})
|
|
context.set("m2", 2)
|
|
self.assertEqual(self.metrics, {"m1": 1, "m2": 2})
|
|
|
|
def test_set(self):
|
|
"""
|
|
Validate various ways to set metrics.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set("m1", 1)
|
|
context.set("m2", 2)
|
|
context.update({"m3": 3, "m4": 4})
|
|
|
|
self.assertEqual(self.metrics, {"m1": 1, "m2": 2, "m3": 3, "m4": 4})
|
|
|
|
def test_set_disallow_overwrite(self):
|
|
"""
|
|
Validate set won't overwrite.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set("m1", 1)
|
|
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
|
context.set("m1", 2)
|
|
|
|
self.assertEqual(self.metrics, {"m1": 1})
|
|
|
|
def test_update_disallow_overwrite(self):
|
|
"""
|
|
Validate update won't overwrite.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.update({"m1": 1, "m2": 2})
|
|
with self.assertRaisesRegex(RuntimeError, "already been set"):
|
|
context.update({"m1": 7, "m3": 3})
|
|
|
|
def test_update_allow_overwrite(self):
|
|
"""
|
|
Validate update will overwrite when given param.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.update({"m1": 1, "m2": 2})
|
|
context.update({"m1": 7, "m3": 3}, overwrite=True)
|
|
|
|
self.assertEqual(self.metrics, {"m1": 7, "m2": 2, "m3": 3})
|
|
|
|
def test_add_to_set(self):
|
|
"""
|
|
Validate add_to_set.
|
|
"""
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.add_to_set("m1", 1)
|
|
context.add_to_set("m1", 2)
|
|
context.add_to_set("m2", 3)
|
|
context.add_to_set("m2", 4)
|
|
|
|
self.assertEqual(self.metrics, {"m1": {1, 2}, "m2": {3, 4}})
|
|
self.assertTrue(isinstance(self.metrics["m1"], set))
|
|
self.assertTrue(isinstance(self.metrics["m2"], set))
|
|
|
|
def test_set_key_value(self):
|
|
with MetricsContext(self._on_exit) as context:
|
|
context.set_key_value("feature_usage", "k", True)
|
|
# Overrides allowed
|
|
context.set_key_value("feature_usage", "k2", True)
|
|
context.set_key_value("feature_usage", "k2", False)
|
|
|
|
self.assertEqual(self.metrics, {"feature_usage": {"k": True, "k2": False}})
|
|
|
|
def test_top_n(self):
|
|
top_n = TopN(3)
|
|
for k, v in (("seven", 7), ("four", 4), ("five", 5), ("six", 6), ("eight", 8)):
|
|
top_n.add(k, v)
|
|
|
|
self.assertEqual(len(top_n), 3)
|
|
print(list(top_n))
|
|
self.assertEqual(list(top_n), [("eight", 8), ("seven", 7), ("six", 6)])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|