mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert D21089648: Put TORCH_LIBRARY in torch/library.h; add custom class API
Test Plan: revert-hammer Differential Revision: D21089648 Original commit changeset: 8d54329c1252 fbshipit-source-id: 636e8a11afc628a4cdae9d44824985c10c70555e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a05406ea56
commit
2ccdc39dce
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/autocast_mode.h>
|
#include <ATen/autocast_mode.h>
|
||||||
|
|
||||||
@ -373,7 +373,7 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
|
|||||||
Explicit registration for out-of-place ops
|
Explicit registration for out-of-place ops
|
||||||
*****************************************/
|
*****************************************/
|
||||||
TORCH_LIBRARY_IMPL(_, Autocast, m) {
|
TORCH_LIBRARY_IMPL(_, Autocast, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(c10::CppFunction::makeFallthrough());
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_, BackendSelect, m) {
|
TORCH_LIBRARY_IMPL(_, BackendSelect, m) {
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(c10::CppFunction::makeFallthrough());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
#include <ATen/core/LegacyTypeDispatch.h>
|
#include <ATen/core/LegacyTypeDispatch.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* This file implements a variable fallback kernel for custom operators.
|
* This file implements a variable fallback kernel for custom operators.
|
||||||
@ -65,9 +65,9 @@ TORCH_LIBRARY_IMPL(_, Autograd, m) {
|
|||||||
//
|
//
|
||||||
// We can remove this `fallthrough` kernel when all kernels support boxed
|
// We can remove this `fallthrough` kernel when all kernels support boxed
|
||||||
// call.
|
// call.
|
||||||
m.fallback(torch::CppFunction::makeFallthrough());
|
m.fallback(c10::CppFunction::makeFallthrough());
|
||||||
#else
|
#else
|
||||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>());
|
m.fallback(c10::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/stack.h>
|
|
||||||
#include <c10/util/Metaprogramming.h>
|
#include <c10/util/Metaprogramming.h>
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
#include <torch/library.h>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
// TODO: Consider representing debug info as a struct instead so you
|
|
||||||
// don't have to allocate strings all the time
|
|
||||||
std::string debugString(std::string debug, const char* file, uint32_t line) {
|
|
||||||
#ifdef STRIP_ERROR_MESSAGES
|
|
||||||
return "";
|
|
||||||
#else
|
|
||||||
if (debug.empty()) {
|
|
||||||
return c10::str("registered at ", file, ":", line);
|
|
||||||
} else {
|
|
||||||
return debug;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, Library::Kind kind) {
|
|
||||||
switch (kind) {
|
|
||||||
case Library::DEF:
|
|
||||||
os << "TORCH_LIBRARY";
|
|
||||||
break;
|
|
||||||
case Library::IMPL:
|
|
||||||
os << "TORCH_LIBRARY_IMPL";
|
|
||||||
break;
|
|
||||||
case Library::FRAGMENT:
|
|
||||||
os << "TORCH_LIBRARY_FRAGMENT";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CppFunction::CppFunction(c10::KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema)
|
|
||||||
: func_(std::move(func))
|
|
||||||
, schema_(std::move(schema))
|
|
||||||
, debug_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")"
|
|
||||||
|
|
||||||
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
|
||||||
: kind_(kind)
|
|
||||||
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
|
||||||
, dispatch_key_((!k.has_value() || *k == c10::DispatchKey::CatchAll) ? c10::nullopt : k)
|
|
||||||
, file_(file)
|
|
||||||
, line_(line)
|
|
||||||
{
|
|
||||||
switch (kind_) {
|
|
||||||
case DEF:
|
|
||||||
// Only DEFs require library uniqueness; fragments
|
|
||||||
// don't register a library
|
|
||||||
registrars_.emplace_back(
|
|
||||||
c10::Dispatcher::singleton().registerLibrary(
|
|
||||||
*ns_, debugString("", file_, line_)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
// fallthrough
|
|
||||||
case FRAGMENT:
|
|
||||||
TORCH_CHECK(
|
|
||||||
ns_.has_value(),
|
|
||||||
kind_, ": cannot define ", kind_, " with the wildcard namespace _ "
|
|
||||||
"(every ", kind_, " defines operators for a distinct namespace!)"
|
|
||||||
"Did you mean to use TORCH_LIBRARY_IMPL instead? "
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
|
||||||
break;
|
|
||||||
case IMPL:
|
|
||||||
// Nothing to do, everything is OK
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Error if an operator is def'ed multiple times. Right now we just
|
|
||||||
// merge everything
|
|
||||||
|
|
||||||
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
|
||||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
|
|
||||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
|
||||||
DEF_PRELUDE,
|
|
||||||
"Cannot define an operator inside of a ", kind_, " block. "
|
|
||||||
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
|
|
||||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
|
||||||
auto ns_opt = schema.getNamespace();
|
|
||||||
if (ns_opt.has_value()) {
|
|
||||||
// Note [Redundancy in registration code is OK]
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
// In an earlier version of this code, I made it an error to explicitly
|
|
||||||
// specify the namespace, even when the namespaces match. I've decided
|
|
||||||
// to relax this constraint because sometimes we code generate registrations
|
|
||||||
// and you cannot conveniently tell what the enclosing context will be;
|
|
||||||
// in these cases, it is simpler (and less error prone) to place all
|
|
||||||
// of the information in the registration site, which will be cross-checked
|
|
||||||
// in the end in any case (and if it turns out you DON'T have the right
|
|
||||||
// information at the site, as is the case with backend specific
|
|
||||||
// per-op registrations, you will get the right behavior!)
|
|
||||||
TORCH_CHECK(false,
|
|
||||||
*ns_opt == *ns_,
|
|
||||||
"Explicitly provided namespace (", *ns_opt, ") in schema string "
|
|
||||||
"does not match namespace of enclsing ", kind_, " block (", *ns_, "). "
|
|
||||||
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
|
|
||||||
"(and consider deleting the namespace from your schema string.) ",
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
|
|
||||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
|
||||||
}
|
|
||||||
if (out_name) {
|
|
||||||
*out_name = schema.operator_name(); // copy!
|
|
||||||
}
|
|
||||||
registrars_.emplace_back(
|
|
||||||
c10::Dispatcher::singleton().registerDef(
|
|
||||||
std::move(schema),
|
|
||||||
debugString("", file_, line_)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
#undef DEF_PRELUDE
|
|
||||||
|
|
||||||
Library& Library::_def(c10::either<c10::OperatorName, c10::FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
|
||||||
c10::FunctionSchema schema = [&] {
|
|
||||||
if (name_or_schema.is_right()) {
|
|
||||||
return std::move(name_or_schema).right();
|
|
||||||
} else {
|
|
||||||
// it's a name; use the inferred schema
|
|
||||||
c10::OperatorName name = std::move(name_or_schema).left();
|
|
||||||
TORCH_CHECK(f.schema_,
|
|
||||||
"def(\"", name, "\"): "
|
|
||||||
"Full schema string was not specified, and we couldn't infer schema either. ",
|
|
||||||
"Please explicitly provide a schema string. ",
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
c10::FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
|
|
||||||
s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
c10::OperatorName name("", ""); // Get the namespaced name for the impl call
|
|
||||||
// First define the schema...
|
|
||||||
_def(std::move(schema), &name);
|
|
||||||
// Then register the implementation...
|
|
||||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
|
||||||
registrars_.emplace_back(
|
|
||||||
c10::Dispatcher::singleton().registerImpl(
|
|
||||||
std::move(name),
|
|
||||||
dispatch_key,
|
|
||||||
std::move(f.func_),
|
|
||||||
std::move(f.schema_),
|
|
||||||
debugString(std::move(f.debug_), file_, line_)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
|
|
||||||
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
|
||||||
auto name = torch::jit::parseName(name_str);
|
|
||||||
auto ns_opt = name.getNamespace();
|
|
||||||
// This is kind of similar to the checking in def(), but the error
|
|
||||||
// messages are a little different for this call site
|
|
||||||
if (ns_opt.has_value()) {
|
|
||||||
// See Note [Redundancy in registration code is OK]
|
|
||||||
TORCH_CHECK(*ns_opt == *ns_,
|
|
||||||
IMPL_PRELUDE,
|
|
||||||
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
|
||||||
"does not match namespace of enclosing ", kind_, " block (", *ns_, "). "
|
|
||||||
"Move this definition to the ", kind_, " block corresponding to this namespace "
|
|
||||||
"(and consider deleting the namespace from your schema string.) ",
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
|
||||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
|
||||||
}
|
|
||||||
// See Note [Redundancy in registration code is OK]
|
|
||||||
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
|
|
||||||
dispatch_key_.has_value() &&
|
|
||||||
*f.dispatch_key_ != *dispatch_key_),
|
|
||||||
IMPL_PRELUDE,
|
|
||||||
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
|
|
||||||
"with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). "
|
|
||||||
"Please declare a separate ", kind_, " block for this dispatch key and "
|
|
||||||
"move your impl() there. "
|
|
||||||
ERROR_CONTEXT
|
|
||||||
);
|
|
||||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
|
||||||
registrars_.emplace_back(
|
|
||||||
c10::Dispatcher::singleton().registerImpl(
|
|
||||||
std::move(name),
|
|
||||||
dispatch_key,
|
|
||||||
std::move(f.func_),
|
|
||||||
std::move(f.schema_),
|
|
||||||
debugString(std::move(f.debug_), file_, line_)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
#undef IMPL_PRELUDE
|
|
||||||
|
|
||||||
Library& Library::_fallback(CppFunction&& f) & {
|
|
||||||
TORCH_CHECK(kind_ == IMPL,
|
|
||||||
"fallback(...): Cannot define an operator inside of a ", kind_, " block. "
|
|
||||||
"Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ",
|
|
||||||
ERROR_CONTEXT);
|
|
||||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
|
||||||
TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
|
|
||||||
TORCH_CHECK(!ns_.has_value(),
|
|
||||||
"fallback(...): Fallback functions which apply to only a single namespace ",
|
|
||||||
"(you specified ", *ns_, ") are not supported. If you intended to apply ",
|
|
||||||
"this fallback function globally, please define a separate block:\n\n",
|
|
||||||
" TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
|
|
||||||
ERROR_CONTEXT);
|
|
||||||
registrars_.emplace_back(
|
|
||||||
c10::Dispatcher::singleton().registerFallback(
|
|
||||||
*dispatch_key,
|
|
||||||
std::move(f.func_),
|
|
||||||
debugString(std::move(f.debug_), file_, line_)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace torch
|
|
@ -7,6 +7,37 @@
|
|||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// TODO: Consider representing debug info as a struct instead so you
|
||||||
|
// don't have to allocate strings all the time
|
||||||
|
std::string debugString(std::string debug, const char* file, uint32_t line) {
|
||||||
|
#ifdef STRIP_ERROR_MESSAGES
|
||||||
|
return "";
|
||||||
|
#else
|
||||||
|
if (debug.empty()) {
|
||||||
|
return c10::str("registered at ", file, ":", line);
|
||||||
|
} else {
|
||||||
|
return debug;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, Library::Kind kind) {
|
||||||
|
switch (kind) {
|
||||||
|
case Library::DEF:
|
||||||
|
os << "TORCH_LIBRARY";
|
||||||
|
break;
|
||||||
|
case Library::IMPL:
|
||||||
|
os << "TORCH_LIBRARY_IMPL";
|
||||||
|
break;
|
||||||
|
case Library::FRAGMENT:
|
||||||
|
os << "TORCH_LIBRARY_FRAGMENT";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static_assert(std::is_nothrow_move_constructible<c10::optional<RegistrationHandleRAII>>::value, "");
|
static_assert(std::is_nothrow_move_constructible<c10::optional<RegistrationHandleRAII>>::value, "");
|
||||||
static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRAII>>::value, "");
|
static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRAII>>::value, "");
|
||||||
|
|
||||||
@ -109,4 +140,200 @@ RegisterOperators::~RegisterOperators() = default;
|
|||||||
RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default;
|
RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default;
|
||||||
RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) noexcept = default;
|
RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) noexcept = default;
|
||||||
|
|
||||||
} // namespace c10
|
|
||||||
|
CppFunction::CppFunction(KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema)
|
||||||
|
: func_(std::move(func))
|
||||||
|
, schema_(std::move(schema))
|
||||||
|
, debug_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")"
|
||||||
|
|
||||||
|
Library::Library(Kind kind, std::string ns, c10::optional<DispatchKey> k, const char* file, uint32_t line)
|
||||||
|
: kind_(kind)
|
||||||
|
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
||||||
|
, dispatch_key_((!k.has_value() || *k == DispatchKey::CatchAll) ? c10::nullopt : k)
|
||||||
|
, file_(file)
|
||||||
|
, line_(line)
|
||||||
|
{
|
||||||
|
switch (kind_) {
|
||||||
|
case DEF:
|
||||||
|
// Only DEFs require library uniqueness; fragments
|
||||||
|
// don't register a library
|
||||||
|
registrars_.emplace_back(
|
||||||
|
Dispatcher::singleton().registerLibrary(
|
||||||
|
*ns_, debugString("", file_, line_)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
// fallthrough
|
||||||
|
case FRAGMENT:
|
||||||
|
TORCH_CHECK(
|
||||||
|
ns_.has_value(),
|
||||||
|
kind_, ": cannot define ", kind_, " with the wildcard namespace _ "
|
||||||
|
"(every ", kind_, " defines operators for a distinct namespace!)"
|
||||||
|
"Did you mean to use TORCH_LIBRARY_IMPL instead? "
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||||
|
break;
|
||||||
|
case IMPL:
|
||||||
|
// Nothing to do, everything is OK
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Error if an operator is def'ed multiple times. Right now we just
|
||||||
|
// merge everything
|
||||||
|
|
||||||
|
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
||||||
|
Library& Library::_def(FunctionSchema&& schema, OperatorName* out_name) & {
|
||||||
|
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||||
|
DEF_PRELUDE,
|
||||||
|
"Cannot define an operator inside of a ", kind_, " block. "
|
||||||
|
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
|
||||||
|
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||||
|
auto ns_opt = schema.getNamespace();
|
||||||
|
if (ns_opt.has_value()) {
|
||||||
|
// Note [Redundancy in registration code is OK]
|
||||||
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
// In an earlier version of this code, I made it an error to explicitly
|
||||||
|
// specify the namespace, even when the namespaces match. I've decided
|
||||||
|
// to relax this constraint because sometimes we code generate registrations
|
||||||
|
// and you cannot conveniently tell what the enclosing context will be;
|
||||||
|
// in these cases, it is simpler (and less error prone) to place all
|
||||||
|
// of the information in the registration site, which will be cross-checked
|
||||||
|
// in the end in any case (and if it turns out you DON'T have the right
|
||||||
|
// information at the site, as is the case with backend specific
|
||||||
|
// per-op registrations, you will get the right behavior!)
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
*ns_opt == *ns_,
|
||||||
|
"Explicitly provided namespace (", *ns_opt, ") in schema string "
|
||||||
|
"does not match namespace of enclsing ", kind_, " block (", *ns_, "). "
|
||||||
|
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
|
||||||
|
"(and consider deleting the namespace from your schema string.) ",
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
|
||||||
|
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||||
|
}
|
||||||
|
if (out_name) {
|
||||||
|
*out_name = schema.operator_name(); // copy!
|
||||||
|
}
|
||||||
|
registrars_.emplace_back(
|
||||||
|
Dispatcher::singleton().registerDef(
|
||||||
|
std::move(schema),
|
||||||
|
debugString("", file_, line_)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
#undef DEF_PRELUDE
|
||||||
|
|
||||||
|
Library& Library::_def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
||||||
|
FunctionSchema schema = [&] {
|
||||||
|
if (name_or_schema.is_right()) {
|
||||||
|
return std::move(name_or_schema).right();
|
||||||
|
} else {
|
||||||
|
// it's a name; use the inferred schema
|
||||||
|
OperatorName name = std::move(name_or_schema).left();
|
||||||
|
TORCH_CHECK(f.schema_,
|
||||||
|
"def(\"", name, "\"): "
|
||||||
|
"Full schema string was not specified, and we couldn't infer schema either. ",
|
||||||
|
"Please explicitly provide a schema string. ",
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
|
||||||
|
s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
OperatorName name("", ""); // Get the namespaced name for the impl call
|
||||||
|
// First define the schema...
|
||||||
|
_def(std::move(schema), &name);
|
||||||
|
// Then register the implementation...
|
||||||
|
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||||
|
registrars_.emplace_back(
|
||||||
|
Dispatcher::singleton().registerImpl(
|
||||||
|
std::move(name),
|
||||||
|
dispatch_key,
|
||||||
|
std::move(f.func_),
|
||||||
|
std::move(f.schema_),
|
||||||
|
debugString(std::move(f.debug_), file_, line_)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
|
||||||
|
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||||
|
auto name = torch::jit::parseName(name_str);
|
||||||
|
auto ns_opt = name.getNamespace();
|
||||||
|
// This is kind of similar to the checking in def(), but the error
|
||||||
|
// messages are a little different for this call site
|
||||||
|
if (ns_opt.has_value()) {
|
||||||
|
// See Note [Redundancy in registration code is OK]
|
||||||
|
TORCH_CHECK(*ns_opt == *ns_,
|
||||||
|
IMPL_PRELUDE,
|
||||||
|
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
||||||
|
"does not match namespace of enclosing ", kind_, " block (", *ns_, "). "
|
||||||
|
"Move this definition to the ", kind_, " block corresponding to this namespace "
|
||||||
|
"(and consider deleting the namespace from your schema string.) ",
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
||||||
|
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||||
|
}
|
||||||
|
// See Note [Redundancy in registration code is OK]
|
||||||
|
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
|
||||||
|
dispatch_key_.has_value() &&
|
||||||
|
*f.dispatch_key_ != *dispatch_key_),
|
||||||
|
IMPL_PRELUDE,
|
||||||
|
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
|
||||||
|
"with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). "
|
||||||
|
"Please declare a separate ", kind_, " block for this dispatch key and "
|
||||||
|
"move your impl() there. "
|
||||||
|
ERROR_CONTEXT
|
||||||
|
);
|
||||||
|
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||||
|
registrars_.emplace_back(
|
||||||
|
Dispatcher::singleton().registerImpl(
|
||||||
|
std::move(name),
|
||||||
|
dispatch_key,
|
||||||
|
std::move(f.func_),
|
||||||
|
std::move(f.schema_),
|
||||||
|
debugString(std::move(f.debug_), file_, line_)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
#undef IMPL_PRELUDE
|
||||||
|
|
||||||
|
Library& Library::_fallback(CppFunction&& f) & {
|
||||||
|
TORCH_CHECK(kind_ == IMPL,
|
||||||
|
"fallback(...): Cannot define an operator inside of a ", kind_, " block. "
|
||||||
|
"Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ",
|
||||||
|
ERROR_CONTEXT);
|
||||||
|
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||||
|
TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
|
||||||
|
TORCH_CHECK(!ns_.has_value(),
|
||||||
|
"fallback(...): Fallback functions which apply to only a single namespace ",
|
||||||
|
"(you specified ", *ns_, ") are not supported. If you intended to apply ",
|
||||||
|
"this fallback function globally, please define a separate block:\n\n",
|
||||||
|
" TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
|
||||||
|
ERROR_CONTEXT);
|
||||||
|
registrars_.emplace_back(
|
||||||
|
Dispatcher::singleton().registerFallback(
|
||||||
|
*dispatch_key,
|
||||||
|
std::move(f.func_),
|
||||||
|
debugString(std::move(f.debug_), file_, line_)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -594,9 +594,407 @@ private:
|
|||||||
std::vector<RegistrationHandleRAII> registrars_;
|
std::vector<RegistrationHandleRAII> registrars_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// New style API
|
||||||
|
//
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
// The basic concept behind the new style API is to be as similar to pybind11's
|
||||||
|
// API as possible.
|
||||||
|
//
|
||||||
|
// A quick tour of a few usage examples:
|
||||||
|
//
|
||||||
|
// // Define a library whose operators live in the namespace 'aten'.
|
||||||
|
// // You must define all of the operators for this library in
|
||||||
|
// // this namespace.
|
||||||
|
// TORCH_LIBRARY(aten, m) {
|
||||||
|
// // Define a schema for an operator, but provide no implementation
|
||||||
|
// m.def("mul(Tensor self, Tensor other) -> Tensor");
|
||||||
|
//
|
||||||
|
// // Define a operator with exactly one implementation for all backends.
|
||||||
|
// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
|
||||||
|
//
|
||||||
|
// // Provide an implementation for a defined operator (you can
|
||||||
|
// // provide multiple; one per backend). We'll take care of calling
|
||||||
|
// // the correct implementation depending on if we get a CPU
|
||||||
|
// // tensor or a CUDA tensor
|
||||||
|
// m.impl("mul", torch::kCPU, &mul_cpu_impl);
|
||||||
|
// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Define implementations for operators for a non-standard backend,
|
||||||
|
// // e.g., XLA (valid values are entries of DispatchKey). These
|
||||||
|
// // operator names are not namespaced; you can define implementations
|
||||||
|
// // for any namespace.
|
||||||
|
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||||
|
// m.impl("mul", &mul_xla_impl);
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
|
// Represents a C++ function that implements an operator. Most users won't
|
||||||
|
// interact directly with this class, except via error messages: the
|
||||||
|
// constructors this function define the set of permissible "function"-like
|
||||||
|
// things you can bind via the interface.
|
||||||
|
//
|
||||||
|
// This class erases the type of the passed in function, but durably records
|
||||||
|
// the type via an inferred schema for the function.
|
||||||
|
//
|
||||||
|
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
|
||||||
|
// opaque to the user.
|
||||||
|
class CAFFE2_API CppFunction final {
|
||||||
|
public:
|
||||||
|
// This overload accepts function pointers, e.g., CppFunction(&add_impl)
|
||||||
|
template <typename Func>
|
||||||
|
explicit CppFunction(Func* f, std::enable_if_t<guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
|
||||||
|
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
|
||||||
|
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||||
|
, schema_(detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Func>>>())
|
||||||
|
, debug_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
// This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... })
|
||||||
|
template <typename Lambda>
|
||||||
|
explicit CppFunction(Lambda&& f, std::enable_if_t<guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
|
||||||
|
: func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
|
||||||
|
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||||
|
, schema_(detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>())
|
||||||
|
, debug_()
|
||||||
|
{}
|
||||||
|
|
||||||
|
// This static factory lets you create CppFunctions that (1) don't have boxing
|
||||||
|
// wrappers (because we don't support it yet) and (2) don't have schema
|
||||||
|
// inference (because some ops don't support it).
|
||||||
|
//
|
||||||
|
// TODO: Eliminate the necessity for this function entirely.
|
||||||
|
template <typename Func>
|
||||||
|
static CppFunction makeUnboxedOnly(Func* f) {
|
||||||
|
return CppFunction(
|
||||||
|
c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f),
|
||||||
|
/* schema */ nullptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: more user friendly API
|
||||||
|
static CppFunction makeFallthrough() {
|
||||||
|
return CppFunction(
|
||||||
|
c10::KernelFunction::makeFallthrough(),
|
||||||
|
/* schema */ nullptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: more user friendly API
|
||||||
|
template<KernelFunction::BoxedKernelFunction* func>
|
||||||
|
static CppFunction makeFromBoxedFunction() {
|
||||||
|
return CppFunction(
|
||||||
|
c10::KernelFunction::makeFromBoxedFunction<func>(),
|
||||||
|
/* schema */ nullptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
CppFunction&& debug(std::string d) && {
|
||||||
|
debug_ = std::move(d);
|
||||||
|
return std::move(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||||
|
c10::KernelFunction func_;
|
||||||
|
std::unique_ptr<c10::FunctionSchema> schema_;
|
||||||
|
std::string debug_;
|
||||||
|
|
||||||
|
// The "setter" for dispatch_key_
|
||||||
|
template <typename Func>
|
||||||
|
friend CppFunction dispatch(c10::DispatchKey, Func&&);
|
||||||
|
|
||||||
|
// The only class which actually pulls out values from CppFunction (does so
|
||||||
|
// destructively, felt too lazy to write accessors that I don't even
|
||||||
|
// want users to use)
|
||||||
|
friend class Library;
|
||||||
|
|
||||||
|
CppFunction(KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a CppFunction which is associated with a specific dispatch key.
|
||||||
|
// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/
|
||||||
|
// the dispatcher determines that the DispatchKey is the best choice for
|
||||||
|
// a function
|
||||||
|
template <typename Func>
|
||||||
|
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
|
||||||
|
CppFunction f(std::forward<Func>(raw_f));
|
||||||
|
if (k == c10::DispatchKey::CatchAll) {
|
||||||
|
f.dispatch_key_ = c10::nullopt;
|
||||||
|
} else {
|
||||||
|
f.dispatch_key_ = k;
|
||||||
|
}
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience overload of dispatch which accepts DeviceType
|
||||||
|
template <typename Func>
|
||||||
|
inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
|
||||||
|
auto deviceTypeToDispatchKey = [](DeviceType t){
|
||||||
|
switch (t) {
|
||||||
|
// This list is synchronized with the k-constants in c10/core/DeviceType.h
|
||||||
|
case DeviceType::CPU:
|
||||||
|
return c10::DispatchKey::CPU;
|
||||||
|
case DeviceType::CUDA:
|
||||||
|
return c10::DispatchKey::CUDA;
|
||||||
|
case DeviceType::XLA:
|
||||||
|
return c10::DispatchKey::XLA;
|
||||||
|
case DeviceType::HIP:
|
||||||
|
return c10::DispatchKey::HIP;
|
||||||
|
case DeviceType::MSNPU:
|
||||||
|
return c10::DispatchKey::MSNPU;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"Device type ", t, " cannot be overloaded at dispatch time, "
|
||||||
|
"please file a bug report explaining what you were trying to do.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline FunctionSchema schema(const char* str, AliasAnalysisKind k) {
|
||||||
|
FunctionSchema s = torch::jit::parseSchema(str);
|
||||||
|
s.setAliasAnalysis(k);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
inline FunctionSchema schema(const char* s) {
|
||||||
|
return schema(s, AliasAnalysisKind::FROM_SCHEMA);
|
||||||
|
}
|
||||||
|
inline FunctionSchema&& schema(FunctionSchema&& s) { return std::move(s); }
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(FunctionSchema&& s) {
|
||||||
|
return c10::make_right<OperatorName, FunctionSchema>(std::move(s));
|
||||||
|
}
|
||||||
|
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(OperatorName&& n) {
|
||||||
|
return c10::make_left<OperatorName, FunctionSchema>(std::move(n));
|
||||||
|
}
|
||||||
|
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(const char* str) {
|
||||||
|
auto s = torch::jit::parseSchemaOrName(str);
|
||||||
|
if (s.is_right()) {
|
||||||
|
s.right().setAliasAnalysis(AliasAnalysisKind::FROM_SCHEMA);
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
class TorchLibraryInit;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is the "handle" by which functions defined in TORCH_LIBRARY
|
||||||
|
// and TORCH_LIBRARY_IMPL can define operators and override implementations
|
||||||
|
// at certain backends.
|
||||||
|
//
|
||||||
|
// Conventionally, you get access to it using those two macros:
|
||||||
|
//
|
||||||
|
// TORCH_LIBRARY(torchvision, m) {
|
||||||
|
// // m is a c10::Library
|
||||||
|
// m.def("roi_align", ...);
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||||
|
// // m is a c10::Library
|
||||||
|
// m.impl("add", ...);
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// In some cases, you need to define something that applies to all namespaces,
|
||||||
|
// not just one namespace (usually a fallback). In that case, use the reserved
|
||||||
|
// namespace _, e.g.,
|
||||||
|
//
|
||||||
|
// TORCH_LIBRARY_IMPL(_, XLA, m) {
|
||||||
|
// m.fallback(xla_fallback);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
class CAFFE2_API Library final {
|
||||||
|
public:
|
||||||
|
// Which type of macro produced this Library
|
||||||
|
enum Kind {
|
||||||
|
DEF, // from TORCH_LIBRARY (no qualifier)
|
||||||
|
IMPL,
|
||||||
|
FRAGMENT,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly
|
||||||
|
Library(Kind kind, std::string ns, c10::optional<DispatchKey> k, const char* file, uint32_t line);
|
||||||
|
|
||||||
|
Library(const Library&) = delete;
|
||||||
|
Library& operator=(const Library&) = delete;
|
||||||
|
Library(Library&&) = default;
|
||||||
|
Library& operator=(Library&&) = default;
|
||||||
|
|
||||||
|
// Some notes about the API design here. We had the following constraints:
|
||||||
|
//
|
||||||
|
// - We need to support multiple "types" of arguments for schema and
|
||||||
|
// functions (e.g., unnamed lambda types, regular functions, const char*,
|
||||||
|
// fully instantiated schemas)
|
||||||
|
// - We don't want to write exponentially many overloads
|
||||||
|
// - We don't want to rely on implicit conversion to a common type,
|
||||||
|
// because the C++ compiler will only be willing to do a single
|
||||||
|
// implicit conversion (reducing the set of valid types which you
|
||||||
|
// can invoke with); also error messages are worse when an implicit
|
||||||
|
// conversion is not selected (as the compiler will not explain
|
||||||
|
// why it didn't select an implicit conversion; this is different
|
||||||
|
// from overloads where it will explain each candidate overload and
|
||||||
|
// why it didn't apply)
|
||||||
|
//
|
||||||
|
// To solve all of these constraints at the same time, we use a trick taken
|
||||||
|
// from the pybind11 library: template over the argument in the user visible
|
||||||
|
// API, and inside of the templated function explicitly call an overloaded
|
||||||
|
// function to resolve the argument to a real type. You get the good error
|
||||||
|
// messages from overloads, but at the same time you only need to write the
|
||||||
|
// overload for any given argument type once.
|
||||||
|
|
||||||
|
// Declare an operator with a schema, but don't provide any implementations
|
||||||
|
// for it. You're expected to then provide implementations using the
|
||||||
|
// impl() method.
|
||||||
|
template <typename Schema>
|
||||||
|
Library& def(Schema&& raw_schema) & {
|
||||||
|
FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||||
|
return _def(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience method to define an operator for a schema and then register
|
||||||
|
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
|
||||||
|
// except that if n is not a schema, then the schema is inferred from the
|
||||||
|
// static type of f.
|
||||||
|
template <typename NameOrSchema, typename Func>
|
||||||
|
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
|
||||||
|
CppFunction f(std::forward<Func>(raw_f));
|
||||||
|
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
|
||||||
|
return _def(std::move(name_or_schema), std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register an implementation for an operator. You may register multiple
|
||||||
|
// implementations for a single operator at different dispatch keys
|
||||||
|
// (see torch::dispatch). Implementations must have a corresponding
|
||||||
|
// declaration (from def), otherwise they are invalid.
|
||||||
|
template <typename Func>
|
||||||
|
Library& impl(const char* name, Func&& raw_f) & {
|
||||||
|
CppFunction f(std::forward<Func>(raw_f));
|
||||||
|
return _impl(name, std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience overload for directly specifying the dispatch key. Dispatch
|
||||||
|
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
|
||||||
|
// the canonical list of accepted overloads.
|
||||||
|
template <typename Dispatch, typename Func>
|
||||||
|
Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
|
||||||
|
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience overload for unboxed only kernels. These are quite common
|
||||||
|
// but will be eventually eliminated; this function makes it easy to grep for
|
||||||
|
// them.
|
||||||
|
//
|
||||||
|
// TODO: Remove this overload once the makeUnboxedOnly incidence rate
|
||||||
|
// goes way down
|
||||||
|
template <typename Func>
|
||||||
|
Library& impl_UNBOXED(const char* name, Func* raw_f) & {
|
||||||
|
return impl(name, CppFunction::makeUnboxedOnly(raw_f));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register a fallback implementation for all operators which will be used
|
||||||
|
// if there is not a specific implementation for an operator available.
|
||||||
|
// Providing a DispatchKey is MANDATORY for fallback at the moment; e.g.,
|
||||||
|
// only call this from TORCH_LIBRARY_IMPL
|
||||||
|
template <typename Func>
|
||||||
|
Library& fallback(Func&& raw_f) & {
|
||||||
|
CppFunction f((std::forward<Func>(raw_f)));
|
||||||
|
return _fallback(std::move(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Kind kind_;
|
||||||
|
c10::optional<std::string> ns_;
|
||||||
|
c10::optional<DispatchKey> dispatch_key_;
|
||||||
|
const char* file_;
|
||||||
|
uint32_t line_;
|
||||||
|
|
||||||
|
std::vector<RegistrationHandleRAII> registrars_;
|
||||||
|
|
||||||
|
friend detail::TorchLibraryInit;
|
||||||
|
|
||||||
|
// Non-user visible actual implementations of functions. These aren't
|
||||||
|
// public because we only implement & qualifier and not && qualifier
|
||||||
|
Library& _def(FunctionSchema&& schema, OperatorName* out_name = nullptr) &;
|
||||||
|
Library& _def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
|
||||||
|
Library& _impl(const char* name, CppFunction&& f) &;
|
||||||
|
Library& _fallback(CppFunction&& f) &;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
class TorchLibraryInit final {
|
||||||
|
private:
|
||||||
|
using InitFn = void(Library&);
|
||||||
|
Library lib_;
|
||||||
|
public:
|
||||||
|
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<DispatchKey> k, const char* file, uint32_t line)
|
||||||
|
: lib_(kind, ns, k, file, line) {
|
||||||
|
fn(lib_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
||||||
|
// NB: The EXACT NAMING of the initializer functions (e.g.,
|
||||||
|
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
|
||||||
|
// see the regexes at tools/code_analyzer/run_analyzer.sh
|
||||||
|
|
||||||
|
#define TORCH_LIBRARY(ns, m) \
|
||||||
|
static void TORCH_LIBRARY_init_ ## ns (c10::Library&); \
|
||||||
|
static c10::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
|
||||||
|
c10::Library::DEF, \
|
||||||
|
&TORCH_LIBRARY_init_ ## ns, \
|
||||||
|
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||||
|
); \
|
||||||
|
void TORCH_LIBRARY_init_ ## ns (c10::Library& m)
|
||||||
|
|
||||||
|
// This macro is a version of TORCH_LIBRARY that doesn't enforce that there
|
||||||
|
// is only one library (it is a "fragment"). This should ONLY be used
|
||||||
|
// with PerOpRegistration (as its name suggests).
|
||||||
|
#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \
|
||||||
|
static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library&); \
|
||||||
|
static c10::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \
|
||||||
|
c10::Library::FRAGMENT, \
|
||||||
|
&TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \
|
||||||
|
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||||
|
); \
|
||||||
|
void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library& m)
|
||||||
|
|
||||||
|
#define TORCH_LIBRARY_IMPL(ns, k, m) \
|
||||||
|
static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library&); \
|
||||||
|
static c10::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \
|
||||||
|
c10::Library::IMPL, \
|
||||||
|
& TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \
|
||||||
|
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \
|
||||||
|
); \
|
||||||
|
void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library& m)
|
||||||
|
|
||||||
|
// These are variants of the macros above which are to be used for testing (they
|
||||||
|
// don't setup the static initializer, so you can control the visibility of
|
||||||
|
// the allocated library yourself).
|
||||||
|
//
|
||||||
|
// DO NOT use these in production code, they are NOT understood by the
|
||||||
|
// code analyzer and will be incorrectly analyzed in those situations.
|
||||||
|
#define MAKE_TORCH_LIBRARY(ns) Library(Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
|
||||||
|
#define MAKE_TORCH_LIBRARY_IMPL(ns, k) Library(Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
// Old-style API
|
// Old-style API
|
||||||
using RegisterOperators = c10::RegisterOperators;
|
using RegisterOperators = c10::RegisterOperators;
|
||||||
|
|
||||||
|
// New-style API
|
||||||
|
using c10::dispatch;
|
||||||
|
using c10::schema;
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
|
|
||||||
#include <ATen/core/boxing/impl/test_helpers.h>
|
#include <ATen/core/boxing/impl/test_helpers.h>
|
||||||
#include <ATen/core/op_registration/op_registration.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <torch/library.h>
|
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
@ -22,10 +21,7 @@ using c10::OperatorHandle;
|
|||||||
using c10::Dispatcher;
|
using c10::Dispatcher;
|
||||||
using c10::IValue;
|
using c10::IValue;
|
||||||
using c10::DispatchKey;
|
using c10::DispatchKey;
|
||||||
|
using c10::Library;
|
||||||
using torch::Library;
|
|
||||||
using torch::CppFunction;
|
|
||||||
|
|
||||||
using at::Tensor;
|
using at::Tensor;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -1446,7 +1442,7 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {
|
|||||||
|
|
||||||
TEST(NewOperatorRegistrationTest, fallback) {
|
TEST(NewOperatorRegistrationTest, fallback) {
|
||||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
|
auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
|
||||||
m.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
m.fallback(c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||||
|
|
||||||
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
|
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
|
||||||
|
|
||||||
@ -1499,9 +1495,9 @@ TEST(NewOperatorRegistrationTest, CppFunction) {
|
|||||||
m.def("fn2", dummy_fn);
|
m.def("fn2", dummy_fn);
|
||||||
m.def("fn3", [](const Tensor& x) { return x; });
|
m.def("fn3", [](const Tensor& x) { return x; });
|
||||||
// These require explicit schema
|
// These require explicit schema
|
||||||
m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough());
|
m.def("fn4(Tensor x) -> Tensor", c10::CppFunction::makeFallthrough());
|
||||||
m.def("fn5(Tensor x) -> Tensor", CppFunction::makeUnboxedOnly(dummy_fn));
|
m.def("fn5(Tensor x) -> Tensor", c10::CppFunction::makeUnboxedOnly(dummy_fn));
|
||||||
m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
m.def("fn6(Tensor x) -> Tensor", c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some internal tests that have to be done from C++
|
// Some internal tests that have to be done from C++
|
||||||
|
@ -120,7 +120,7 @@ m.def("${unqual_schema_string}");
|
|||||||
# TORCH_LIBRARY macro invocation
|
# TORCH_LIBRARY macro invocation
|
||||||
DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||||
m.impl("${unqual_operator_name_with_overload}",
|
m.impl("${unqual_operator_name_with_overload}",
|
||||||
torch::CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}));
|
CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}));
|
||||||
""")
|
""")
|
||||||
|
|
||||||
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||||
@ -137,7 +137,7 @@ m.impl("${unqual_operator_name_with_overload}", &TypeDefault::${type_wrapper_nam
|
|||||||
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||||
m.impl("${unqual_operator_name_with_overload}",
|
m.impl("${unqual_operator_name_with_overload}",
|
||||||
torch::dispatch(DispatchKey::${Backend},
|
torch::dispatch(DispatchKey::${Backend},
|
||||||
torch::CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
|
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
|
||||||
);
|
);
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@
|
|||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <ATen/MemoryOverlap.h>
|
#include <ATen/MemoryOverlap.h>
|
||||||
#include <ATen/NamedTensorUtils.h>
|
#include <ATen/NamedTensorUtils.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <ATen/native/ResizeCommon.h>
|
#include <ATen/native/ResizeCommon.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <c10/core/TensorOptions.h>
|
#include <c10/core/TensorOptions.h>
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <ATen/WrapDimUtils.h>
|
#include <ATen/WrapDimUtils.h>
|
||||||
#include <ATen/detail/CUDAHooksInterface.h>
|
#include <ATen/detail/CUDAHooksInterface.h>
|
||||||
#include <ATen/NamedTensorUtils.h>
|
#include <ATen/NamedTensorUtils.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/cuda/Resize.cuh>
|
#include <ATen/native/cuda/Resize.cuh>
|
||||||
#include <ATen/native/ResizeCommon.h>
|
#include <ATen/native/ResizeCommon.h>
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
@ -116,7 +116,7 @@ The final file `ATen/native/quantized/cpu/qxand.cpp` would look as follows
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h> // Need that for the `native_functions.yaml`
|
#include <ATen/NativeFunctions.h> // Need that for the `native_functions.yaml`
|
||||||
#include <ATen/core/Type.h>
|
#include <ATen/core/Type.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpu/vec256/vec256.h>
|
#include <ATen/cpu/vec256/vec256.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/c10_utils.h>
|
#include <ATen/native/c10_utils.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <ATen/SmallVector.h>
|
#include <ATen/SmallVector.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpp_custom_type_hack.h>
|
#include <ATen/cpp_custom_type_hack.h>
|
||||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/Pool.h>
|
#include <ATen/native/Pool.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/cpu/vec256/vec256.h>
|
#include <ATen/cpu/vec256/vec256.h>
|
||||||
#include <ATen/native/SortingUtils.h>
|
#include <ATen/native/SortingUtils.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/cpu/Loops.h>
|
#include <ATen/native/cpu/Loops.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <c10/core/QScheme.h>
|
#include <c10/core/QScheme.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
TORCH_LIBRARY(quantized, m) {
|
TORCH_LIBRARY(quantized, m) {
|
||||||
m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc");
|
m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc");
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#ifdef USE_XNNPACK
|
#ifdef USE_XNNPACK
|
||||||
|
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/native/xnnpack/Convolution.h>
|
#include <ATen/native/xnnpack/Convolution.h>
|
||||||
#include <ATen/native/xnnpack/Linear.h>
|
#include <ATen/native/xnnpack/Linear.h>
|
||||||
#include <ATen/native/xnnpack/OpContext.h>
|
#include <ATen/native/xnnpack/OpContext.h>
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <c10/core/TensorOptions.h>
|
#include <c10/core/TensorOptions.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// ${generated_comment}
|
// ${generated_comment}
|
||||||
|
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/TypeDefault.h>
|
#include <ATen/TypeDefault.h>
|
||||||
$extra_headers
|
$extra_headers
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// ${generated_comment}
|
// ${generated_comment}
|
||||||
|
|
||||||
#include <c10/core/TensorOptions.h>
|
#include <c10/core/TensorOptions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/core/UndefinedTensorImpl.h>
|
#include <c10/core/UndefinedTensorImpl.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
#include <c10/core/TensorOptions.h>
|
#include <c10/core/TensorOptions.h>
|
||||||
#include <ATen/DeviceGuard.h>
|
#include <ATen/DeviceGuard.h>
|
||||||
#include <ATen/SparseTensorUtils.h>
|
#include <ATen/SparseTensorUtils.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static const char* named_tensors_unsupported_error =
|
static const char* named_tensors_unsupported_error =
|
||||||
|
@ -27,7 +27,7 @@ $storage_tensor_headers
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
$extra_cuda_headers
|
$extra_cuda_headers
|
||||||
$legacy_th_headers
|
$legacy_th_headers
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
@ -110,7 +110,7 @@ void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
|||||||
|
|
||||||
TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
|
TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
|
||||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
||||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
m.fallback(CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||||
|
|
||||||
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
|
|||||||
|
|
||||||
TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
||||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper);
|
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper);
|
||||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
|
m.fallback(CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
|
||||||
|
|
||||||
override_call_count = 0;
|
override_call_count = 0;
|
||||||
Tensor a = at::detail::make_tensor<GenericWrapperTensorImpl>(ones({5, 5}, kDouble));
|
Tensor a = at::detail::make_tensor<GenericWrapperTensorImpl>(ones({5, 5}, kDouble));
|
||||||
@ -132,10 +132,10 @@ TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
|||||||
|
|
||||||
TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
|
TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
|
||||||
auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode);
|
auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode);
|
||||||
m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
m.impl("mul.Tensor", c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||||
|
|
||||||
auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
||||||
gm.fallback(torch::CppFunction::makeFallthrough());
|
gm.fallback(c10::CppFunction::makeFallthrough());
|
||||||
|
|
||||||
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
#include <ATen/native/DistributionTemplates.h>
|
#include <ATen/native/DistributionTemplates.h>
|
||||||
#include <ATen/native/cpu/DistributionTemplates.h>
|
#include <ATen/native/cpu/DistributionTemplates.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#include <ATen/Generator.h>
|
#include <ATen/Generator.h>
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
@ -699,12 +699,7 @@ endif()
|
|||||||
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
||||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
||||||
FILES_MATCHING PATTERN "*.h")
|
FILES_MATCHING PATTERN "*.h")
|
||||||
install(FILES
|
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h"
|
||||||
"${TORCH_SRC_DIR}/script.h"
|
|
||||||
"${TORCH_SRC_DIR}/extension.h"
|
|
||||||
"${TORCH_SRC_DIR}/custom_class.h"
|
|
||||||
"${TORCH_SRC_DIR}/library.h"
|
|
||||||
"${TORCH_SRC_DIR}/custom_class_detail.h"
|
|
||||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
|
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +64,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \
|
|||||||
../../../torch/csrc/jit/runtime/custom_operator.h \
|
../../../torch/csrc/jit/runtime/custom_operator.h \
|
||||||
../../../torch/csrc/jit/serialization/import.h \
|
../../../torch/csrc/jit/serialization/import.h \
|
||||||
../../../torch/csrc/jit/api/module.h \
|
../../../torch/csrc/jit/api/module.h \
|
||||||
../../../torch/library.h \
|
|
||||||
../../../torch/custom_class.h
|
../../../torch/custom_class.h
|
||||||
# Don't include .cpp files!
|
# Don't include .cpp files!
|
||||||
FILE_PATTERNS = *.h
|
FILE_PATTERNS = *.h
|
||||||
|
@ -66,71 +66,78 @@ struct PickleTester : torch::CustomClassHolder {
|
|||||||
std::vector<int64_t> vals;
|
std::vector<int64_t> vals;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static auto test = torch::class_<Foo>("_TorchScriptTesting", "_Foo")
|
||||||
|
.def(torch::init<int64_t, int64_t>())
|
||||||
|
// .def(torch::init<>())
|
||||||
|
.def("info", &Foo::info)
|
||||||
|
.def("increment", &Foo::increment)
|
||||||
|
.def("add", &Foo::add)
|
||||||
|
.def("combine", &Foo::combine);
|
||||||
|
|
||||||
|
static auto testStack =
|
||||||
|
torch::class_<MyStackClass<std::string>>(
|
||||||
|
"_TorchScriptTesting",
|
||||||
|
"_StackString")
|
||||||
|
.def(torch::init<std::vector<std::string>>())
|
||||||
|
.def("push", &MyStackClass<std::string>::push)
|
||||||
|
.def("pop", &MyStackClass<std::string>::pop)
|
||||||
|
.def("clone", &MyStackClass<std::string>::clone)
|
||||||
|
.def("merge", &MyStackClass<std::string>::merge)
|
||||||
|
.def_pickle(
|
||||||
|
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
|
||||||
|
return self->stack_;
|
||||||
|
},
|
||||||
|
[](std::vector<std::string> state) { // __setstate__
|
||||||
|
return c10::make_intrusive<MyStackClass<std::string>>(
|
||||||
|
std::vector<std::string>{"i", "was", "deserialized"});
|
||||||
|
})
|
||||||
|
.def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
|
||||||
|
.def(
|
||||||
|
"top",
|
||||||
|
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
|
||||||
|
-> std::string { return self->stack_.back(); });
|
||||||
|
// clang-format off
|
||||||
|
// The following will fail with a static assert telling you you have to
|
||||||
|
// take an intrusive_ptr<MyStackClass> as the first argument.
|
||||||
|
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
static auto testPickle =
|
||||||
|
torch::class_<PickleTester>("_TorchScriptTesting", "_PickleTester")
|
||||||
|
.def(torch::init<std::vector<int64_t>>())
|
||||||
|
.def_pickle(
|
||||||
|
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
||||||
|
return std::vector<int64_t>{1, 3, 3, 7};
|
||||||
|
},
|
||||||
|
[](std::vector<int64_t> state) { // __setstate__
|
||||||
|
return c10::make_intrusive<PickleTester>(std::move(state));
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"top",
|
||||||
|
[](const c10::intrusive_ptr<PickleTester>& self) {
|
||||||
|
return self->vals.back();
|
||||||
|
})
|
||||||
|
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
|
||||||
|
auto val = self->vals.back();
|
||||||
|
self->vals.pop_back();
|
||||||
|
return val;
|
||||||
|
});
|
||||||
|
|
||||||
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
|
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
|
||||||
return torch::zeros({instance->vals.back(), 4});
|
return torch::zeros({instance->vals.back(), 4});
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY(_TorchScriptTesting, m) {
|
torch::RegisterOperators& register_take_instance() {
|
||||||
m.class_<Foo>("_Foo")
|
static auto instance_registry = torch::RegisterOperators().op(
|
||||||
.def(torch::init<int64_t, int64_t>())
|
torch::RegisterOperators::options()
|
||||||
// .def(torch::init<>())
|
.schema(
|
||||||
.def("info", &Foo::info)
|
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y")
|
||||||
.def("increment", &Foo::increment)
|
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
|
||||||
.def("add", &Foo::add)
|
return instance_registry;
|
||||||
.def("combine", &Foo::combine);
|
|
||||||
|
|
||||||
m.class_<MyStackClass<std::string>>("_StackString")
|
|
||||||
.def(torch::init<std::vector<std::string>>())
|
|
||||||
.def("push", &MyStackClass<std::string>::push)
|
|
||||||
.def("pop", &MyStackClass<std::string>::pop)
|
|
||||||
.def("clone", &MyStackClass<std::string>::clone)
|
|
||||||
.def("merge", &MyStackClass<std::string>::merge)
|
|
||||||
.def_pickle(
|
|
||||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
|
|
||||||
return self->stack_;
|
|
||||||
},
|
|
||||||
[](std::vector<std::string> state) { // __setstate__
|
|
||||||
return c10::make_intrusive<MyStackClass<std::string>>(
|
|
||||||
std::vector<std::string>{"i", "was", "deserialized"});
|
|
||||||
})
|
|
||||||
.def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
|
|
||||||
.def(
|
|
||||||
"top",
|
|
||||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
|
|
||||||
-> std::string { return self->stack_.back(); });
|
|
||||||
// clang-format off
|
|
||||||
// The following will fail with a static assert telling you you have to
|
|
||||||
// take an intrusive_ptr<MyStackClass> as the first argument.
|
|
||||||
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
m.class_<PickleTester>("_PickleTester")
|
|
||||||
.def(torch::init<std::vector<int64_t>>())
|
|
||||||
.def_pickle(
|
|
||||||
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
|
||||||
return std::vector<int64_t>{1, 3, 3, 7};
|
|
||||||
},
|
|
||||||
[](std::vector<int64_t> state) { // __setstate__
|
|
||||||
return c10::make_intrusive<PickleTester>(std::move(state));
|
|
||||||
})
|
|
||||||
.def(
|
|
||||||
"top",
|
|
||||||
[](const c10::intrusive_ptr<PickleTester>& self) {
|
|
||||||
return self->vals.back();
|
|
||||||
})
|
|
||||||
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
|
|
||||||
auto val = self->vals.back();
|
|
||||||
self->vals.pop_back();
|
|
||||||
return val;
|
|
||||||
});
|
|
||||||
|
|
||||||
m.def(
|
|
||||||
"take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
|
|
||||||
take_an_instance);
|
|
||||||
// test that schema inference is ok too
|
|
||||||
m.def("take_an_instance_inferred", take_an_instance);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static auto& ensure_take_instance_registered = register_take_instance();
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void testTorchbindIValueAPI() {
|
void testTorchbindIValueAPI() {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/library.h>
|
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ namespace {
|
|||||||
// cares about the name
|
// cares about the name
|
||||||
TORCH_LIBRARY(_test, m) {
|
TORCH_LIBRARY(_test, m) {
|
||||||
m.def("AA(Tensor self) -> Tensor");
|
m.def("AA(Tensor self) -> Tensor");
|
||||||
m.impl("AA", torch::CppFunction::makeUnboxedOnly(AA_op));
|
m.impl("AA", CppFunction::makeUnboxedOnly(AA_op));
|
||||||
|
|
||||||
m.def("BB(Tensor self) -> Tensor");
|
m.def("BB(Tensor self) -> Tensor");
|
||||||
m.impl("BB", &BB_op);
|
m.impl("BB", &BB_op);
|
||||||
@ -93,7 +93,7 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) {
|
|||||||
|
|
||||||
TORCH_LIBRARY_IMPL(_test, CPU, m) {
|
TORCH_LIBRARY_IMPL(_test, CPU, m) {
|
||||||
m.impl_UNBOXED("EE", EE_op);
|
m.impl_UNBOXED("EE", EE_op);
|
||||||
m.impl("FF", torch::CppFunction::makeUnboxedOnly(FF_op));
|
m.impl("FF", CppFunction::makeUnboxedOnly(FF_op));
|
||||||
m.impl("GG",
|
m.impl("GG",
|
||||||
[] (Tensor a) -> Tensor {
|
[] (Tensor a) -> Tensor {
|
||||||
return call_FF_op(a);
|
return call_FF_op(a);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
||||||
|
|
||||||
#include <ATen/TypeDefault.h>
|
#include <ATen/TypeDefault.h>
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
|
|
||||||
// ${generated_comment}
|
// ${generated_comment}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ echo "Analyze: ${INPUT}"
|
|||||||
# to operate, so for safety we match a more expansive set.
|
# to operate, so for safety we match a more expansive set.
|
||||||
"${ANALYZER_BIN}" \
|
"${ANALYZER_BIN}" \
|
||||||
-op_schema_pattern="^(_aten|_prim|aten|quantized|profiler|_test)::[a-zA-Z0-9_.]+(\(.*)?$" \
|
-op_schema_pattern="^(_aten|_prim|aten|quantized|profiler|_test)::[a-zA-Z0-9_.]+(\(.*)?$" \
|
||||||
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|torch::Library::(_?def|_?impl|_?impl_UNBOXED)" \
|
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|c10::Library::(_?def|_?impl|_?impl_UNBOXED)" \
|
||||||
-op_invoke_pattern="c10::Dispatcher::findSchema|callOp" \
|
-op_invoke_pattern="c10::Dispatcher::findSchema|callOp" \
|
||||||
-root_symbol_pattern="torch::jit::[^(]" \
|
-root_symbol_pattern="torch::jit::[^(]" \
|
||||||
-torch_library_init_pattern="^.*TORCH_LIBRARY_init_([^(]+)(\(.*)?$" \
|
-torch_library_init_pattern="^.*TORCH_LIBRARY_init_([^(]+)(\(.*)?$" \
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
#include "function.h"
|
#include "function.h"
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <torch/csrc/jit/runtime/instruction.h>
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||||
#include <torch/custom_class_detail.h>
|
#include <torch/custom_class_detail.h>
|
||||||
#include <torch/library.h>
|
|
||||||
#include "interpreter.h"
|
#include "interpreter.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/TypeDefault.h>
|
#include <ATen/TypeDefault.h>
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
#include <torch/csrc/autograd/function.h>
|
#include <torch/csrc/autograd/function.h>
|
||||||
#include <torch/library.h>
|
|
||||||
|
|
||||||
using Stack = std::vector<c10::IValue>;
|
using Stack = std::vector<c10::IValue>;
|
||||||
using at::Scalar;
|
using at::Scalar;
|
||||||
@ -104,12 +104,12 @@ void log_softmax_kernel(const c10::OperatorHandle& op, Stack* stack) {
|
|||||||
TORCH_LIBRARY_IMPL(_aten, Autograd, m) {
|
TORCH_LIBRARY_IMPL(_aten, Autograd, m) {
|
||||||
m.impl("add.Scalar", torch::autograd::VariableType::add_Scalar);
|
m.impl("add.Scalar", torch::autograd::VariableType::add_Scalar);
|
||||||
m.impl("mul.Tensor", torch::autograd::VariableType::mul_Tensor);
|
m.impl("mul.Tensor", torch::autograd::VariableType::mul_Tensor);
|
||||||
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<conv2d_kernel>());
|
m.impl("conv2d", CppFunction::makeFromBoxedFunction<conv2d_kernel>());
|
||||||
m.impl("dropout", VariableType::dropout);
|
m.impl("dropout", VariableType::dropout);
|
||||||
m.impl("feature_dropout", VariableType::feature_dropout);
|
m.impl("feature_dropout", VariableType::feature_dropout);
|
||||||
m.impl(
|
m.impl(
|
||||||
"log_softmax.int",
|
"log_softmax.int",
|
||||||
torch::CppFunction::makeFromBoxedFunction<log_softmax_kernel>());
|
CppFunction::makeFromBoxedFunction<log_softmax_kernel>());
|
||||||
m.impl(
|
m.impl(
|
||||||
"max_pool2d",
|
"max_pool2d",
|
||||||
[](const Tensor& self,
|
[](const Tensor& self,
|
||||||
@ -127,7 +127,7 @@ TORCH_LIBRARY_IMPL(_aten, Autograd, m) {
|
|||||||
ceil_mode);
|
ceil_mode);
|
||||||
});
|
});
|
||||||
m.impl("relu", VariableType::relu);
|
m.impl("relu", VariableType::relu);
|
||||||
m.impl("view", torch::CppFunction::makeFromBoxedFunction<view_kernel>());
|
m.impl("view", CppFunction::makeFromBoxedFunction<view_kernel>());
|
||||||
m.impl("t", VariableType::t);
|
m.impl("t", VariableType::t);
|
||||||
m.impl("addmm", VariableType::addmm);
|
m.impl("addmm", VariableType::addmm);
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include <torch/csrc/utils/python_dispatch.h>
|
#include <torch/csrc/utils/python_dispatch.h>
|
||||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||||
|
|
||||||
#include <torch/library.h>
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include <pybind11/operators.h>
|
#include <pybind11/operators.h>
|
||||||
@ -45,12 +45,12 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
|||||||
|
|
||||||
|
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
inline c10::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||||
auto mb_key = parseDispatchKey(key);
|
auto mb_key = parseDispatchKey(key);
|
||||||
if (mb_key) {
|
if (mb_key) {
|
||||||
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
return c10::dispatch(*mb_key, std::move(raw_f));
|
||||||
} else {
|
} else {
|
||||||
torch::CppFunction f(std::forward<Func>(raw_f));
|
c10::CppFunction f(std::forward<Func>(raw_f));
|
||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -62,16 +62,16 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
.def("schema", &c10::OperatorHandle::schema);
|
.def("schema", &c10::OperatorHandle::schema);
|
||||||
|
|
||||||
// TODO: figure out how to do chaining
|
// TODO: figure out how to do chaining
|
||||||
py::class_<torch::Library>(m, "_DispatchModule")
|
py::class_<c10::Library>(m, "_DispatchModule")
|
||||||
.def("def_", [](py::object self, const char* schema, const char* alias) {
|
.def("def_", [](py::object self, const char* schema, const char* alias) {
|
||||||
self.cast<torch::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
self.cast<c10::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||||
return self;
|
return self;
|
||||||
}, "", py::arg("schema"), py::arg("alias") = "")
|
}, "", py::arg("schema"), py::arg("alias") = "")
|
||||||
// Simulated "legacy" def where alias analysis kind is not set.
|
// Simulated "legacy" def where alias analysis kind is not set.
|
||||||
// Ordinarily this can only be exercised from RegisterOperators() API
|
// Ordinarily this can only be exercised from RegisterOperators() API
|
||||||
// but I am not going to bind that here
|
// but I am not going to bind that here
|
||||||
.def("def_legacy", [](py::object self, const char* schema) {
|
.def("def_legacy", [](py::object self, const char* schema) {
|
||||||
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
self.cast<c10::Library&>().def(torch::jit::parseSchema(schema));
|
||||||
return self;
|
return self;
|
||||||
}, "", py::arg("schema"))
|
}, "", py::arg("schema"))
|
||||||
// We can't conveniently turn Python functions into valid functions
|
// We can't conveniently turn Python functions into valid functions
|
||||||
@ -83,7 +83,7 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
// Mangling scheme: args_rets. One character per.
|
// Mangling scheme: args_rets. One character per.
|
||||||
// t = Tensor
|
// t = Tensor
|
||||||
.def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
.def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||||
self.cast<torch::Library&>().def(
|
self.cast<c10::Library&>().def(
|
||||||
name,
|
name,
|
||||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||||
return a;
|
return a;
|
||||||
@ -94,7 +94,7 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
py::arg("dispatch") = "",
|
py::arg("dispatch") = "",
|
||||||
py::arg("debug") = "default_def_name_t_t")
|
py::arg("debug") = "default_def_name_t_t")
|
||||||
.def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) {
|
.def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) {
|
||||||
self.cast<torch::Library&>().def(
|
self.cast<c10::Library&>().def(
|
||||||
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
||||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||||
return a;
|
return a;
|
||||||
@ -108,7 +108,7 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
// TODO: maybe consider deduplicating the definitions here, it's getting
|
// TODO: maybe consider deduplicating the definitions here, it's getting
|
||||||
// pretty long
|
// pretty long
|
||||||
.def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
.def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||||
self.cast<torch::Library&>().impl(
|
self.cast<c10::Library&>().impl(
|
||||||
name,
|
name,
|
||||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||||
return a;
|
return a;
|
||||||
@ -119,7 +119,7 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
py::arg("dispatch") = "",
|
py::arg("dispatch") = "",
|
||||||
py::arg("debug") = "impl_t_t")
|
py::arg("debug") = "impl_t_t")
|
||||||
.def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
.def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||||
self.cast<torch::Library&>().impl(
|
self.cast<c10::Library&>().impl(
|
||||||
name,
|
name,
|
||||||
dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) {
|
dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) {
|
||||||
return a;
|
return a;
|
||||||
@ -133,9 +133,9 @@ void initDispatchBindings(PyObject* module) {
|
|||||||
// This is a wee bit dodgy right now, but the "underlying" API is much
|
// This is a wee bit dodgy right now, but the "underlying" API is much
|
||||||
// easier to test than the high level (using TORCH_LIBRARY, e.g.)
|
// easier to test than the high level (using TORCH_LIBRARY, e.g.)
|
||||||
if (name.empty()) {
|
if (name.empty()) {
|
||||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
||||||
} else {
|
} else {
|
||||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -5,12 +5,12 @@
|
|||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/jit_type.h>
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/op_registration/infer_schema.h>
|
#include <ATen/core/op_registration/infer_schema.h>
|
||||||
|
#include <ATen/core/op_registration/op_registration.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
#include <c10/util/Metaprogramming.h>
|
#include <c10/util/Metaprogramming.h>
|
||||||
#include <c10/util/TypeList.h>
|
#include <c10/util/TypeList.h>
|
||||||
#include <c10/util/TypeTraits.h>
|
#include <c10/util/TypeTraits.h>
|
||||||
#include <torch/library.h>
|
|
||||||
#include <torch/custom_class_detail.h>
|
#include <torch/custom_class_detail.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -270,14 +270,4 @@ using ::torch::class_;
|
|||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
||||||
template <class CurClass>
|
} // namespace torch
|
||||||
inline class_<CurClass> Library::class_(const std::string& className) {
|
|
||||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
|
||||||
"class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
|
|
||||||
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
|
|
||||||
"(Error occurred at ", file_, ":", line_, ")");
|
|
||||||
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
|
|
||||||
return torch::class_<CurClass>(*ns_, className);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
406
torch/library.h
406
torch/library.h
@ -1,406 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <c10/core/DispatchKey.h>
|
|
||||||
#include <ATen/core/dispatch/Dispatcher.h>
|
|
||||||
#include <ATen/core/op_registration/infer_schema.h>
|
|
||||||
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
|
|
||||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Just for inferFunctionSchemaFromFunctor
|
|
||||||
#include <ATen/core/op_registration/op_registration.h>
|
|
||||||
|
|
||||||
namespace torch {
|
|
||||||
|
|
||||||
template <class CurClass>
|
|
||||||
class class_;
|
|
||||||
|
|
||||||
// A quick tour of a few usage examples:
|
|
||||||
//
|
|
||||||
// // Define a library whose operators live in the namespace 'aten'.
|
|
||||||
// // You must define all of the operators for this library in
|
|
||||||
// // this namespace.
|
|
||||||
// TORCH_LIBRARY(aten, m) {
|
|
||||||
// // Define a schema for an operator, but provide no implementation
|
|
||||||
// m.def("mul(Tensor self, Tensor other) -> Tensor");
|
|
||||||
//
|
|
||||||
// // Define a operator with exactly one implementation for all backends.
|
|
||||||
// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
|
|
||||||
//
|
|
||||||
// // Provide an implementation for a defined operator (you can
|
|
||||||
// // provide multiple; one per backend). We'll take care of calling
|
|
||||||
// // the correct implementation depending on if we get a CPU
|
|
||||||
// // tensor or a CUDA tensor
|
|
||||||
// m.impl("mul", torch::kCPU, &mul_cpu_impl);
|
|
||||||
// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // Define implementations for operators for a non-standard backend,
|
|
||||||
// // e.g., XLA (valid values are entries of DispatchKey). These
|
|
||||||
// // operator names are not namespaced; you can define implementations
|
|
||||||
// // for any namespace.
|
|
||||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
|
||||||
// m.impl("mul", &mul_xla_impl);
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// Represents a C++ function that implements an operator. Most users won't
|
|
||||||
// interact directly with this class, except via error messages: the
|
|
||||||
// constructors this function define the set of permissible "function"-like
|
|
||||||
// things you can bind via the interface.
|
|
||||||
//
|
|
||||||
// This class erases the type of the passed in function, but durably records
|
|
||||||
// the type via an inferred schema for the function.
|
|
||||||
//
|
|
||||||
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
|
|
||||||
// opaque to the user.
|
|
||||||
class CAFFE2_API CppFunction final {
|
|
||||||
public:
|
|
||||||
// This overload accepts function pointers, e.g., CppFunction(&add_impl)
|
|
||||||
template <typename Func>
|
|
||||||
explicit CppFunction(Func* f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
|
|
||||||
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
|
|
||||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
|
||||||
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Func>>>())
|
|
||||||
, debug_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
// This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... })
|
|
||||||
template <typename Lambda>
|
|
||||||
explicit CppFunction(Lambda&& f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
|
|
||||||
: func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
|
|
||||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
|
||||||
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>())
|
|
||||||
, debug_()
|
|
||||||
{}
|
|
||||||
|
|
||||||
// This static factory lets you create CppFunctions that (1) don't have boxing
|
|
||||||
// wrappers (because we don't support it yet) and (2) don't have schema
|
|
||||||
// inference (because some ops don't support it).
|
|
||||||
//
|
|
||||||
// TODO: Eliminate the necessity for this function entirely.
|
|
||||||
template <typename Func>
|
|
||||||
static CppFunction makeUnboxedOnly(Func* f) {
|
|
||||||
return CppFunction(
|
|
||||||
c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f),
|
|
||||||
/* schema */ nullptr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: more user friendly API
|
|
||||||
static CppFunction makeFallthrough() {
|
|
||||||
return CppFunction(
|
|
||||||
c10::KernelFunction::makeFallthrough(),
|
|
||||||
/* schema */ nullptr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: more user friendly API
|
|
||||||
template<c10::KernelFunction::BoxedKernelFunction* func>
|
|
||||||
static CppFunction makeFromBoxedFunction() {
|
|
||||||
return CppFunction(
|
|
||||||
c10::KernelFunction::makeFromBoxedFunction<func>(),
|
|
||||||
/* schema */ nullptr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
CppFunction&& debug(std::string d) && {
|
|
||||||
debug_ = std::move(d);
|
|
||||||
return std::move(*this);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
|
||||||
c10::KernelFunction func_;
|
|
||||||
std::unique_ptr<c10::FunctionSchema> schema_;
|
|
||||||
std::string debug_;
|
|
||||||
|
|
||||||
// The "setter" for dispatch_key_
|
|
||||||
template <typename Func>
|
|
||||||
friend CppFunction dispatch(c10::DispatchKey, Func&&);
|
|
||||||
|
|
||||||
// The only class which actually pulls out values from CppFunction (does so
|
|
||||||
// destructively, felt too lazy to write accessors that I don't even
|
|
||||||
// want users to use)
|
|
||||||
friend class Library;
|
|
||||||
|
|
||||||
CppFunction(c10::KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create a CppFunction which is associated with a specific dispatch key.
|
|
||||||
// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/
|
|
||||||
// the dispatcher determines that the DispatchKey is the best choice for
|
|
||||||
// a function
|
|
||||||
template <typename Func>
|
|
||||||
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
|
|
||||||
CppFunction f(std::forward<Func>(raw_f));
|
|
||||||
if (k == c10::DispatchKey::CatchAll) {
|
|
||||||
f.dispatch_key_ = c10::nullopt;
|
|
||||||
} else {
|
|
||||||
f.dispatch_key_ = k;
|
|
||||||
}
|
|
||||||
return f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience overload of dispatch which accepts DeviceType
|
|
||||||
template <typename Func>
|
|
||||||
inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
|
||||||
auto deviceTypeToDispatchKey = [](c10::DeviceType t){
|
|
||||||
switch (t) {
|
|
||||||
// This list is synchronized with the k-constants in c10/core/DeviceType.h
|
|
||||||
case c10::DeviceType::CPU:
|
|
||||||
return c10::DispatchKey::CPU;
|
|
||||||
case c10::DeviceType::CUDA:
|
|
||||||
return c10::DispatchKey::CUDA;
|
|
||||||
case c10::DeviceType::XLA:
|
|
||||||
return c10::DispatchKey::XLA;
|
|
||||||
case c10::DeviceType::HIP:
|
|
||||||
return c10::DispatchKey::HIP;
|
|
||||||
case c10::DeviceType::MSNPU:
|
|
||||||
return c10::DispatchKey::MSNPU;
|
|
||||||
default:
|
|
||||||
TORCH_CHECK(false,
|
|
||||||
"Device type ", t, " cannot be overloaded at dispatch time, "
|
|
||||||
"please file a bug report explaining what you were trying to do.");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
|
|
||||||
c10::FunctionSchema s = torch::jit::parseSchema(str);
|
|
||||||
s.setAliasAnalysis(k);
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
inline c10::FunctionSchema schema(const char* s) {
|
|
||||||
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
|
|
||||||
}
|
|
||||||
inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); }
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::FunctionSchema&& s) {
|
|
||||||
return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
|
|
||||||
}
|
|
||||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::OperatorName&& n) {
|
|
||||||
return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
|
|
||||||
}
|
|
||||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(const char* str) {
|
|
||||||
auto s = torch::jit::parseSchemaOrName(str);
|
|
||||||
if (s.is_right()) {
|
|
||||||
s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
|
||||||
}
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
|
|
||||||
class TorchLibraryInit;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is the "handle" by which functions defined in TORCH_LIBRARY
|
|
||||||
// and TORCH_LIBRARY_IMPL can define operators and override implementations
|
|
||||||
// at certain backends.
|
|
||||||
//
|
|
||||||
// Conventionally, you get access to it using those two macros:
|
|
||||||
//
|
|
||||||
// TORCH_LIBRARY(torchvision, m) {
|
|
||||||
// // m is a torch::Library
|
|
||||||
// m.def("roi_align", ...);
|
|
||||||
// ...
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
|
||||||
// // m is a torch::Library
|
|
||||||
// m.impl("add", ...);
|
|
||||||
// ...
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// In some cases, you need to define something that applies to all namespaces,
|
|
||||||
// not just one namespace (usually a fallback). In that case, use the reserved
|
|
||||||
// namespace _, e.g.,
|
|
||||||
//
|
|
||||||
// TORCH_LIBRARY_IMPL(_, XLA, m) {
|
|
||||||
// m.fallback(xla_fallback);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
class CAFFE2_API Library final {
|
|
||||||
public:
|
|
||||||
// Which type of macro produced this Library
|
|
||||||
enum Kind {
|
|
||||||
DEF, // from TORCH_LIBRARY (no qualifier)
|
|
||||||
IMPL,
|
|
||||||
FRAGMENT,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly
|
|
||||||
Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line);
|
|
||||||
|
|
||||||
Library(const Library&) = delete;
|
|
||||||
Library& operator=(const Library&) = delete;
|
|
||||||
Library(Library&&) = default;
|
|
||||||
Library& operator=(Library&&) = default;
|
|
||||||
|
|
||||||
// Some notes about the API design here. We had the following constraints:
|
|
||||||
//
|
|
||||||
// - We need to support multiple "types" of arguments for schema and
|
|
||||||
// functions (e.g., unnamed lambda types, regular functions, const char*,
|
|
||||||
// fully instantiated schemas)
|
|
||||||
// - We don't want to write exponentially many overloads
|
|
||||||
// - We don't want to rely on implicit conversion to a common type,
|
|
||||||
// because the C++ compiler will only be willing to do a single
|
|
||||||
// implicit conversion (reducing the set of valid types which you
|
|
||||||
// can invoke with); also error messages are worse when an implicit
|
|
||||||
// conversion is not selected (as the compiler will not explain
|
|
||||||
// why it didn't select an implicit conversion; this is different
|
|
||||||
// from overloads where it will explain each candidate overload and
|
|
||||||
// why it didn't apply)
|
|
||||||
//
|
|
||||||
// To solve all of these constraints at the same time, we use a trick taken
|
|
||||||
// from the pybind11 library: template over the argument in the user visible
|
|
||||||
// API, and inside of the templated function explicitly call an overloaded
|
|
||||||
// function to resolve the argument to a real type. You get the good error
|
|
||||||
// messages from overloads, but at the same time you only need to write the
|
|
||||||
// overload for any given argument type once.
|
|
||||||
|
|
||||||
// Declare an operator with a schema, but don't provide any implementations
|
|
||||||
// for it. You're expected to then provide implementations using the
|
|
||||||
// impl() method.
|
|
||||||
template <typename Schema>
|
|
||||||
Library& def(Schema&& raw_schema) & {
|
|
||||||
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
|
||||||
return _def(std::move(s));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience method to define an operator for a schema and then register
|
|
||||||
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
|
|
||||||
// except that if n is not a schema, then the schema is inferred from the
|
|
||||||
// static type of f.
|
|
||||||
template <typename NameOrSchema, typename Func>
|
|
||||||
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
|
|
||||||
CppFunction f(std::forward<Func>(raw_f));
|
|
||||||
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
|
|
||||||
return _def(std::move(name_or_schema), std::move(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register an implementation for an operator. You may register multiple
|
|
||||||
// implementations for a single operator at different dispatch keys
|
|
||||||
// (see torch::dispatch). Implementations must have a corresponding
|
|
||||||
// declaration (from def), otherwise they are invalid.
|
|
||||||
template <typename Func>
|
|
||||||
Library& impl(const char* name, Func&& raw_f) & {
|
|
||||||
CppFunction f(std::forward<Func>(raw_f));
|
|
||||||
return _impl(name, std::move(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience overload for directly specifying the dispatch key. Dispatch
|
|
||||||
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
|
|
||||||
// the canonical list of accepted overloads.
|
|
||||||
template <typename Dispatch, typename Func>
|
|
||||||
Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
|
|
||||||
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience overload for unboxed only kernels. These are quite common
|
|
||||||
// but will be eventually eliminated; this function makes it easy to grep for
|
|
||||||
// them.
|
|
||||||
//
|
|
||||||
// TODO: Remove this overload once the makeUnboxedOnly incidence rate
|
|
||||||
// goes way down
|
|
||||||
template <typename Func>
|
|
||||||
Library& impl_UNBOXED(const char* name, Func* raw_f) & {
|
|
||||||
return impl(name, CppFunction::makeUnboxedOnly(raw_f));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register a fallback implementation for all operators which will be used
|
|
||||||
// if there is not a specific implementation for an operator available.
|
|
||||||
// Providing a DispatchKey is MANDATORY for fallback at the moment; e.g.,
|
|
||||||
// only call this from TORCH_LIBRARY_IMPL
|
|
||||||
template <typename Func>
|
|
||||||
Library& fallback(Func&& raw_f) & {
|
|
||||||
CppFunction f((std::forward<Func>(raw_f)));
|
|
||||||
return _fallback(std::move(f));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class CurClass>
|
|
||||||
inline class_<CurClass> class_(const std::string& className);
|
|
||||||
|
|
||||||
private:
|
|
||||||
Kind kind_;
|
|
||||||
c10::optional<std::string> ns_;
|
|
||||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
|
||||||
const char* file_;
|
|
||||||
uint32_t line_;
|
|
||||||
|
|
||||||
std::vector<c10::RegistrationHandleRAII> registrars_;
|
|
||||||
|
|
||||||
friend class detail::TorchLibraryInit;
|
|
||||||
|
|
||||||
// Non-user visible actual implementations of functions. These aren't
|
|
||||||
// public because we only implement & qualifier and not && qualifier
|
|
||||||
Library& _def(c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr) &;
|
|
||||||
Library& _def(c10::either<c10::OperatorName, c10::FunctionSchema>&&, CppFunction&& f) &;
|
|
||||||
Library& _impl(const char* name, CppFunction&& f) &;
|
|
||||||
Library& _fallback(CppFunction&& f) &;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
class TorchLibraryInit final {
|
|
||||||
private:
|
|
||||||
using InitFn = void(Library&);
|
|
||||||
Library lib_;
|
|
||||||
public:
|
|
||||||
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
|
||||||
: lib_(kind, ns, k, file, line) {
|
|
||||||
fn(lib_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
} // namespace torch
|
|
||||||
|
|
||||||
// NB: The EXACT NAMING of the initializer functions (e.g.,
|
|
||||||
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
|
|
||||||
// see the regexes at tools/code_analyzer/run_analyzer.sh
|
|
||||||
|
|
||||||
#define TORCH_LIBRARY(ns, m) \
|
|
||||||
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
|
|
||||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
|
|
||||||
torch::Library::DEF, \
|
|
||||||
&TORCH_LIBRARY_init_ ## ns, \
|
|
||||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
|
||||||
); \
|
|
||||||
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
|
|
||||||
|
|
||||||
// This macro is a version of TORCH_LIBRARY that doesn't enforce that there
|
|
||||||
// is only one library (it is a "fragment"). This should ONLY be used
|
|
||||||
// with PerOpRegistration (as its name suggests).
|
|
||||||
#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \
|
|
||||||
static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library&); \
|
|
||||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \
|
|
||||||
torch::Library::FRAGMENT, \
|
|
||||||
&TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \
|
|
||||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
|
||||||
); \
|
|
||||||
void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library& m)
|
|
||||||
|
|
||||||
#define TORCH_LIBRARY_IMPL(ns, k, m) \
|
|
||||||
static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library&); \
|
|
||||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \
|
|
||||||
torch::Library::IMPL, \
|
|
||||||
& TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \
|
|
||||||
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \
|
|
||||||
); \
|
|
||||||
void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library& m)
|
|
||||||
|
|
||||||
// These are variants of the macros above which are to be used for testing (they
|
|
||||||
// don't setup the static initializer, so you can control the visibility of
|
|
||||||
// the allocated library yourself).
|
|
||||||
//
|
|
||||||
// DO NOT use these in production code, they are NOT understood by the
|
|
||||||
// code analyzer and will be incorrectly analyzed in those situations.
|
|
||||||
#define MAKE_TORCH_LIBRARY(ns) torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
|
|
||||||
#define MAKE_TORCH_LIBRARY_IMPL(ns, k) torch::Library(torch::Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
|
|
||||||
|
|
||||||
#include <torch/custom_class.h>
|
|
Reference in New Issue
Block a user