mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	move THCP-related utils to cuda/utils.cpp. (#8221)
These files don't follow the usual pattern: In general the files torch/csrc/X torch/csrc/cuda/X both include the generic file torch/csrc/generic/X, where torch/csrc/X includes the cpu implementations and torch/csrc/cuda/X includes the cuda implementations. (Aside: this is probably not the best structure, the torch/csrc/X fiels should probably be moved to torch/csrc/cpu/X). utils.cpp combines these so that torch/csrc/utils.cpp has cuda specific code. This makes it impossible to declare a single THTensor and THCTensor template type (i.e. THPPointer<_THTensor>, THPointer<_THCTensor>).
This commit is contained in:
		| @ -7,3 +7,30 @@ | ||||
|  | ||||
| #define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp" | ||||
| #include <THC/THCGenerateAllTypes.h> | ||||
|  | ||||
| #ifdef WITH_CUDA | ||||
| std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj) { | ||||
|   if (!PySequence_Check(obj)) { | ||||
|     throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList"); | ||||
|   } | ||||
|   THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL)); | ||||
|   if (seq.get() == NULL) { | ||||
|     throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj))); | ||||
|   } | ||||
|  | ||||
|   std::vector<THCStream*> streams; | ||||
|   Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); | ||||
|   for (Py_ssize_t i = 0; i < length; i++) { | ||||
|     PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i); | ||||
|  | ||||
|     if (PyObject_IsInstance(stream, THCPStreamClass)) { | ||||
|       streams.push_back( ((THCPStream *)stream)->cdata); | ||||
|     } else if (stream == Py_None) { | ||||
|       streams.push_back(NULL); | ||||
|     } else { | ||||
|       std::runtime_error("Unknown data type found in stream list. Need THCStream or None"); | ||||
|     } | ||||
|   } | ||||
|   return streams; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| @ -17,10 +17,6 @@ | ||||
| #include "generic/utils.cpp" | ||||
| #include <TH/THGenerateHalfType.h> | ||||
|  | ||||
| #ifdef WITH_CUDA | ||||
| #include "torch/csrc/cuda/THCP.h" | ||||
| #endif | ||||
|  | ||||
| int THPUtils_getCallable(PyObject *arg, PyObject **result) { | ||||
|   if (!PyCallable_Check(arg)) | ||||
|     return 0; | ||||
| @ -231,30 +227,3 @@ bool maybeThrowBackCompatKeepdimWarn(char *func) { | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| #ifdef WITH_CUDA | ||||
| std::vector <THCStream*> THPUtils_PySequence_to_THCStreamList(PyObject *obj) { | ||||
|   if (!PySequence_Check(obj)) { | ||||
|     throw std::runtime_error("Expected a sequence in THPUtils_PySequence_to_THCStreamList"); | ||||
|   } | ||||
|   THPObjectPtr seq = THPObjectPtr(PySequence_Fast(obj, NULL)); | ||||
|   if (seq.get() == NULL) { | ||||
|     throw std::runtime_error("expected PySequence, but got " + std::string(THPUtils_typename(obj))); | ||||
|   } | ||||
|  | ||||
|   std::vector<THCStream*> streams; | ||||
|   Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); | ||||
|   for (Py_ssize_t i = 0; i < length; i++) { | ||||
|     PyObject *stream = PySequence_Fast_GET_ITEM(seq.get(), i); | ||||
|  | ||||
|     if (PyObject_IsInstance(stream, THCPStreamClass)) { | ||||
|       streams.push_back( ((THCPStream *)stream)->cdata); | ||||
|     } else if (stream == Py_None) { | ||||
|       streams.push_back(NULL); | ||||
|     } else { | ||||
|       std::runtime_error("Unknown data type found in stream list. Need THCStream or None"); | ||||
|     } | ||||
|   } | ||||
|   return streams; | ||||
| } | ||||
| #endif | ||||
|  | ||||
		Reference in New Issue
	
	Block a user