Parallelize backwards

This commit is contained in:
Sam Gross
2017-02-28 12:20:25 -08:00
parent c238ee3681
commit 34ce58c909
31 changed files with 1408 additions and 279 deletions

View File

@ -3,6 +3,7 @@
#include "torch/csrc/autograd/engine.h"
#include "torch/csrc/THP.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/utils/auto_gil.h"
using namespace torch::autograd;
@ -10,6 +11,27 @@ struct THPEngine {
PyObject_HEAD
};
struct PythonEngine : public Engine {
virtual void thread_main(ReadyQueue& queue) override {
// Create a PyThreadState, but release the GIL. This lets AutoGIL calls
// inside thread_main acquire the GIL without having to create a new
// PyThreadState each time.
AutoGIL gil;
AutoNoGIL no_gil;
Engine::thread_main(queue);
}
virtual void thread_on_exception(FunctionTask& task, std::exception& e) override {
auto python_err = dynamic_cast<python_error*>(&e);
if (python_err) {
python_err->persist();
}
Engine::thread_on_exception(task, e);
}
};
static PythonEngine engine;
PyObject *THPEngineClass = NULL;
// Main backward function
@ -58,10 +80,12 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
}
try {
Engine::backward(vars, grads, retain_variables);
AutoNoGIL no_gil;
engine.backward(vars, grads, retain_variables);
} catch (python_error &e) {
e.restore();
return nullptr;
} catch (std::exception &e) {
} catch (const std::exception &e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
return nullptr;
}