mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34497 Use a thread_local table to intercept UserRRefs created during user function args deserialization, and then wait for confirmations of those UserRRefs before launching the given user function. Differential Revision: D20347464 Test Plan: Imported from OSS Pulled By: mrshenli fbshipit-source-id: 087484a2d2f03fbfb156752ab25653f39b412a07
25 lines
826 B
C++
25 lines
826 B
C++
#include <torch/csrc/distributed/rpc/request_callback.h>
|
|
|
|
#include <torch/csrc/distributed/autograd/context/container.h>
|
|
#include <torch/csrc/distributed/autograd/utils.h>
|
|
|
|
namespace torch {
|
|
namespace distributed {
|
|
namespace rpc {
|
|
|
|
using namespace torch::distributed::autograd;
|
|
|
|
std::shared_ptr<FutureMessage> RequestCallback::operator()(
|
|
Message& request) const {
|
|
// NB: cannot clear autograd context id here because the processMessage method
|
|
// might pause waiting for all RRefs in the arguments to be confirmed by their
|
|
// owners and resumne processing in a different thread. Hence, the
|
|
// thread_local context id needs to be set and cleared in the thread that
|
|
// indeed carries out the processing logic.
|
|
return processMessage(request);
|
|
}
|
|
|
|
} // namespace rpc
|
|
} // namespace distributed
|
|
} // namespace torch
|