diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index 43202f5a6c9a..dba8f82f3ca3 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -45,6 +45,13 @@ class C10_API WaitCounterHandle { class WaitGuard { public: + WaitGuard(WaitGuard&& other) noexcept + : handle_{std::exchange(other.handle_, {})}, + ctxs_{std::move(other.ctxs_)} {} + WaitGuard(const WaitGuard&) = delete; + WaitGuard& operator=(const WaitGuard&) = delete; + WaitGuard& operator=(WaitGuard&&) = delete; + ~WaitGuard() { stop(); } diff --git a/test/test_monitor.py b/test/test_monitor.py index 59d763421d00..e84163b94fab 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -16,6 +16,7 @@ from torch.monitor import ( unregister_event_handler, Stat, TensorboardEventHandler, + WaitCounter, ) class TestMonitor(TestCase): @@ -98,6 +99,13 @@ class TestMonitor(TestCase): log_event(e) self.assertEqual(len(events), 2) + def test_wait_counter(self) -> None: + wait_counter = WaitCounter( + "test_wait_counter", + ) + with wait_counter.guard() as wcg: + pass + @skipIfTorchDynamo("Really weird error") class TestMonitorTensorboard(TestCase): diff --git a/torch/csrc/monitor/python_init.cpp b/torch/csrc/monitor/python_init.cpp index d6ac4f312c41..2e1a9c6dcc01 100644 --- a/torch/csrc/monitor/python_init.cpp +++ b/torch/csrc/monitor/python_init.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) { after calling ``register_event_handler``. After this returns the event handler will no longer receive events. )DOC"); + + struct WaitCounterTracker { + explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h) + : handle{h} {} + c10::monitor::WaitCounterHandle handle; + std::optional guard; + }; + py::class_>( + m, "WaitCounterTracker") + .def( + "__enter__", + [](const std::shared_ptr& self) { + self->guard.emplace(self->handle.start()); + }) + .def( + "__exit__", + [](const std::shared_ptr& self, + const pybind11::args&) { self->guard.reset(); }); + + py::class_( + m, + "WaitCounter", + R"DOC( + WaitCounter represents a named duration counter. + Multiple units of work can be tracked by the same WaitCounter. Depending + on the backend, the WaitCounter may track the number of units of work, + their duration etc. + )DOC") + .def( + py::init([](const std::string& key) { + return std::make_unique(key); + }), + py::arg("key")) + .def( + "guard", + [](const c10::monitor::WaitCounterHandle* self) { + return std::make_shared(*self); + }, + R"DOC( + Creates a guard that manages a single unit of work. + )DOC"); } } // namespace monitor