Update Module.cpp for THC changes

This commit is contained in:
Sam Gross
2016-09-30 11:13:14 -07:00
parent 3e1c88e3e0
commit c20828478e

View File

@ -7,9 +7,7 @@
#include "THCP.h" #include "THCP.h"
THCState _state; THCState *state;
THCState *state = &_state;
THCDeviceAllocator deviceAllocator;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Class pointer cache // Class pointer cache
@ -84,11 +82,6 @@ static bool THCPModule_assignStateless()
void THCPModule_setDevice(int device) void THCPModule_setDevice(int device)
{ {
THCudaCheck(cudaSetDevice(device)); THCudaCheck(cudaSetDevice(device));
THCRandom_setGenerator(state, device);
// The stream is per device, so update the stream as well
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));
} }
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg) PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg)
@ -213,7 +206,8 @@ PyObject * THCPModule_initialSeed(PyObject *_unused)
bool THCPModule_initCuda(PyObject *module_dict) { bool THCPModule_initCuda(PyObject *module_dict) {
#define ASSERT_TRUE(cond) if (!(cond)) { return false; } #define ASSERT_TRUE(cond) if (!(cond)) { return false; }
THCCachingAllocator_init(&state->cudaDeviceAllocator); state = THCState_alloc();
THCCachingAllocator_init(THCState_getDeviceAllocator(state));
THCudaInit(state); THCudaInit(state);
#ifdef USE_MAGMA #ifdef USE_MAGMA