Add CUDA device selection

This commit is contained in:
Adam Paszke
2016-08-12 07:46:46 -07:00
parent fa6e5c5bff
commit 12bed8dc0d
10 changed files with 214 additions and 18 deletions

View File

@ -466,12 +466,20 @@ static PyObject * THPModule_cat(PyObject *_unused, PyObject *args)
return PyObject_Call(method, args, NULL);
}
#ifdef WITH_CUDA
extern PyObject * THCPModule_initExtension(PyObject *self);
extern PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
extern PyObject * THCPModule_getDevice_wrap(PyObject *self);
extern PyObject * THCPModule_getDeviceCount_wrap(PyObject *self);
#endif
static PyMethodDef TorchMethods[] = {
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_NOARGS, NULL},
#ifdef WITH_CUDA
{"_initCuda", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
{"_cuda_init", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, NULL},
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, NULL},
#endif
{"_tensorCopy", (PyCFunction)THPModule_tensorCopyWrapper, METH_VARARGS, NULL},
{"_storageCopy", (PyCFunction)THPModule_storageCopyWrapper, METH_VARARGS, NULL},