[pytorch][counters] Pybind for WaitCounter (#132357)

Summary:
Basic pybind integration for WaitCounter providing a guard API.
Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API).

Test Plan: unit test

Differential Revision: D60557660

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132357
Approved by: https://github.com/jamesperng, https://github.com/asiab4
This commit is contained in:
Andrii Grynenko
2024-08-02 16:08:10 +00:00
committed by PyTorch MergeBot
parent d224857b3a
commit fca2dba7ca
4 changed files with 64 additions and 7 deletions

View File

@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
class WaitGuard {
public:
WaitGuard(WaitGuard&& other) noexcept
: handle_{std::exchange(other.handle_, {})},
ctxs_{std::move(other.ctxs_)} {}
WaitGuard(const WaitGuard&) = delete;
WaitGuard& operator=(const WaitGuard&) = delete;
WaitGuard& operator=(WaitGuard&&) = delete;
~WaitGuard() {
stop();
}

View File

@ -1,22 +1,21 @@
# Owner(s): ["oncall: r2p"]
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfTorchDynamo,
)
from datetime import timedelta, datetime
import tempfile
import time
from datetime import datetime, timedelta
from torch.monitor import (
Aggregation,
Event,
log_event,
register_event_handler,
unregister_event_handler,
Stat,
TensorboardEventHandler,
unregister_event_handler,
_WaitCounter,
)
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
class TestMonitor(TestCase):
def test_interval_stat(self) -> None:
@ -98,6 +97,13 @@ class TestMonitor(TestCase):
log_event(e)
self.assertEqual(len(events), 2)
def test_wait_counter(self) -> None:
wait_counter = _WaitCounter(
"test_wait_counter",
)
with wait_counter.guard() as wcg:
pass
@skipIfTorchDynamo("Really weird error")
class TestMonitorTensorboard(TestCase):

View File

@ -1,5 +1,7 @@
#include <utility>
#include <c10/util/WaitCounter.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
after calling ``register_event_handler``. After this returns the event
handler will no longer receive events.
)DOC");
struct WaitCounterTracker {
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
: handle{h} {}
c10::monitor::WaitCounterHandle handle;
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
};
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
m, "_WaitCounterTracker")
.def(
"__enter__",
[](const std::shared_ptr<WaitCounterTracker>& self) {
self->guard.emplace(self->handle.start());
})
.def(
"__exit__",
[](const std::shared_ptr<WaitCounterTracker>& self,
const pybind11::args&) { self->guard.reset(); });
py::class_<c10::monitor::WaitCounterHandle>(
m,
"_WaitCounter",
R"DOC(
WaitCounter represents a named duration counter.
Multiple units of work can be tracked by the same WaitCounter. Depending
on the backend, the WaitCounter may track the number of units of work,
their duration etc.
)DOC")
.def(
py::init([](const std::string& key) {
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
}),
py::arg("key"))
.def(
"guard",
[](const c10::monitor::WaitCounterHandle* self) {
return std::make_shared<WaitCounterTracker>(*self);
},
R"DOC(
Creates a guard that manages a single unit of work.
)DOC");
}
} // namespace monitor

View File

@ -1,7 +1,8 @@
from torch._C._monitor import * # noqa: F403
from typing import TYPE_CHECKING
from torch._C._monitor import _WaitCounter # type: ignore[attr-defined]
if TYPE_CHECKING:
from torch.utils.tensorboard import SummaryWriter