mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move schema inference to c10 (#18090)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18090 This schema inference is needed by the c10 operator registration mechanism. Move it to c10. It is going to be used by diffs stacked on top. Reviewed By: ezyang Differential Revision: D14491454 fbshipit-source-id: 0f8ddcdbd91467c8347d315dd443a1ca8b216481
This commit is contained in:
committed by
Facebook Github Bot
parent
1877087df2
commit
daa77c6e26
88
aten/src/ATen/core/op_registration/infer_schema.h
Normal file
88
aten/src/ATen/core/op_registration/infer_schema.h
Normal file
@ -0,0 +1,88 @@
|
||||
#pragma once
|
||||
|
||||
/**
|
||||
* This file contains functionality to take a C++ function and infer its
|
||||
* c10::FunctionSchema.
|
||||
*/
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
/// Checks the static C++ type `T` for correctness to catch common error cases.
|
||||
template <typename T>
|
||||
void checkStaticTypes() {
|
||||
// Give nice error messages for some of the common error cases.
|
||||
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
||||
static_assert(
|
||||
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
|
||||
"INVALID TYPE: Only int64_t is supported as an integral argument type");
|
||||
static_assert(
|
||||
!std::is_same<T, float>::value,
|
||||
"INVALID TYPE: float is not supported as an argument type, use double instead");
|
||||
}
|
||||
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
void checkStaticTypes() {
|
||||
checkStaticTypes<First>();
|
||||
checkStaticTypes<Second, Rest...>();
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
|
||||
checkStaticTypes<guts::decay_t<Ts>...>();
|
||||
// Arguments are named "_<index>"
|
||||
return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
|
||||
return createArgumentVectorFromTypes<Ts..., Is...>();
|
||||
}
|
||||
|
||||
/// Unpack a tuple return type into a vector of return types, one per tuple
|
||||
/// element.
|
||||
template <typename... Ts>
|
||||
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
|
||||
return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
|
||||
}
|
||||
|
||||
/// Create a single-element `vector` for simple (non-tuple) return types.
|
||||
template <typename ReturnType>
|
||||
::std::vector<Argument> createReturns(ReturnType*) {
|
||||
checkStaticTypes<guts::decay_t<ReturnType>>();
|
||||
return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
|
||||
}
|
||||
|
||||
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
|
||||
/// into the argument list.
|
||||
template <typename FunctionTraits, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
|
||||
using ArgumentTypes = typename FunctionTraits::parameter_types;
|
||||
return createArgumentVectorFromTypes<
|
||||
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
|
||||
}
|
||||
|
||||
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
||||
/// function.
|
||||
template <typename FunctionTraits>
|
||||
FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
|
||||
using ReturnType = typename FunctionTraits::return_type;
|
||||
|
||||
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
|
||||
guts::make_index_sequence<FunctionTraits::number_of_parameters>());
|
||||
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
|
||||
|
||||
return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
|
||||
}
|
||||
}
|
||||
|
||||
template<class FuncType>
|
||||
FunctionSchema inferFunctionSchema(std::string name, std::string overload_name) {
|
||||
return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
||||
}
|
||||
|
||||
}
|
@ -29,7 +29,7 @@ file(GLOB C10_SRCS
|
||||
*.cpp
|
||||
core/*.cpp
|
||||
core/dispatch/*.cpp
|
||||
core/opschema/*.cpp
|
||||
core/op_registration/*.cpp
|
||||
core/impl/*.cpp
|
||||
macros/*.cpp
|
||||
util/*.cpp
|
||||
|
2
setup.py
2
setup.py
@ -744,8 +744,8 @@ if __name__ == '__main__':
|
||||
'include/c10/macros/*.h',
|
||||
'include/c10/core/*.h',
|
||||
'include/ATen/core/dispatch/*.h',
|
||||
'include/ATen/core/op_registration/*.h',
|
||||
'include/c10/core/impl/*.h',
|
||||
'include/ATen/core/opschema/*.h',
|
||||
'include/c10/util/*.h',
|
||||
'include/c10/cuda/*.h',
|
||||
'include/c10/cuda/impl/*.h',
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
|
||||
@ -16,73 +17,7 @@ namespace detail {
|
||||
using ::c10::Argument;
|
||||
using ::c10::FunctionSchema;
|
||||
|
||||
/// Checks the static C++ type `T` for correctness to catch common error cases.
|
||||
template <typename T>
|
||||
void checkStaticTypes() {
|
||||
// Give nice error messages for some of the common error cases.
|
||||
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
||||
static_assert(
|
||||
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
|
||||
"INVALID TYPE: Only int64_t is supported as an integral argument type");
|
||||
static_assert(
|
||||
!std::is_same<T, float>::value,
|
||||
"INVALID TYPE: float is not supported as an argument type, use double instead");
|
||||
}
|
||||
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
void checkStaticTypes() {
|
||||
checkStaticTypes<First>();
|
||||
checkStaticTypes<Second, Rest...>();
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
|
||||
checkStaticTypes<decay_t<Ts>...>();
|
||||
// Arguments are named "_<index>"
|
||||
return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createReturns(Indices<Is...> indices) {
|
||||
return createArgumentVectorFromTypes<Ts..., Is...>();
|
||||
}
|
||||
|
||||
/// Unpack a tuple return type into a vector of return types, one per tuple
|
||||
/// element.
|
||||
template <typename... Ts>
|
||||
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
|
||||
// Create an index pack so we can call `get<Indices>` on the tuple next.
|
||||
return createReturns<Ts...>(typename MakeIndices<sizeof...(Ts)>::indices{});
|
||||
}
|
||||
|
||||
/// Create a single-element `vector` for simple (non-tuple) return types.
|
||||
template <typename ReturnType>
|
||||
::std::vector<Argument> createReturns(ReturnType*) {
|
||||
checkStaticTypes<decay_t<ReturnType>>();
|
||||
return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
|
||||
}
|
||||
|
||||
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
|
||||
/// into the argument list.
|
||||
template <typename FunctionTraits, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
|
||||
using ArgumentTypes = typename FunctionTraits::parameter_types;
|
||||
return createArgumentVectorFromTypes<
|
||||
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
|
||||
}
|
||||
|
||||
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
||||
/// function.
|
||||
template <typename FunctionTraits>
|
||||
FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
|
||||
using ReturnType = typename FunctionTraits::return_type;
|
||||
|
||||
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
|
||||
typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
|
||||
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
|
||||
|
||||
return {name, "", arguments, returns};
|
||||
}
|
||||
|
||||
/// Adds the elements of the `tuple` as input nodes to the traced graph.
|
||||
template <size_t... Is, typename... Types>
|
||||
@ -178,8 +113,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
|
||||
const auto bracketIndex = schemaOrName.find('(');
|
||||
if (bracketIndex == std::string::npos) {
|
||||
// Infer the full schema and we're good.
|
||||
return torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
|
||||
/*name=*/schemaOrName);
|
||||
return c10::detail::createFunctionSchemaFromTraits<Traits>(
|
||||
/*name=*/schemaOrName, "");
|
||||
}
|
||||
|
||||
// If the user provided her own schema, we need to infer it nevertheless and
|
||||
@ -189,8 +124,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
|
||||
auto providedSchema = parseSchema(schemaOrName);
|
||||
|
||||
const auto inferredSchema =
|
||||
torch::jit::detail::createFunctionSchemaFromTraits<Traits>(
|
||||
providedSchema.name());
|
||||
c10::detail::createFunctionSchemaFromTraits<Traits>(
|
||||
providedSchema.name(), providedSchema.overload_name());
|
||||
checkArgumentVector(
|
||||
"argument",
|
||||
inferredSchema.arguments(),
|
||||
|
Reference in New Issue
Block a user