mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fix multiprocessing on OS X
This commit is contained in:
@ -137,16 +137,18 @@ class TestMultiprocessing(TestCase):
|
||||
with leak_checker(self) as lc:
|
||||
do_test()
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_sharing(self):
|
||||
self._test_sharing()
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_preserve_sharing(self):
|
||||
self._test_preserve_sharing()
|
||||
|
||||
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
|
||||
def test_fd_pool(self):
|
||||
self._test_pool()
|
||||
|
||||
@unittest.skipIf(platform == "darwin", "file_system sharing strategy doesn't work in OSX")
|
||||
def test_fs_sharing(self):
|
||||
with fs_sharing():
|
||||
self._test_sharing()
|
||||
|
@ -85,7 +85,7 @@ int main(int argc, char *argv[]) {
|
||||
try {
|
||||
char tmpfile[L_tmpnam];
|
||||
if (std::tmpnam(tmpfile) == NULL)
|
||||
throw std::exception();
|
||||
throw std::runtime_error("could not generate a random filename for manager socket");
|
||||
// TODO: better strategy for generating tmp names
|
||||
// TODO: retry on collisions - this can easily fail
|
||||
srv_socket.reset(new ManagerServerSocket(std::string(tmpfile)));
|
||||
|
@ -17,7 +17,7 @@ public:
|
||||
|
||||
protected:
|
||||
Socket() {
|
||||
SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_SEQPACKET, 0));
|
||||
SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
|
||||
}
|
||||
Socket(const Socket& other) = delete;
|
||||
Socket(Socket&& other): socket_fd(other.socket_fd) { other.socket_fd = -1; };
|
||||
@ -40,6 +40,40 @@ protected:
|
||||
return strlen(address.sun_path) + sizeof(address.sun_family);
|
||||
}
|
||||
|
||||
void recv(void *_buffer, size_t num_bytes) {
|
||||
char *buffer = (char*)_buffer;
|
||||
size_t bytes_recieved = 0;
|
||||
ssize_t step_recieved;
|
||||
struct pollfd pfd = {0};
|
||||
pfd.fd = socket_fd;
|
||||
pfd.events = POLLIN;
|
||||
while (bytes_recieved < num_bytes) {
|
||||
SYSCHECK(poll(&pfd, 1, 1000));
|
||||
if (pfd.revents & POLLIN) {
|
||||
SYSCHECK(step_recieved = ::read(socket_fd, buffer, num_bytes - bytes_recieved));
|
||||
if (step_recieved == 0)
|
||||
throw std::runtime_error("Other end has closed the connection");
|
||||
bytes_recieved += step_recieved;
|
||||
buffer += step_recieved;
|
||||
} else if (pfd.revents & (POLLERR | POLLHUP)) {
|
||||
throw std::runtime_error("An error occured while waiting for the data");
|
||||
} else {
|
||||
throw std::runtime_error("Shared memory manager connection has timed out");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void send(const void *_buffer, size_t num_bytes) {
|
||||
const char *buffer = (const char*)_buffer;
|
||||
size_t bytes_sent = 0;
|
||||
ssize_t step_sent;
|
||||
while (bytes_sent < num_bytes) {
|
||||
SYSCHECK(step_sent = ::write(socket_fd, buffer, num_bytes));
|
||||
bytes_sent += step_sent;
|
||||
buffer += step_sent;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
@ -49,12 +83,12 @@ public:
|
||||
|
||||
AllocInfo recieve() {
|
||||
AllocInfo info;
|
||||
SYSCHECK(::recv(socket_fd, &info, sizeof(info), 0));
|
||||
recv(&info, sizeof(info));
|
||||
return info;
|
||||
}
|
||||
|
||||
void confirm() {
|
||||
SYSCHECK(::send(socket_fd, "OK", 2, 0));
|
||||
send("OK", 2);
|
||||
}
|
||||
|
||||
};
|
||||
@ -81,7 +115,9 @@ public:
|
||||
|
||||
ManagerSocket accept() {
|
||||
int client_fd;
|
||||
SYSCHECK(client_fd = ::accept(socket_fd, NULL, NULL));
|
||||
struct sockaddr_un addr;
|
||||
socklen_t addr_len = sizeof(addr);
|
||||
SYSCHECK(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
|
||||
return ManagerSocket(client_fd);
|
||||
}
|
||||
|
||||
@ -102,26 +138,16 @@ public:
|
||||
}
|
||||
|
||||
void register_allocation(AllocInfo &info) {
|
||||
char buffer[10];
|
||||
char buffer[3] = {0, 0, 0};
|
||||
ssize_t bytes_read;
|
||||
SYSCHECK(::send(socket_fd, &info, sizeof(info), 0));
|
||||
struct pollfd pfd = {0};
|
||||
pfd.fd = socket_fd;
|
||||
pfd.events = POLLIN;
|
||||
SYSCHECK(poll(&pfd, 1, 1000));
|
||||
if (pfd.revents & POLLIN) {
|
||||
SYSCHECK(bytes_read = ::recv(socket_fd, buffer, sizeof(buffer)-1, 0));
|
||||
buffer[bytes_read] = 0;
|
||||
if (strcmp(buffer, "OK") != 0)
|
||||
throw std::exception();
|
||||
} else {
|
||||
// no data arrived before the timeout
|
||||
throw std::exception();
|
||||
}
|
||||
send(&info, sizeof(info));
|
||||
recv(buffer, 2);
|
||||
if (strcmp(buffer, "OK") != 0)
|
||||
throw std::runtime_error("Shared memory manager didn't respond with an OK");
|
||||
}
|
||||
|
||||
void register_deallocation(AllocInfo &info) {
|
||||
::send(socket_fd, &info, sizeof(info), 0);
|
||||
send(&info, sizeof(info));
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -1,18 +1,27 @@
|
||||
from sys import platform as _platform
|
||||
from multiprocessing import *
|
||||
|
||||
|
||||
_sharing_strategy = 'file_descriptor'
|
||||
if _platform == 'darwin':
|
||||
_sharing_strategy = 'file_system'
|
||||
_all_sharing_strategies = {'file_system'}
|
||||
else:
|
||||
_sharing_strategy = 'file_descriptor'
|
||||
_all_sharing_strategies = {'file_descriptor', 'file_system'}
|
||||
|
||||
|
||||
def set_sharing_strategy(new_stragegy):
|
||||
global _sharing_strategy
|
||||
assert new_stragegy in {'file_descriptor', 'file_system'}
|
||||
assert new_stragegy in _all_sharing_strategies
|
||||
_sharing_strategy = new_stragegy
|
||||
|
||||
|
||||
def get_sharing_strategy():
|
||||
return _sharing_strategy
|
||||
|
||||
def get_all_sharing_strategies():
|
||||
return _all_sharing_strategies
|
||||
|
||||
|
||||
def Queue(*args, **kwargs):
|
||||
from .queue import Queue, FdQueue
|
||||
|
Reference in New Issue
Block a user