mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Added nullptr check for pthradpool_get_threads_count (#34087)
Summary: We get seg fault without this in using XNNPACK. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34087 Differential Revision: D20199787 Pulled By: kimishpatel fbshipit-source-id: d3d274e7bb197461632b21688820cd4c10dcd819
This commit is contained in:
committed by
Facebook Github Bot
parent
ac6e75a165
commit
8269c4f3d3
@ -28,7 +28,19 @@ void pthreadpool_compute_1d(
|
||||
}
|
||||
|
||||
size_t pthreadpool_get_threads_count(pthreadpool_t threadpool) {
|
||||
// The current fix only useful when XNNPACK calls pthreadpool_get_threads_count with nullptr.
|
||||
if (threadpool == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return reinterpret_cast<caffe2::ThreadPool*>(threadpool)->getNumThreads();
|
||||
// TODO: Future fix: If we keep maintaining two different threadpools.
|
||||
// Old C2 and new one for XNNPACK, then the we have two different pthreadpool pointer
|
||||
// types. One is caffe2::Thredpool*, the other is pthreadpool* (pthreadpool_new_if_impl.c)
|
||||
// XNNPACK calls pthreadpool_get_threads_count during op setup using pthreadpool*, and
|
||||
// uses _parallelize_ interface for for actual work.
|
||||
// While NNPACK uses caffe2::Threadpool*.
|
||||
// Thus if pthreadpool_get_threads_count is getting called from XNNPACK we cannot
|
||||
// reinterpret_cast it to ThreadPool. It will seg fault or worse will have unedfined behavior.
|
||||
}
|
||||
|
||||
pthreadpool_t pthreadpool_create(size_t threads_count) {
|
||||
|
Reference in New Issue
Block a user