Added nullptr check for pthradpool_get_threads_count (#34087)

Summary:
We get seg fault without this in using XNNPACK.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34087

Differential Revision: D20199787

Pulled By: kimishpatel

fbshipit-source-id: d3d274e7bb197461632b21688820cd4c10dcd819
This commit is contained in:
Kimish Patel
2020-03-04 11:08:02 -08:00
committed by Facebook Github Bot
parent ac6e75a165
commit 8269c4f3d3

View File

@ -28,7 +28,19 @@ void pthreadpool_compute_1d(
}
size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) {
// The current fix only useful when XNNPACK calls pthreadpool_get_threads_count with nullptr.
if (threadpool == nullptr) {
return 1;
}
return reinterpret_cast<caffe2::ThreadPool*>(threadpool)->getNumThreads();
// TODO: Future fix: If we keep maintaining two different threadpools.
// Old C2 and new one for XNNPACK, then the we have two different pthreadpool pointer
// types. One is caffe2::Thredpool*, the other is pthreadpool* (pthreadpool_new_if_impl.c)
// XNNPACK calls pthreadpool_get_threads_count during op setup using pthreadpool*, and
// uses _parallelize_ interface for for actual work.
// While NNPACK uses caffe2::Threadpool*.
// Thus if pthreadpool_get_threads_count is getting called from XNNPACK we cannot
// reinterpret_cast it to ThreadPool. It will seg fault or worse will have unedfined behavior.
}
pthreadpool_t pthreadpool_create(size_t threads_count) {