Use Future's then() API to fix RPC profiling (#38352)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38352

Fixes the RPC profiling by using the `then()` API added in https://github.com/pytorch/pytorch/pull/37311. Instead of adding a regular callback, we return a new future that completes when the profiling callback is finished. This is transparent to the user as the future still completes with the value of the original future (i.e. the RPC's return value)

To make this work for RRef, we add a `_set_profiling_future` to set the profiling future, and `_get_profiling_future` to retrieve this future and wait on it in the tests.

Re-enabled profiling tests and stress tested them 1000 times to verify the fix
ghstack-source-id: 104086114

Test Plan: Re-enabled profiling tests

Differential Revision: D21506940

fbshipit-source-id: 35cde22f0551c825c9bc98ddc24cca412878a63a
This commit is contained in:
Rohan Varma
2020-05-14 12:50:13 -07:00
committed by Facebook GitHub Bot
parent f178bf10f1
commit 4d4895a62a
9 changed files with 86 additions and 39 deletions

View File

@ -100,7 +100,8 @@ TypePtr tryInferTypeWithTypeHint(
/////////////////////////// PyRRef //////////////////////////////////
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref) : rref_(std::move(rref)) {
PyRRef::PyRRef(c10::intrusive_ptr<RRef> rref)
: rref_(std::move(rref)), profilingFuture_(c10::nullopt) {
TORCH_CHECK(rref_, "PyRRef must not wrap nullptr");
}
@ -127,6 +128,15 @@ c10::intrusive_ptr<JitFuture> PyRRef::getFuture() const {
rref_->getOwnerCreationFuture(), false /* hasValue */);
}
c10::intrusive_ptr<JitFuture> PyRRef::getProfilingFuture() const {
TORCH_INTERNAL_ASSERT(profilingFuture_, "Profiling future has not been set!");
return *profilingFuture_;
}
void PyRRef::setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture) {
profilingFuture_ = std::move(profilingFuture);
}
bool PyRRef::isOwner() const {
return rref_->isOwner();
}