mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Update Module.cpp for THC changes
This commit is contained in:
@ -7,9 +7,7 @@
|
||||
|
||||
#include "THCP.h"
|
||||
|
||||
THCState _state;
|
||||
THCState *state = &_state;
|
||||
THCDeviceAllocator deviceAllocator;
|
||||
THCState *state;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Class pointer cache
|
||||
@ -84,11 +82,6 @@ static bool THCPModule_assignStateless()
|
||||
void THCPModule_setDevice(int device)
|
||||
{
|
||||
THCudaCheck(cudaSetDevice(device));
|
||||
THCRandom_setGenerator(state, device);
|
||||
|
||||
// The stream is per device, so update the stream as well
|
||||
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
|
||||
THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));
|
||||
}
|
||||
|
||||
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg)
|
||||
@ -213,7 +206,8 @@ PyObject * THCPModule_initialSeed(PyObject *_unused)
|
||||
|
||||
bool THCPModule_initCuda(PyObject *module_dict) {
|
||||
#define ASSERT_TRUE(cond) if (!(cond)) { return false; }
|
||||
THCCachingAllocator_init(&state->cudaDeviceAllocator);
|
||||
state = THCState_alloc();
|
||||
THCCachingAllocator_init(THCState_getDeviceAllocator(state));
|
||||
THCudaInit(state);
|
||||
|
||||
#ifdef USE_MAGMA
|
||||
|
Reference in New Issue
Block a user