From 0cad2c061555123f5ce08553388f6cf79a2f80ea Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 8 Oct 2021 09:04:21 -0700 Subject: [PATCH] Move intraop_launch_future from Parallel.h (#64166) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64166 Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30728585 Pulled By: dagitses fbshipit-source-id: 75a41418ae9218bec9bac27597051295222b6eee --- aten/src/ATen/Parallel-inl.h | 1 + aten/src/ATen/Parallel.h | 6 +----- aten/src/ATen/ParallelFuture.h | 13 +++++++++++++ aten/src/ATen/ParallelNative.cpp | 1 + aten/src/ATen/ParallelNative.h | 2 ++ aten/src/ATen/ParallelNativeTBB.cpp | 1 + aten/src/ATen/ParallelNativeTBB.h | 2 ++ aten/src/ATen/ParallelOpenMP.cpp | 1 + aten/src/ATen/test/test_parallel.cpp | 1 + binaries/at_launch_benchmark.cc | 1 + 10 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 aten/src/ATen/ParallelFuture.h diff --git a/aten/src/ATen/Parallel-inl.h b/aten/src/ATen/Parallel-inl.h index fc07455f25c7..e336c0192cd2 100644 --- a/aten/src/ATen/Parallel-inl.h +++ b/aten/src/ATen/Parallel-inl.h @@ -1,5 +1,6 @@ #pragma once +#include #include namespace at { diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index ed797dbfa37c..5a71be1742f8 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -1,7 +1,7 @@ #pragma once #include -#include #include +#include namespace at { @@ -144,10 +144,6 @@ void launch_no_thread_state(std::function fn); // Launches intra-op parallel task TORCH_API void intraop_launch(std::function func); -// Launches intra-op parallel task, returns a future -TORCH_API c10::intrusive_ptr intraop_launch_future( - std::function func); - // Returns number of intra-op threads used by default TORCH_API int intraop_default_num_threads(); diff --git a/aten/src/ATen/ParallelFuture.h b/aten/src/ATen/ParallelFuture.h new file mode 100644 index 000000000000..4a6a5870610f --- /dev/null +++ b/aten/src/ATen/ParallelFuture.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace at { + +// Launches intra-op parallel task, returns a future +TORCH_API c10::intrusive_ptr intraop_launch_future( + std::function func); + +} // namespace at diff --git a/aten/src/ATen/ParallelNative.cpp b/aten/src/ATen/ParallelNative.cpp index ab759e6f6dde..753dbdb751e6 100644 --- a/aten/src/ATen/ParallelNative.cpp +++ b/aten/src/ATen/ParallelNative.cpp @@ -1,6 +1,7 @@ #include #if AT_PARALLEL_NATIVE #include +#include #include #ifndef C10_MOBILE diff --git a/aten/src/ATen/ParallelNative.h b/aten/src/ATen/ParallelNative.h index 74ebeb75351f..420112602904 100644 --- a/aten/src/ATen/ParallelNative.h +++ b/aten/src/ATen/ParallelNative.h @@ -4,6 +4,8 @@ #include #include +#include + #define INTRA_OP_PARALLEL namespace at { diff --git a/aten/src/ATen/ParallelNativeTBB.cpp b/aten/src/ATen/ParallelNativeTBB.cpp index c38dcb64f81b..5d878add3225 100644 --- a/aten/src/ATen/ParallelNativeTBB.cpp +++ b/aten/src/ATen/ParallelNativeTBB.cpp @@ -1,6 +1,7 @@ #include #if AT_PARALLEL_NATIVE_TBB #include +#include #include #include diff --git a/aten/src/ATen/ParallelNativeTBB.h b/aten/src/ATen/ParallelNativeTBB.h index 625355a1de95..01dda99990c8 100644 --- a/aten/src/ATen/ParallelNativeTBB.h +++ b/aten/src/ATen/ParallelNativeTBB.h @@ -4,6 +4,8 @@ #include #include +#include + #ifdef _WIN32 #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index ab459fb55cb7..800549a89ab2 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -2,6 +2,7 @@ #include #if AT_PARALLEL_OPENMP #include +#include #include diff --git a/aten/src/ATen/test/test_parallel.cpp b/aten/src/ATen/test/test_parallel.cpp index 95a7c281f13e..84d27fc0042d 100644 --- a/aten/src/ATen/test/test_parallel.cpp +++ b/aten/src/ATen/test/test_parallel.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include // NOLINTNEXTLINE(modernize-deprecated-headers) diff --git a/binaries/at_launch_benchmark.cc b/binaries/at_launch_benchmark.cc index 8c7c23c5f7ee..e1f05d727da8 100644 --- a/binaries/at_launch_benchmark.cc +++ b/binaries/at_launch_benchmark.cc @@ -3,6 +3,7 @@ #include "c10/util/Flags.h" #include "caffe2/core/init.h" +#include #include #include #include