mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch/distributed: move WorkerInfo registration into libtorch instead of libtorch_python (#78028)
Summary: This moves torch::class_<WorkerInfo> into `rpc_agent.cpp` so it gets registered in libtorch instead of libtorch_python. This is intermediate work to getting torch::deploy to load an unmodified copy of libtorch. Current RPC is incompatible due to duplicate registrations. ``` unknown file: Failure C++ exception with description "Exception Caught inside torch::deploy embedded library: Custom class with name __torch__.torch.classes.dist_rpc.WorkerInfo is already registered. Ensure that registration with torch::class_ is only called once. Exception raised from registerCustomClass at ../aten/src/ATen/core/custom_class.cpp:61 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7f3bd9adb92e in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5c (0x7f3bd9ab7068 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libc10.so) frame #2: torch::registerCustomClass(std::shared_ptr<c10::ClassType>) + 0x110 (0x7f3bc2258980 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so) frame #3: torch::detail::class_base::class_base(std::string const&, std::string const&, std::string, std::type_info const&, std::type_info const&) + 0x3b9 (0x7f3bc225a419 in /home/tristanr/venvs/multipy/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so) frame #4: [0x7f3ba45cfea1] frame #5: <unknown function> + 0x1b5334 (0x5652bdab9334 in ./test_deploy) frame #6: <unknown function> + 0x1b4f3e (0x5652bdab8f3e in ./test_deploy) frame #7: <unknown function> + 0x1b519b (0x5652bdab919b in ./test_deploy) frame #8: loadSearchFile(char const*) + 0x23e (0x7f3ba62f37f8 in /tmp/torch_deploy9ATEFg) frame #9: deploy_set_self + 0x51 (0x7f3ba62f38f9 in /tmp/torch_deploy9ATEFg) frame #10: torch::deploy::Interpreter::Interpreter(torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>) + 0x274 (0x5652bdaaa790 in ./test_deploy) frame #11: void __gnu_cxx::new_allocator<torch::deploy::Interpreter>::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x81 (0x5652bdaaf58b in ./test_deploy) frame #12: void std::allocator_traits<std::allocator<torch::deploy::Interpreter> >::construct<torch::deploy::Interpreter, torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(std::allocator<torch::deploy::Interpreter>&, torch::deploy::Interpreter*, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0x4a (0x5652bdaae320 in ./test_deploy) frame #13: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::_M_realloc_insert<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(__gnu_cxx::__normal_iterator<torch::deploy::Interpreter*, std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> > >, torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xee (0x5652bdaae4a0 in ./test_deploy) frame #14: void std::vector<torch::deploy::Interpreter, std::allocator<torch::deploy::Interpreter> >::emplace_back<torch::deploy::InterpreterManager*, std::shared_ptr<torch::deploy::Environment>&>(torch::deploy::InterpreterManager*&&, std::shared_ptr<torch::deploy::Environment>&) + 0xb6 (0x5652bdaad258 in ./test_deploy) frame #15: torch::deploy::InterpreterManager::InterpreterManager(unsigned long, std::shared_ptr<torch::deploy::Environment>) + 0x123 (0x5652bdaa83b1 in ./test_deploy) frame #16: TorchpyTest_InitTwice_Test::TestBody() + 0x65 (0x5652bda075a9 in ./test_deploy) frame #17: void testing::internal::HandleSehExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x65 (0x5652bda944b7 in ./test_deploy) frame #18: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x5a (0x5652bda8cfe7 in ./test_deploy) frame #19: testing::Test::Run() + 0x100 (0x5652bda68622 in ./test_deploy) frame #20: testing::TestInfo::Run() + 0x10f (0x5652bda68fb3 in ./test_deploy) frame #21: testing::TestSuite::Run() + 0x121 (0x5652bda6980d in ./test_deploy) frame #22: testing::internal::UnitTestImpl::RunAllTests() + 0x38e (0x5652bda756e6 in ./test_deploy) frame #23: bool testing::internal::HandleSehExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x65 (0x5652bda9586b in ./test_deploy) frame #24: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x5a (0x5652bda8e0f7 in ./test_deploy) frame #25: testing::UnitTest::Run() + 0xc9 (0x5652bda73fd1 in ./test_deploy) frame #26: RUN_ALL_TESTS() + 0x11 (0x5652bda169fa in ./test_deploy) frame #27: main + 0x27 (0x5652bda10ce2 in ./test_deploy) frame #28: <unknown function> + 0x2d310 (0x7f3bc0431310 in /usr/lib/libc.so.6) frame #29: __libc_start_main + 0x81 (0x7f3bc04313c1 in /usr/lib/libc.so.6) frame #30: _start + 0x25 (0x5652bda063b5 in ./test_deploy) ``` Test Plan: CI Differential Revision: D36564258 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78028 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
8412f209f0
commit
ebba4219ae
@ -5,8 +5,50 @@ namespace torch {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
|
||||
namespace {
|
||||
// WorkerInfo needs to be registered exactly once. Since the op registration
|
||||
// happens in libtorch_python we wrap the class registration in a helper to make
|
||||
// sure that if there's multiple copies of Python such as used in torch::deploy
|
||||
// we only ever register it once.
|
||||
static std::once_flag workerInfoFlag;
|
||||
static c10::optional<torch::class_<WorkerInfo>> workerInfo;
|
||||
} // namespace
|
||||
|
||||
RegisterWorkerInfoOnce::RegisterWorkerInfoOnce() {
|
||||
std::call_once(workerInfoFlag, []() {
|
||||
workerInfo = torch::class_<WorkerInfo>("dist_rpc", "WorkerInfo")
|
||||
.def(torch::init<std::string, int64_t>());
|
||||
});
|
||||
}
|
||||
|
||||
constexpr size_t WorkerInfo::MAX_NAME_LEN;
|
||||
|
||||
WorkerInfo::WorkerInfo(std::string name, int64_t id)
|
||||
: WorkerInfo(std::move(name), (worker_id_t)id) {
|
||||
TORCH_CHECK(
|
||||
id <= std::numeric_limits<worker_id_t>::max(),
|
||||
"RPC worker id ",
|
||||
id,
|
||||
" out of bound of int16_t.");
|
||||
}
|
||||
|
||||
WorkerInfo::WorkerInfo(std::string name, worker_id_t id)
|
||||
: name_(std::move(name)), id_(id) {
|
||||
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
|
||||
bool validChar =
|
||||
std::find_if(name_.begin(), name_.end(), [](char c) {
|
||||
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
|
||||
}) == name_.end();
|
||||
TORCH_CHECK(
|
||||
validSize && validChar,
|
||||
"Worker name must match ^[A-Za-z0-9-_:]*$, "
|
||||
"and must be non-empty and shorter than ",
|
||||
MAX_NAME_LEN,
|
||||
" chars, "
|
||||
"but got ",
|
||||
name_);
|
||||
}
|
||||
|
||||
// Large Time Duration for waiting on the condition variable until the map is
|
||||
// population. Cannot use
|
||||
// std::chrono::time_point<std::chrono::steady_clock>::max() due to a known
|
||||
|
||||
@ -52,31 +52,9 @@ struct RpcBackendOptions {
|
||||
|
||||
// A globally unique ID to identify an RpcAgent
|
||||
struct TORCH_API WorkerInfo : torch::CustomClassHolder {
|
||||
WorkerInfo(std::string name, int64_t id)
|
||||
: WorkerInfo(std::move(name), (worker_id_t)id) {
|
||||
TORCH_CHECK(
|
||||
id <= std::numeric_limits<worker_id_t>::max(),
|
||||
"RPC worker id ",
|
||||
id,
|
||||
" out of bound of int16_t.");
|
||||
}
|
||||
WorkerInfo(std::string name, int64_t id);
|
||||
|
||||
WorkerInfo(std::string name, worker_id_t id)
|
||||
: name_(std::move(name)), id_(id) {
|
||||
bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0;
|
||||
bool validChar =
|
||||
std::find_if(name_.begin(), name_.end(), [](char c) {
|
||||
return !(std::isalnum(c) || c == '-' || c == '_' || c == ':');
|
||||
}) == name_.end();
|
||||
TORCH_CHECK(
|
||||
validSize && validChar,
|
||||
"Worker name must match ^[A-Za-z0-9-_:]*$, "
|
||||
"and must be non-empty and shorter than ",
|
||||
MAX_NAME_LEN,
|
||||
" chars, "
|
||||
"but got ",
|
||||
name_);
|
||||
}
|
||||
WorkerInfo(std::string name, worker_id_t id);
|
||||
|
||||
bool operator==(const WorkerInfo& rhs) {
|
||||
return (id_ == rhs.id_) && (name_ == rhs.name_);
|
||||
@ -88,6 +66,10 @@ struct TORCH_API WorkerInfo : torch::CustomClassHolder {
|
||||
const worker_id_t id_;
|
||||
};
|
||||
|
||||
struct TORCH_API RegisterWorkerInfoOnce {
|
||||
RegisterWorkerInfoOnce();
|
||||
};
|
||||
|
||||
TORCH_API std::ostream& operator<<(
|
||||
std::ostream& os,
|
||||
const WorkerInfo& workerInfo);
|
||||
|
||||
@ -22,10 +22,7 @@ namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
static auto workerInfo =
|
||||
torch::class_<dist_rpc::WorkerInfo>("dist_rpc", "WorkerInfo")
|
||||
.def(torch::init<std::string, int64_t>());
|
||||
distributed::rpc::RegisterWorkerInfoOnce workerInfo{};
|
||||
|
||||
// prepare the rpc input arguments and call the C++ impls
|
||||
void prepare_and_call_rpc_op(
|
||||
|
||||
Reference in New Issue
Block a user