[pytorch][counters] Pybind for WaitCounter (#132167)

Summary:
Basic pybind integration for WaitCounter providing a guard API.
Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API).

Test Plan: unit test

Reviewed By: asiab4

Differential Revision: D60463979

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132167
Approved by: https://github.com/asiab4
This commit is contained in:
Andrii Grynenko
2024-07-31 16:04:40 +00:00
committed by PyTorch MergeBot
parent 39a3c98aa6
commit 2c7bd61afa
3 changed files with 58 additions and 0 deletions

View File

@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
class WaitGuard {
public:
WaitGuard(WaitGuard&& other) noexcept
: handle_{std::exchange(other.handle_, {})},
ctxs_{std::move(other.ctxs_)} {}
WaitGuard(const WaitGuard&) = delete;
WaitGuard& operator=(const WaitGuard&) = delete;
WaitGuard& operator=(WaitGuard&&) = delete;
~WaitGuard() {
stop();
}

View File

@ -16,6 +16,7 @@ from torch.monitor import (
unregister_event_handler,
Stat,
TensorboardEventHandler,
WaitCounter,
)
class TestMonitor(TestCase):
@ -98,6 +99,13 @@ class TestMonitor(TestCase):
log_event(e)
self.assertEqual(len(events), 2)
def test_wait_counter(self) -> None:
wait_counter = WaitCounter(
"test_wait_counter",
)
with wait_counter.guard() as wcg:
pass
@skipIfTorchDynamo("Really weird error")
class TestMonitorTensorboard(TestCase):

View File

@ -1,5 +1,7 @@
#include <utility>
#include <c10/util/WaitCounter.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
after calling ``register_event_handler``. After this returns the event
handler will no longer receive events.
)DOC");
struct WaitCounterTracker {
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
: handle{h} {}
c10::monitor::WaitCounterHandle handle;
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
};
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
m, "WaitCounterTracker")
.def(
"__enter__",
[](const std::shared_ptr<WaitCounterTracker>& self) {
self->guard.emplace(self->handle.start());
})
.def(
"__exit__",
[](const std::shared_ptr<WaitCounterTracker>& self,
const pybind11::args&) { self->guard.reset(); });
py::class_<c10::monitor::WaitCounterHandle>(
m,
"WaitCounter",
R"DOC(
WaitCounter represents a named duration counter.
Multiple units of work can be tracked by the same WaitCounter. Depending
on the backend, the WaitCounter may track the number of units of work,
their duration etc.
)DOC")
.def(
py::init([](const std::string& key) {
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
}),
py::arg("key"))
.def(
"guard",
[](const c10::monitor::WaitCounterHandle* self) {
return std::make_shared<WaitCounterTracker>(*self);
},
R"DOC(
Creates a guard that manages a single unit of work.
)DOC");
}
} // namespace monitor