mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
move THCP-related utils to cuda/utils.cpp. (#8221)
These files don't follow the usual pattern: In general the files torch/csrc/X torch/csrc/cuda/X both include the generic file torch/csrc/generic/X, where torch/csrc/X includes the cpu implementations and torch/csrc/cuda/X includes the cuda implementations. (Aside: this is probably not the best structure, the torch/csrc/X fiels should probably be moved to torch/csrc/cpu/X). utils.cpp combines these so that torch/csrc/utils.cpp has cuda specific code. This makes it impossible to declare a single THTensor and THCTensor template type (i.e. THPPointer<_THTensor>, THPointer<_THCTensor>).
This commit is contained in:
@ -7,3 +7,30 @@
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
|
||||
#include <THC/THCGenerateAllTypes.h>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj) {
|
||||
if (!PySequence_Check(obj)) {
|
||||
throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList");
|
||||
}
|
||||
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL));
|
||||
if (seq.get() == NULL) {
|
||||
throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)));
|
||||
}
|
||||
|
||||
std::vector<THCStream*> streams;
|
||||
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i);
|
||||
|
||||
if (PyObject_IsInstance(stream, THCPStreamClass)) {
|
||||
streams.push_back( ((THCPStream *)stream)->cdata);
|
||||
} else if (stream == Py_None) {
|
||||
streams.push_back(NULL);
|
||||
} else {
|
||||
std::runtime_error("Unknown data type found in stream list. Need THCStream or None");
|
||||
}
|
||||
}
|
||||
return streams;
|
||||
}
|
||||
#endif
|
||||
|
@ -17,10 +17,6 @@
|
||||
#include "generic/utils.cpp"
|
||||
#include <TH/THGenerateHalfType.h>
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
#include "torch/csrc/cuda/THCP.h"
|
||||
#endif
|
||||
|
||||
int THPUtils_getCallable(PyObject *arg, PyObject **result) {
|
||||
if (!PyCallable_Check(arg))
|
||||
return 0;
|
||||
@ -231,30 +227,3 @@ bool maybeThrowBackCompatKeepdimWarn(char *func) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj) {
|
||||
if (!PySequence_Check(obj)) {
|
||||
throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList");
|
||||
}
|
||||
THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL));
|
||||
if (seq.get() == NULL) {
|
||||
throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj)));
|
||||
}
|
||||
|
||||
std::vector<THCStream*> streams;
|
||||
Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i);
|
||||
|
||||
if (PyObject_IsInstance(stream, THCPStreamClass)) {
|
||||
streams.push_back( ((THCPStream *)stream)->cdata);
|
||||
} else if (stream == Py_None) {
|
||||
streams.push_back(NULL);
|
||||
} else {
|
||||
std::runtime_error("Unknown data type found in stream list. Need THCStream or None");
|
||||
}
|
||||
}
|
||||
return streams;
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user