mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[pytorch][counters] Pybind for WaitCounter (#132357)
Summary: Basic pybind integration for WaitCounter providing a guard API. Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API). Test Plan: unit test Differential Revision: D60557660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132357 Approved by: https://github.com/jamesperng, https://github.com/asiab4
This commit is contained in:
committed by
PyTorch MergeBot
parent
d224857b3a
commit
fca2dba7ca
@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
|
||||
|
||||
class WaitGuard {
|
||||
public:
|
||||
WaitGuard(WaitGuard&& other) noexcept
|
||||
: handle_{std::exchange(other.handle_, {})},
|
||||
ctxs_{std::move(other.ctxs_)} {}
|
||||
WaitGuard(const WaitGuard&) = delete;
|
||||
WaitGuard& operator=(const WaitGuard&) = delete;
|
||||
WaitGuard& operator=(WaitGuard&&) = delete;
|
||||
|
||||
~WaitGuard() {
|
||||
stop();
|
||||
}
|
||||
|
@ -1,22 +1,21 @@
|
||||
# Owner(s): ["oncall: r2p"]
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, skipIfTorchDynamo,
|
||||
)
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from torch.monitor import (
|
||||
Aggregation,
|
||||
Event,
|
||||
log_event,
|
||||
register_event_handler,
|
||||
unregister_event_handler,
|
||||
Stat,
|
||||
TensorboardEventHandler,
|
||||
unregister_event_handler,
|
||||
_WaitCounter,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||
|
||||
class TestMonitor(TestCase):
|
||||
def test_interval_stat(self) -> None:
|
||||
@ -98,6 +97,13 @@ class TestMonitor(TestCase):
|
||||
log_event(e)
|
||||
self.assertEqual(len(events), 2)
|
||||
|
||||
def test_wait_counter(self) -> None:
|
||||
wait_counter = _WaitCounter(
|
||||
"test_wait_counter",
|
||||
)
|
||||
with wait_counter.guard() as wcg:
|
||||
pass
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Really weird error")
|
||||
class TestMonitorTensorboard(TestCase):
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include <c10/util/WaitCounter.h>
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
|
||||
after calling ``register_event_handler``. After this returns the event
|
||||
handler will no longer receive events.
|
||||
)DOC");
|
||||
|
||||
struct WaitCounterTracker {
|
||||
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
|
||||
: handle{h} {}
|
||||
c10::monitor::WaitCounterHandle handle;
|
||||
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
|
||||
};
|
||||
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
|
||||
m, "_WaitCounterTracker")
|
||||
.def(
|
||||
"__enter__",
|
||||
[](const std::shared_ptr<WaitCounterTracker>& self) {
|
||||
self->guard.emplace(self->handle.start());
|
||||
})
|
||||
.def(
|
||||
"__exit__",
|
||||
[](const std::shared_ptr<WaitCounterTracker>& self,
|
||||
const pybind11::args&) { self->guard.reset(); });
|
||||
|
||||
py::class_<c10::monitor::WaitCounterHandle>(
|
||||
m,
|
||||
"_WaitCounter",
|
||||
R"DOC(
|
||||
WaitCounter represents a named duration counter.
|
||||
Multiple units of work can be tracked by the same WaitCounter. Depending
|
||||
on the backend, the WaitCounter may track the number of units of work,
|
||||
their duration etc.
|
||||
)DOC")
|
||||
.def(
|
||||
py::init([](const std::string& key) {
|
||||
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
|
||||
}),
|
||||
py::arg("key"))
|
||||
.def(
|
||||
"guard",
|
||||
[](const c10::monitor::WaitCounterHandle* self) {
|
||||
return std::make_shared<WaitCounterTracker>(*self);
|
||||
},
|
||||
R"DOC(
|
||||
Creates a guard that manages a single unit of work.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
} // namespace monitor
|
||||
|
@ -1,7 +1,8 @@
|
||||
from torch._C._monitor import * # noqa: F403
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch._C._monitor import _WaitCounter # type: ignore[attr-defined]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
Reference in New Issue
Block a user