Move schema inference to c10 (#18090)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18090

This schema inference is needed by the c10 operator registration mechanism. Move it to c10.
It is going to be used by diffs stacked on top.

Reviewed By: ezyang

Differential Revision: D14491454

fbshipit-source-id: 0f8ddcdbd91467c8347d315dd443a1ca8b216481
This commit is contained in:
Sebastian Messmer
2019-03-21 14:51:38 -07:00
committed by Facebook Github Bot
parent 1877087df2
commit daa77c6e26
4 changed files with 95 additions and 72 deletions

View File

@ -0,0 +1,88 @@
#pragma once
/**
* This file contains functionality to take a C++ function and infer its
* c10::FunctionSchema.
*/
#include <ATen/core/function_schema.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
namespace c10 {
namespace detail {
/// Checks the static C++ type `T` for correctness to catch common error cases.
template <typename T>
void checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
"INVALID TYPE: Only int64_t is supported as an integral argument type");
static_assert(
!std::is_same<T, float>::value,
"INVALID TYPE: float is not supported as an argument type, use double instead");
}
template <typename First, typename Second, typename... Rest>
void checkStaticTypes() {
checkStaticTypes<First>();
checkStaticTypes<Second, Rest...>();
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
checkStaticTypes<guts::decay_t<Ts>...>();
// Arguments are named "_<index>"
return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
return createArgumentVectorFromTypes<Ts..., Is...>();
}
/// Unpack a tuple return type into a vector of return types, one per tuple
/// element.
template <typename... Ts>
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
}
/// Create a single-element `vector` for simple (non-tuple) return types.
template <typename ReturnType>
::std::vector<Argument> createReturns(ReturnType*) {
checkStaticTypes<guts::decay_t<ReturnType>>();
return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
}
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
/// into the argument list.
template <typename FunctionTraits, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
using ArgumentTypes = typename FunctionTraits::parameter_types;
return createArgumentVectorFromTypes<
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
}
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
using ReturnType = typename FunctionTraits::return_type;
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
guts::make_index_sequence<FunctionTraits::number_of_parameters>());
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
}
}
template<class FuncType>
FunctionSchema inferFunctionSchema(std::string name, std::string overload_name) {
return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
}

View File

@ -29,7 +29,7 @@ file(GLOB C10_SRCS
*.cpp *.cpp
core/*.cpp core/*.cpp
core/dispatch/*.cpp core/dispatch/*.cpp
core/opschema/*.cpp core/op_registration/*.cpp
core/impl/*.cpp core/impl/*.cpp
macros/*.cpp macros/*.cpp
util/*.cpp util/*.cpp

View File

@ -744,8 +744,8 @@ if __name__ == '__main__':
'include/c10/macros/*.h', 'include/c10/macros/*.h',
'include/c10/core/*.h', 'include/c10/core/*.h',
'include/ATen/core/dispatch/*.h', 'include/ATen/core/dispatch/*.h',
'include/ATen/core/op_registration/*.h',
'include/c10/core/impl/*.h', 'include/c10/core/impl/*.h',
'include/ATen/core/opschema/*.h',
'include/c10/util/*.h', 'include/c10/util/*.h',
'include/c10/cuda/*.h', 'include/c10/cuda/*.h',
'include/c10/cuda/impl/*.h', 'include/c10/cuda/impl/*.h',

View File

@ -2,6 +2,7 @@
#include <torch/csrc/jit/operator.h> #include <torch/csrc/jit/operator.h>
#include <ATen/core/stack.h> #include <ATen/core/stack.h>
#include <ATen/core/op_registration/infer_schema.h>
#include <torch/csrc/jit/tracer.h> #include <torch/csrc/jit/tracer.h>
#include <torch/csrc/utils/variadic.h> #include <torch/csrc/utils/variadic.h>
@ -16,73 +17,7 @@ namespace detail {
using ::c10::Argument; using ::c10::Argument;
using ::c10::FunctionSchema; using ::c10::FunctionSchema;
/// Checks the static C++ type `T` for correctness to catch common error cases.
template <typename T>
void checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
"INVALID TYPE: Only int64_t is supported as an integral argument type");
static_assert(
!std::is_same<T, float>::value,
"INVALID TYPE: float is not supported as an argument type, use double instead");
}
template <typename First, typename Second, typename... Rest>
void checkStaticTypes() {
checkStaticTypes<First>();
checkStaticTypes<Second, Rest...>();
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTypes(Indices<Is...> indices) {
checkStaticTypes<decay_t<Ts>...>();
// Arguments are named "_<index>"
return {Argument("_" + std::to_string(Is), getTypePtr<decay_t<Ts>>())...};
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createReturns(Indices<Is...> indices) {
return createArgumentVectorFromTypes<Ts..., Is...>();
}
/// Unpack a tuple return type into a vector of return types, one per tuple
/// element.
template <typename... Ts>
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
// Create an index pack so we can call `get<Indices>` on the tuple next.
return createReturns<Ts...>(typename MakeIndices<sizeof...(Ts)>::indices{});
}
/// Create a single-element `vector` for simple (non-tuple) return types.
template <typename ReturnType>
::std::vector<Argument> createReturns(ReturnType*) {
checkStaticTypes<decay_t<ReturnType>>();
return {Argument("_1", getTypePtr<decay_t<ReturnType>>())};
}
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
/// into the argument list.
template <typename FunctionTraits, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTraits(Indices<Is...> indices) {
using ArgumentTypes = typename FunctionTraits::parameter_types;
return createArgumentVectorFromTypes<
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
}
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(const std::string& name) {
using ReturnType = typename FunctionTraits::return_type;
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
typename MakeIndices<FunctionTraits::number_of_parameters>::indices{});
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
return {name, "", arguments, returns};
}
/// Adds the elements of the `tuple` as input nodes to the traced graph. /// Adds the elements of the `tuple` as input nodes to the traced graph.
template <size_t... Is, typename... Types> template <size_t... Is, typename... Types>
@ -178,8 +113,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
const auto bracketIndex = schemaOrName.find('('); const auto bracketIndex = schemaOrName.find('(');
if (bracketIndex == std::string::npos) { if (bracketIndex == std::string::npos) {
// Infer the full schema and we're good. // Infer the full schema and we're good.
return torch::jit::detail::createFunctionSchemaFromTraits<Traits>( return c10::detail::createFunctionSchemaFromTraits<Traits>(
/*name=*/schemaOrName); /*name=*/schemaOrName, "");
} }
// If the user provided her own schema, we need to infer it nevertheless and // If the user provided her own schema, we need to infer it nevertheless and
@ -189,8 +124,8 @@ FunctionSchema inferAndCheckSchema(const std::string& schemaOrName) {
auto providedSchema = parseSchema(schemaOrName); auto providedSchema = parseSchema(schemaOrName);
const auto inferredSchema = const auto inferredSchema =
torch::jit::detail::createFunctionSchemaFromTraits<Traits>( c10::detail::createFunctionSchemaFromTraits<Traits>(
providedSchema.name()); providedSchema.name(), providedSchema.overload_name());
checkArgumentVector( checkArgumentVector(
"argument", "argument",
inferredSchema.arguments(), inferredSchema.arguments(),