Files
pytorch/torch/csrc/distributed/rpc/request_callback.cpp
Shen Li 422e348619 Don't run user function until all UserRRefs in the args are confirmed (#34497)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497

Use a thread_local table to intercept UserRRefs created during user
function args deserialization, and then wait for confirmations of
those UserRRefs before launching the given user function.

Differential Revision: D20347464

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
2020-03-16 18:30:06 -07:00

25 lines
826 B
C++

#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
using namespace torch::distributed::autograd;
std::shared_ptr<FutureMessage> RequestCallback::operator()(
Message& request) const {
// NB: cannot clear autograd context id here because the processMessage method
// might pause waiting for all RRefs in the arguments to be confirmed by their
// owners and resumne processing in a different thread. Hence, the
// thread_local context id needs to be set and cleared in the thread that
// indeed carries out the processing logic.
return processMessage(request);
}
} // namespace rpc
} // namespace distributed
} // namespace torch