torch/distributed: move WorkerInfo registration into libtorch instead of libtorch_python (#78028)

Summary:
This moves torch::class_<WorkerInfo> into `rpc_agent.cpp` so it gets registered in libtorch instead of libtorch_python. This is intermediate work to getting torch::deploy to load an unmodified copy of libtorch. Current RPC is incompatible due to duplicate registrations.

```
unknown file: Failure
C++ exception with description "Exception Caught inside torch::deploy embedded library:
Custom class with name __torch__.torch.classes.dist_rpc.WorkerInfo is already registered. Ensure that registration with torch::class_ is only called once.
Exception raised from registerCustomClass at ../aten/src/ATen/core/custom_class.cpp:61 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7f3bd9adb92e in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5c (0x7f3bd9ab7068 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: torch::registerCustomClass(std::shared_ptr<c10::ClassType>) + 0x110 (0x7f3bc2258980 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::detail::class_base::class_base(std::string const&, std::string const&, std::string, std::type_info const&, std::type_info const&) + 0x3b9 (0x7f3bc225a419 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: [0x7f3ba45cfea1]
frame #5: <unknown function> + 0x1b5334 (0x5652bdab9334 in ./test_deploy)
frame #6: <unknown function> + 0x1b4f3e (0x5652bdab8f3e in ./test_deploy)
frame #7: <unknown function> + 0x1b519b (0x5652bdab919b in ./test_deploy)
frame #8: loadSearchFile(char const*) + 0x23e (0x7f3ba62f37f8 in /tmp/torch_deploy9ATEFg)
frame #9: deploy_set_self + 0x51 (0x7f3ba62f38f9 in /tmp/torch_deploy9ATEFg)
frame #10: torch::deploy::Interpreter::Interpreter(torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>) + 0x274 (0x5652bdaaa790 in ./test_deploy)
frame #11: void __gnu_cxx::new_allocator<torch::deploy::Interpreter>::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x81 (0x5652bdaaf58b in ./test_deploy)
frame #12: void std::allocator_traits<std::allocator<torch::deploy::Interpreter> >::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(std::allocator<torch::deploy::Interpreter>&, torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x4a (0x5652bdaae320 in ./test_deploy)
frame #13: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::_M_realloc_insert<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(__gnu_cxx::__normal_iterator<torch::deploy::Interpreter*, std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> > >, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xee (0x5652bdaae4a0 in ./test_deploy)
frame #14: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::emplace_back<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xb6 (0x5652bdaad258 in ./test_deploy)
frame #15: torch::deploy::InterpreterManager::InterpreterManager(unsigned long, std::shared_ptr<torch::deploy::Environment>) + 0x123 (0x5652bdaa83b1 in ./test_deploy)
frame #16: TorchpyTest_InitTwice_Test::TestBody() + 0x65 (0x5652bda075a9 in ./test_deploy)
frame #17: void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x65 (0x5652bda944b7 in ./test_deploy)
frame #18: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x5a (0x5652bda8cfe7 in ./test_deploy)
frame #19: testing::Test::Run() + 0x100 (0x5652bda68622 in ./test_deploy)
frame #20: testing::TestInfo::Run() + 0x10f (0x5652bda68fb3 in ./test_deploy)
frame #21: testing::TestSuite::Run() + 0x121 (0x5652bda6980d in ./test_deploy)
frame #22: testing::internal::UnitTestImpl::RunAllTests() + 0x38e (0x5652bda756e6 in ./test_deploy)
frame #23: bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x65 (0x5652bda9586b in ./test_deploy)
frame #24: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x5a (0x5652bda8e0f7 in ./test_deploy)
frame #25: testing::UnitTest::Run() + 0xc9 (0x5652bda73fd1 in ./test_deploy)
frame #26: RUN_ALL_TESTS() + 0x11 (0x5652bda169fa in ./test_deploy)
frame #27: main + 0x27 (0x5652bda10ce2 in ./test_deploy)
frame #28: <unknown function> + 0x2d310 (0x7f3bc0431310 in /usr/lib/libc.so.6)
frame #29: __libc_start_main + 0x81 (0x7f3bc04313c1 in /usr/lib/libc.so.6)
frame #30: _start + 0x25 (0x5652bda063b5 in ./test_deploy)
```

Test Plan: CI

Differential Revision: D36564258

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78028
Approved by: https://github.com/rohan-varma
This commit is contained in:
Tristan Rice
2022-05-25 17:46:39 +00:00
committed by PyTorch MergeBot
parent 8412f209f0
commit ebba4219ae
3 changed files with 49 additions and 28 deletions

View File

@ -5,8 +5,50 @@ namespace torch {
namespace distributed {
namespace rpc {
namespace {
// WorkerInfo needs to be registered exactly once. Since the op registration
// happens in libtorch_python we wrap the class registration in a helper to make
// sure that if there's multiple copies of Python such as used in torch::deploy
// we only ever register it once.
static std::once_flag workerInfoFlag;
static c10::optional<torch::class_<WorkerInfo>> workerInfo;
} // namespace
RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
std::call_once(workerInfoFlag, []() {
workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
.def(torch::init<std::string, int64_t>());
});
}
constexpr size_t WorkerInfo::MAX_NAME_LEN;
WorkerInfo::WorkerInfo(std::string name, int64_t id)
: WorkerInfo(std::move(name), (worker_id_t)id) {
TORCH_CHECK(
id <= std::numeric_limits<worker_id_t>::max(),
"RPC worker id ",
id,
" out of bound of int16_t.");
}
WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
: name_(std::move(name)), id_(id) {
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
bool validChar =
std::find_if(name_.begin(), name_.end(), [](char c) {
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
}) == name_.end();
TORCH_CHECK(
validSize && validChar,
"Worker name must match ^[A-Za-z0-9-_:]*$, "
"and must be non-empty and shorter than ",
MAX_NAME_LEN,
" chars, "
"but got ",
name_);
}
// Large Time Duration for waiting on the condition variable until the map is
// population. Cannot use
// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known

View File

@ -52,31 +52,9 @@ struct RpcBackendOptions {
// A globally unique ID to identify an RpcAgent
struct TORCH_API WorkerInfo : torch::CustomClassHolder {
WorkerInfo(std::string name, int64_t id)
: WorkerInfo(std::move(name), (worker_id_t)id) {
TORCH_CHECK(
id <= std::numeric_limits<worker_id_t>::max(),
"RPC worker id ",
id,
" out of bound of int16_t.");
}
WorkerInfo(std::string name, int64_t id);
WorkerInfo(std::string name, worker_id_t id)
: name_(std::move(name)), id_(id) {
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
bool validChar =
std::find_if(name_.begin(), name_.end(), [](char c) {
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
}) == name_.end();
TORCH_CHECK(
validSize && validChar,
"Worker name must match ^[A-Za-z0-9-_:]*$, "
"and must be non-empty and shorter than ",
MAX_NAME_LEN,
" chars, "
"but got ",
name_);
}
WorkerInfo(std::string name, worker_id_t id);
bool operator==(const WorkerInfo& rhs) {
return (id_ == rhs.id_) && (name_ == rhs.name_);
@ -88,6 +66,10 @@ struct TORCH_API WorkerInfo : torch::CustomClassHolder {
const worker_id_t id_;
};
struct TORCH_API RegisterWorkerInfoOnce {
RegisterWorkerInfoOnce();
};
TORCH_API std::ostream& operator<<(
std::ostream& os,
const WorkerInfo& workerInfo);

View File

@ -22,10 +22,7 @@ namespace torch {
namespace jit {
namespace {
static auto workerInfo =
torch::class_<dist_rpc::WorkerInfo>("dist_rpc", "WorkerInfo")
.def(torch::init<std::string, int64_t>());
distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
// prepare the rpc input arguments and call the C++ impls
void prepare_and_call_rpc_op(