mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Parallelize backwards
This commit is contained in:
		@ -3,6 +3,7 @@
 | 
			
		||||
#include "torch/csrc/autograd/engine.h"
 | 
			
		||||
#include "torch/csrc/THP.h"
 | 
			
		||||
#include "torch/csrc/DynamicTypes.h"
 | 
			
		||||
#include "torch/csrc/utils/auto_gil.h"
 | 
			
		||||
 | 
			
		||||
using namespace torch::autograd;
 | 
			
		||||
 | 
			
		||||
@ -10,6 +11,27 @@ struct THPEngine {
 | 
			
		||||
    PyObject_HEAD
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct PythonEngine : public Engine {
 | 
			
		||||
  virtual void thread_main(ReadyQueue& queue) override {
 | 
			
		||||
    // Create a PyThreadState, but release the GIL. This lets AutoGIL calls
 | 
			
		||||
    // inside thread_main acquire the GIL without having to create a new
 | 
			
		||||
    // PyThreadState each time.
 | 
			
		||||
    AutoGIL gil;
 | 
			
		||||
    AutoNoGIL no_gil;
 | 
			
		||||
    Engine::thread_main(queue);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void thread_on_exception(FunctionTask& task, std::exception& e) override {
 | 
			
		||||
    auto python_err = dynamic_cast<python_error*>(&e);
 | 
			
		||||
    if (python_err) {
 | 
			
		||||
      python_err->persist();
 | 
			
		||||
    }
 | 
			
		||||
    Engine::thread_on_exception(task, e);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static PythonEngine engine;
 | 
			
		||||
 | 
			
		||||
PyObject *THPEngineClass = NULL;
 | 
			
		||||
 | 
			
		||||
// Main backward function
 | 
			
		||||
@ -58,10 +80,12 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  try {
 | 
			
		||||
    Engine::backward(vars, grads, retain_variables);
 | 
			
		||||
    AutoNoGIL no_gil;
 | 
			
		||||
    engine.backward(vars, grads, retain_variables);
 | 
			
		||||
  } catch (python_error &e) {
 | 
			
		||||
    e.restore();
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  } catch (std::exception &e) {
 | 
			
		||||
  } catch (const std::exception &e) {
 | 
			
		||||
    PyErr_SetString(PyExc_RuntimeError, e.what());
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user