mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Update Module.cpp for THC changes
This commit is contained in:
@ -7,9 +7,7 @@
|
|||||||
|
|
||||||
#include "THCP.h"
|
#include "THCP.h"
|
||||||
|
|
||||||
THCState _state;
|
THCState *state;
|
||||||
THCState *state = &_state;
|
|
||||||
THCDeviceAllocator deviceAllocator;
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Class pointer cache
|
// Class pointer cache
|
||||||
@ -84,11 +82,6 @@ static bool THCPModule_assignStateless()
|
|||||||
void THCPModule_setDevice(int device)
|
void THCPModule_setDevice(int device)
|
||||||
{
|
{
|
||||||
THCudaCheck(cudaSetDevice(device));
|
THCudaCheck(cudaSetDevice(device));
|
||||||
THCRandom_setGenerator(state, device);
|
|
||||||
|
|
||||||
// The stream is per device, so update the stream as well
|
|
||||||
THCState_setStream(state, device, THCState_getCurrentStreamIndex(state));
|
|
||||||
THCState_setBlasHandle(state, device, THCState_getCurrentBlasHandleIndex(state));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg)
|
PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg)
|
||||||
@ -213,7 +206,8 @@ PyObject * THCPModule_initialSeed(PyObject *_unused)
|
|||||||
|
|
||||||
bool THCPModule_initCuda(PyObject *module_dict) {
|
bool THCPModule_initCuda(PyObject *module_dict) {
|
||||||
#define ASSERT_TRUE(cond) if (!(cond)) { return false; }
|
#define ASSERT_TRUE(cond) if (!(cond)) { return false; }
|
||||||
THCCachingAllocator_init(&state->cudaDeviceAllocator);
|
state = THCState_alloc();
|
||||||
|
THCCachingAllocator_init(THCState_getDeviceAllocator(state));
|
||||||
THCudaInit(state);
|
THCudaInit(state);
|
||||||
|
|
||||||
#ifdef USE_MAGMA
|
#ifdef USE_MAGMA
|
||||||
|
Reference in New Issue
Block a user