Fix multiprocessing on OS X

This commit is contained in:
Adam Paszke
2016-09-16 18:14:41 -04:00
parent 7847d77405
commit e223564a55
4 changed files with 61 additions and 24 deletions

View File

@ -137,16 +137,18 @@ class TestMultiprocessing(TestCase):
with leak_checker(self) as lc:
do_test()
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_sharing(self):
self._test_sharing()
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_preserve_sharing(self):
self._test_preserve_sharing()
@unittest.skipIf(platform == 'darwin', "file descriptor strategy is not supported on OS X")
def test_fd_pool(self):
self._test_pool()
@unittest.skipIf(platform == "darwin", "file_system sharing strategy doesn't work in OSX")
def test_fs_sharing(self):
with fs_sharing():
self._test_sharing()

View File

@ -85,7 +85,7 @@ int main(int argc, char *argv[]) {
try {
char tmpfile[L_tmpnam];
if (std::tmpnam(tmpfile) == NULL)
throw std::exception();
throw std::runtime_error("could not generate a random filename for manager socket");
// TODO: better strategy for generating tmp names
// TODO: retry on collisions - this can easily fail
srv_socket.reset(new ManagerServerSocket(std::string(tmpfile)));

View File

@ -17,7 +17,7 @@ public:
protected:
Socket() {
SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_SEQPACKET, 0));
SYSCHECK(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
}
Socket(const Socket& other) = delete;
Socket(Socket&& other): socket_fd(other.socket_fd) { other.socket_fd = -1; };
@ -40,6 +40,40 @@ protected:
return strlen(address.sun_path) + sizeof(address.sun_family);
}
void recv(void *_buffer, size_t num_bytes) {
char *buffer = (char*)_buffer;
size_t bytes_recieved = 0;
ssize_t step_recieved;
struct pollfd pfd = {0};
pfd.fd = socket_fd;
pfd.events = POLLIN;
while (bytes_recieved < num_bytes) {
SYSCHECK(poll(&pfd, 1, 1000));
if (pfd.revents & POLLIN) {
SYSCHECK(step_recieved = ::read(socket_fd, buffer, num_bytes - bytes_recieved));
if (step_recieved == 0)
throw std::runtime_error("Other end has closed the connection");
bytes_recieved += step_recieved;
buffer += step_recieved;
} else if (pfd.revents & (POLLERR | POLLHUP)) {
throw std::runtime_error("An error occured while waiting for the data");
} else {
throw std::runtime_error("Shared memory manager connection has timed out");
}
}
}
void send(const void *_buffer, size_t num_bytes) {
const char *buffer = (const char*)_buffer;
size_t bytes_sent = 0;
ssize_t step_sent;
while (bytes_sent < num_bytes) {
SYSCHECK(step_sent = ::write(socket_fd, buffer, num_bytes));
bytes_sent += step_sent;
buffer += step_sent;
}
}
};
@ -49,12 +83,12 @@ public:
AllocInfo recieve() {
AllocInfo info;
SYSCHECK(::recv(socket_fd, &info, sizeof(info), 0));
recv(&info, sizeof(info));
return info;
}
void confirm() {
SYSCHECK(::send(socket_fd, "OK", 2, 0));
send("OK", 2);
}
};
@ -81,7 +115,9 @@ public:
ManagerSocket accept() {
int client_fd;
SYSCHECK(client_fd = ::accept(socket_fd, NULL, NULL));
struct sockaddr_un addr;
socklen_t addr_len = sizeof(addr);
SYSCHECK(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
return ManagerSocket(client_fd);
}
@ -102,26 +138,16 @@ public:
}
void register_allocation(AllocInfo &info) {
char buffer[10];
char buffer[3] = {0, 0, 0};
ssize_t bytes_read;
SYSCHECK(::send(socket_fd, &info, sizeof(info), 0));
struct pollfd pfd = {0};
pfd.fd = socket_fd;
pfd.events = POLLIN;
SYSCHECK(poll(&pfd, 1, 1000));
if (pfd.revents & POLLIN) {
SYSCHECK(bytes_read = ::recv(socket_fd, buffer, sizeof(buffer)-1, 0));
buffer[bytes_read] = 0;
if (strcmp(buffer, "OK") != 0)
throw std::exception();
} else {
// no data arrived before the timeout
throw std::exception();
}
send(&info, sizeof(info));
recv(buffer, 2);
if (strcmp(buffer, "OK") != 0)
throw std::runtime_error("Shared memory manager didn't respond with an OK");
}
void register_deallocation(AllocInfo &info) {
::send(socket_fd, &info, sizeof(info), 0);
send(&info, sizeof(info));
}
};

View File

@ -1,18 +1,27 @@
from sys import platform as _platform
from multiprocessing import *
_sharing_strategy = 'file_descriptor'
if _platform == 'darwin':
_sharing_strategy = 'file_system'
_all_sharing_strategies = {'file_system'}
else:
_sharing_strategy = 'file_descriptor'
_all_sharing_strategies = {'file_descriptor', 'file_system'}
def set_sharing_strategy(new_stragegy):
global _sharing_strategy
assert new_stragegy in {'file_descriptor', 'file_system'}
assert new_stragegy in _all_sharing_strategies
_sharing_strategy = new_stragegy
def get_sharing_strategy():
return _sharing_strategy
def get_all_sharing_strategies():
return _all_sharing_strategies
def Queue(*args, **kwargs):
from .queue import Queue, FdQueue