mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
push magma init into lazyInitCUDA (#18527)
Summary: Tries to fix C++ API's usage of MAGMA-based functions. Attempts to Fix https://github.com/pytorch/pytorch/issues/18074 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18527 Differential Revision: D14691694 Pulled By: soumith fbshipit-source-id: dd04e74418e486d73ea4a92193ddf79352ed71ba
This commit is contained in:
committed by
Facebook Github Bot
parent
ed9724f385
commit
b5d8844bbe
@ -33,6 +33,9 @@ std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
|
||||
THCState* thc_state = THCState_alloc();
|
||||
|
||||
THCudaInit(thc_state);
|
||||
#ifdef USE_MAGMA
|
||||
THCMagma_init(thc_state);
|
||||
#endif
|
||||
return std::unique_ptr<THCState, void (*)(THCState*)>(
|
||||
thc_state, [](THCState* p) {
|
||||
if (p)
|
||||
|
@ -113,3 +113,11 @@ TEST(TensorTest, ToDeviceAndDtype_MultiCUDA) {
|
||||
tensor = tensor.to(at::kCPU, at::kInt);
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
||||
}
|
||||
|
||||
TEST(TensorTest, MagmaInitializesCorrectly_CUDA) {
|
||||
auto tensor = at::arange(1, 17, at::TensorOptions(at::kFloat).device(at::Device("cuda")));
|
||||
tensor = tensor.view({4, 4});
|
||||
if (at::hasMAGMA()) {
|
||||
at::inverse(tensor);
|
||||
}
|
||||
}
|
||||
|
@ -380,11 +380,6 @@ static PyObject * THCPModule_initExtension(PyObject *self)
|
||||
THCPByteStorage_postInit(m);
|
||||
THCPBoolStorage_postInit(m);
|
||||
|
||||
bool has_magma = at::hasMAGMA();
|
||||
if (has_magma) {
|
||||
THCMagma_init(state);
|
||||
}
|
||||
|
||||
bool has_half = true;
|
||||
|
||||
auto set_module_attr = [&](const char* name, PyObject* v) {
|
||||
@ -394,7 +389,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
|
||||
}
|
||||
};
|
||||
|
||||
set_module_attr("has_magma", has_magma ? Py_True : Py_False);
|
||||
set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
|
||||
set_module_attr("has_half", has_half ? Py_True : Py_False);
|
||||
|
||||
auto _state_cdata = THPObjectPtr(PyLong_FromVoidPtr(state));
|
||||
|
Reference in New Issue
Block a user