mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Add CUDA device selection
This commit is contained in:
@ -466,12 +466,20 @@ static PyObject * THPModule_cat(PyObject *_unused, PyObject *args)
|
||||
return PyObject_Call(method, args, NULL);
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
extern PyObject * THCPModule_initExtension(PyObject *self);
|
||||
extern PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg);
|
||||
extern PyObject * THCPModule_getDevice_wrap(PyObject *self);
|
||||
extern PyObject * THCPModule_getDeviceCount_wrap(PyObject *self);
|
||||
#endif
|
||||
|
||||
static PyMethodDef TorchMethods[] = {
|
||||
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_NOARGS, NULL},
|
||||
#ifdef WITH_CUDA
|
||||
{"_initCuda", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
|
||||
{"_cuda_init", (PyCFunction)THCPModule_initExtension, METH_NOARGS, NULL},
|
||||
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, NULL},
|
||||
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, NULL},
|
||||
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, NULL},
|
||||
#endif
|
||||
{"_tensorCopy", (PyCFunction)THPModule_tensorCopyWrapper, METH_VARARGS, NULL},
|
||||
{"_storageCopy", (PyCFunction)THPModule_storageCopyWrapper, METH_VARARGS, NULL},
|
||||
|
||||
Reference in New Issue
Block a user