mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pytorch][counters] Pybind for WaitCounter (#132167)
Summary: Basic pybind integration for WaitCounter providing a guard API. Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API). Test Plan: unit test Reviewed By: asiab4 Differential Revision: D60463979 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132167 Approved by: https://github.com/asiab4
This commit is contained in:
committed by
PyTorch MergeBot
parent
39a3c98aa6
commit
2c7bd61afa
@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
|
||||
|
||||
class WaitGuard {
|
||||
public:
|
||||
WaitGuard(WaitGuard&& other) noexcept
|
||||
: handle_{std::exchange(other.handle_, {})},
|
||||
ctxs_{std::move(other.ctxs_)} {}
|
||||
WaitGuard(const WaitGuard&) = delete;
|
||||
WaitGuard& operator=(const WaitGuard&) = delete;
|
||||
WaitGuard& operator=(WaitGuard&&) = delete;
|
||||
|
||||
~WaitGuard() {
|
||||
stop();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ from torch.monitor import (
|
||||
unregister_event_handler,
|
||||
Stat,
|
||||
TensorboardEventHandler,
|
||||
WaitCounter,
|
||||
)
|
||||
|
||||
class TestMonitor(TestCase):
|
||||
@ -98,6 +99,13 @@ class TestMonitor(TestCase):
|
||||
log_event(e)
|
||||
self.assertEqual(len(events), 2)
|
||||
|
||||
def test_wait_counter(self) -> None:
|
||||
wait_counter = WaitCounter(
|
||||
"test_wait_counter",
|
||||
)
|
||||
with wait_counter.guard() as wcg:
|
||||
pass
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Really weird error")
|
||||
class TestMonitorTensorboard(TestCase):
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include <c10/util/WaitCounter.h>
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
|
||||
after calling ``register_event_handler``. After this returns the event
|
||||
handler will no longer receive events.
|
||||
)DOC");
|
||||
|
||||
struct WaitCounterTracker {
|
||||
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
|
||||
: handle{h} {}
|
||||
c10::monitor::WaitCounterHandle handle;
|
||||
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
|
||||
};
|
||||
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
|
||||
m, "WaitCounterTracker")
|
||||
.def(
|
||||
"__enter__",
|
||||
[](const std::shared_ptr<WaitCounterTracker>& self) {
|
||||
self->guard.emplace(self->handle.start());
|
||||
})
|
||||
.def(
|
||||
"__exit__",
|
||||
[](const std::shared_ptr<WaitCounterTracker>& self,
|
||||
const pybind11::args&) { self->guard.reset(); });
|
||||
|
||||
py::class_<c10::monitor::WaitCounterHandle>(
|
||||
m,
|
||||
"WaitCounter",
|
||||
R"DOC(
|
||||
WaitCounter represents a named duration counter.
|
||||
Multiple units of work can be tracked by the same WaitCounter. Depending
|
||||
on the backend, the WaitCounter may track the number of units of work,
|
||||
their duration etc.
|
||||
)DOC")
|
||||
.def(
|
||||
py::init([](const std::string& key) {
|
||||
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
|
||||
}),
|
||||
py::arg("key"))
|
||||
.def(
|
||||
"guard",
|
||||
[](const c10::monitor::WaitCounterHandle* self) {
|
||||
return std::make_shared<WaitCounterTracker>(*self);
|
||||
},
|
||||
R"DOC(
|
||||
Creates a guard that manages a single unit of work.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
} // namespace monitor
|
||||
|
Reference in New Issue
Block a user