mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR continues to fix clang-tidy warnings for headers in c10/core and c10/util. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115495 Approved by: https://github.com/malfet
58 lines
1.7 KiB
C++
58 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/TypeTraits.h>
|
|
#include <type_traits>
|
|
|
|
namespace c10 {
|
|
|
|
/**
|
|
* Represent a function pointer as a C++ type.
|
|
* This allows using the function pointer as a type
|
|
* in a template and calling it from inside the template
|
|
* allows the compiler to inline the call because it
|
|
* knows the function pointer at compile time.
|
|
*
|
|
* Example 1:
|
|
* int add(int a, int b) {return a + b;}
|
|
* using Add = TORCH_FN_TYPE(add);
|
|
* template<class Func> struct Executor {
|
|
* int execute(int a, int b) {
|
|
* return Func::func_ptr()(a, b);
|
|
* }
|
|
* };
|
|
* Executor<Add> executor;
|
|
* EXPECT_EQ(3, executor.execute(1, 2));
|
|
*
|
|
* Example 2:
|
|
* int add(int a, int b) {return a + b;}
|
|
* template<class Func> int execute(Func, int a, int b) {
|
|
* return Func::func_ptr()(a, b);
|
|
* }
|
|
* EXPECT_EQ(3, execute(TORCH_FN(add), 1, 2));
|
|
*/
|
|
template <class FuncType_, FuncType_* func_ptr_>
|
|
struct CompileTimeFunctionPointer final {
|
|
static_assert(
|
|
guts::is_function_type<FuncType_>::value,
|
|
"TORCH_FN can only wrap function types.");
|
|
using FuncType = FuncType_;
|
|
|
|
static constexpr FuncType* func_ptr() {
|
|
return func_ptr_;
|
|
}
|
|
};
|
|
|
|
template <class T>
|
|
struct is_compile_time_function_pointer : std::false_type {};
|
|
template <class FuncType, FuncType* func_ptr>
|
|
struct is_compile_time_function_pointer<
|
|
CompileTimeFunctionPointer<FuncType, func_ptr>> : std::true_type {};
|
|
|
|
} // namespace c10
|
|
|
|
#define TORCH_FN_TYPE(func) \
|
|
::c10::CompileTimeFunctionPointer< \
|
|
std::remove_pointer_t<std::remove_reference_t<decltype(func)>>, \
|
|
func>
|
|
#define TORCH_FN(func) TORCH_FN_TYPE(func)()
|