[CUBLAS][CUDA GRAPHS] (re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85292)

re-open of #83461 with fix for 10.2 build

CC @ngimel @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85292
Approved by: https://github.com/malfet
This commit is contained in:
eqy
2022-09-20 16:31:54 +00:00
committed by PyTorch MergeBot
parent 39f482acdf
commit 4012e623e8
8 changed files with 105 additions and 0 deletions

View File

@ -806,6 +806,15 @@ PyObject* THCPModule_getCurrentBlasHandle_wrap(
END_HANDLE_TH_ERRORS
}
static PyObject* THCPModule_clearBlasWorkspaces_wrap(
PyObject* self,
PyObject* noargs) {
HANDLE_TH_ERRORS
at::cuda::clearCublasWorkspaces();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THCPModule_rocm_is_backward_pass(
PyObject* _unused,
PyObject* noargs) {
@ -895,6 +904,10 @@ static struct PyMethodDef _THCPModule_methods[] = {
THCPModule_getCurrentBlasHandle_wrap,
METH_NOARGS,
nullptr},
{"_cuda_clearCublasWorkspaces",
THCPModule_clearBlasWorkspaces_wrap,
METH_NOARGS,
nullptr},
{"_cuda_isCurrentStreamCapturing",
THCPModule_isCurrentStreamCapturing_wrap,
METH_NOARGS,