mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert D21089648: Put TORCH_LIBRARY in torch/library.h; add custom class API
Test Plan: revert-hammer Differential Revision: D21089648 Original commit changeset: 8d54329c1252 fbshipit-source-id: 636e8a11afc628a4cdae9d44824985c10c70555e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a05406ea56
commit
2ccdc39dce
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
|
||||
@ -373,7 +373,7 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
|
||||
Explicit registration for out-of-place ops
|
||||
*****************************************/
|
||||
TORCH_LIBRARY_IMPL(_, Autocast, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
m.fallback(c10::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||
|
@ -1,5 +1,9 @@
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
namespace {
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, BackendSelect, m) {
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
m.fallback(c10::CppFunction::makeFallthrough());
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
/*
|
||||
* This file implements a variable fallback kernel for custom operators.
|
||||
@ -65,9 +65,9 @@ TORCH_LIBRARY_IMPL(_, Autograd, m) {
|
||||
//
|
||||
// We can remove this `fallthrough` kernel when all kernels support boxed
|
||||
// call.
|
||||
m.fallback(torch::CppFunction::makeFallthrough());
|
||||
m.fallback(c10::CppFunction::makeFallthrough());
|
||||
#else
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>());
|
||||
m.fallback(c10::CppFunction::makeFromBoxedFunction<&variable_fallback_kernel>());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
|
||||
namespace c10 {
|
||||
|
@ -1,232 +0,0 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
namespace {
|
||||
// TODO: Consider representing debug info as a struct instead so you
|
||||
// don't have to allocate strings all the time
|
||||
std::string debugString(std::string debug, const char* file, uint32_t line) {
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
return "";
|
||||
#else
|
||||
if (debug.empty()) {
|
||||
return c10::str("registered at ", file, ":", line);
|
||||
} else {
|
||||
return debug;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, Library::Kind kind) {
|
||||
switch (kind) {
|
||||
case Library::DEF:
|
||||
os << "TORCH_LIBRARY";
|
||||
break;
|
||||
case Library::IMPL:
|
||||
os << "TORCH_LIBRARY_IMPL";
|
||||
break;
|
||||
case Library::FRAGMENT:
|
||||
os << "TORCH_LIBRARY_FRAGMENT";
|
||||
break;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
}
|
||||
|
||||
CppFunction::CppFunction(c10::KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema)
|
||||
: func_(std::move(func))
|
||||
, schema_(std::move(schema))
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")"
|
||||
|
||||
Library::Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
||||
: kind_(kind)
|
||||
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
||||
, dispatch_key_((!k.has_value() || *k == c10::DispatchKey::CatchAll) ? c10::nullopt : k)
|
||||
, file_(file)
|
||||
, line_(line)
|
||||
{
|
||||
switch (kind_) {
|
||||
case DEF:
|
||||
// Only DEFs require library uniqueness; fragments
|
||||
// don't register a library
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerLibrary(
|
||||
*ns_, debugString("", file_, line_)
|
||||
)
|
||||
);
|
||||
// fallthrough
|
||||
case FRAGMENT:
|
||||
TORCH_CHECK(
|
||||
ns_.has_value(),
|
||||
kind_, ": cannot define ", kind_, " with the wildcard namespace _ "
|
||||
"(every ", kind_, " defines operators for a distinct namespace!)"
|
||||
"Did you mean to use TORCH_LIBRARY_IMPL instead? "
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||
break;
|
||||
case IMPL:
|
||||
// Nothing to do, everything is OK
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Error if an operator is def'ed multiple times. Right now we just
|
||||
// merge everything
|
||||
|
||||
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
||||
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
|
||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||
DEF_PRELUDE,
|
||||
"Cannot define an operator inside of a ", kind_, " block. "
|
||||
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
|
||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||
auto ns_opt = schema.getNamespace();
|
||||
if (ns_opt.has_value()) {
|
||||
// Note [Redundancy in registration code is OK]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In an earlier version of this code, I made it an error to explicitly
|
||||
// specify the namespace, even when the namespaces match. I've decided
|
||||
// to relax this constraint because sometimes we code generate registrations
|
||||
// and you cannot conveniently tell what the enclosing context will be;
|
||||
// in these cases, it is simpler (and less error prone) to place all
|
||||
// of the information in the registration site, which will be cross-checked
|
||||
// in the end in any case (and if it turns out you DON'T have the right
|
||||
// information at the site, as is the case with backend specific
|
||||
// per-op registrations, you will get the right behavior!)
|
||||
TORCH_CHECK(false,
|
||||
*ns_opt == *ns_,
|
||||
"Explicitly provided namespace (", *ns_opt, ") in schema string "
|
||||
"does not match namespace of enclsing ", kind_, " block (", *ns_, "). "
|
||||
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
|
||||
"(and consider deleting the namespace from your schema string.) ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
} else {
|
||||
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
|
||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||
}
|
||||
if (out_name) {
|
||||
*out_name = schema.operator_name(); // copy!
|
||||
}
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
debugString("", file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
#undef DEF_PRELUDE
|
||||
|
||||
Library& Library::_def(c10::either<c10::OperatorName, c10::FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
||||
c10::FunctionSchema schema = [&] {
|
||||
if (name_or_schema.is_right()) {
|
||||
return std::move(name_or_schema).right();
|
||||
} else {
|
||||
// it's a name; use the inferred schema
|
||||
c10::OperatorName name = std::move(name_or_schema).left();
|
||||
TORCH_CHECK(f.schema_,
|
||||
"def(\"", name, "\"): "
|
||||
"Full schema string was not specified, and we couldn't infer schema either. ",
|
||||
"Please explicitly provide a schema string. ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
c10::FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
|
||||
s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
|
||||
return s;
|
||||
}
|
||||
}();
|
||||
c10::OperatorName name("", ""); // Get the namespaced name for the impl call
|
||||
// First define the schema...
|
||||
_def(std::move(schema), &name);
|
||||
// Then register the implementation...
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
|
||||
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||
auto name = torch::jit::parseName(name_str);
|
||||
auto ns_opt = name.getNamespace();
|
||||
// This is kind of similar to the checking in def(), but the error
|
||||
// messages are a little different for this call site
|
||||
if (ns_opt.has_value()) {
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(*ns_opt == *ns_,
|
||||
IMPL_PRELUDE,
|
||||
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
||||
"does not match namespace of enclosing ", kind_, " block (", *ns_, "). "
|
||||
"Move this definition to the ", kind_, " block corresponding to this namespace "
|
||||
"(and consider deleting the namespace from your schema string.) ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
} else {
|
||||
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||
}
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
|
||||
dispatch_key_.has_value() &&
|
||||
*f.dispatch_key_ != *dispatch_key_),
|
||||
IMPL_PRELUDE,
|
||||
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
|
||||
"with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). "
|
||||
"Please declare a separate ", kind_, " block for this dispatch key and "
|
||||
"move your impl() there. "
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
#undef IMPL_PRELUDE
|
||||
|
||||
Library& Library::_fallback(CppFunction&& f) & {
|
||||
TORCH_CHECK(kind_ == IMPL,
|
||||
"fallback(...): Cannot define an operator inside of a ", kind_, " block. "
|
||||
"Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ",
|
||||
ERROR_CONTEXT);
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
|
||||
TORCH_CHECK(!ns_.has_value(),
|
||||
"fallback(...): Fallback functions which apply to only a single namespace ",
|
||||
"(you specified ", *ns_, ") are not supported. If you intended to apply ",
|
||||
"this fallback function globally, please define a separate block:\n\n",
|
||||
" TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
|
||||
ERROR_CONTEXT);
|
||||
registrars_.emplace_back(
|
||||
c10::Dispatcher::singleton().registerFallback(
|
||||
*dispatch_key,
|
||||
std::move(f.func_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
} // namespace torch
|
@ -7,6 +7,37 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace {
|
||||
// TODO: Consider representing debug info as a struct instead so you
|
||||
// don't have to allocate strings all the time
|
||||
std::string debugString(std::string debug, const char* file, uint32_t line) {
|
||||
#ifdef STRIP_ERROR_MESSAGES
|
||||
return "";
|
||||
#else
|
||||
if (debug.empty()) {
|
||||
return c10::str("registered at ", file, ":", line);
|
||||
} else {
|
||||
return debug;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, Library::Kind kind) {
|
||||
switch (kind) {
|
||||
case Library::DEF:
|
||||
os << "TORCH_LIBRARY";
|
||||
break;
|
||||
case Library::IMPL:
|
||||
os << "TORCH_LIBRARY_IMPL";
|
||||
break;
|
||||
case Library::FRAGMENT:
|
||||
os << "TORCH_LIBRARY_FRAGMENT";
|
||||
break;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
}
|
||||
|
||||
static_assert(std::is_nothrow_move_constructible<c10::optional<RegistrationHandleRAII>>::value, "");
|
||||
static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRAII>>::value, "");
|
||||
|
||||
@ -109,4 +140,200 @@ RegisterOperators::~RegisterOperators() = default;
|
||||
RegisterOperators::RegisterOperators(RegisterOperators&&) noexcept = default;
|
||||
RegisterOperators& RegisterOperators::operator=(RegisterOperators&&) noexcept = default;
|
||||
|
||||
} // namespace c10
|
||||
|
||||
CppFunction::CppFunction(KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema)
|
||||
: func_(std::move(func))
|
||||
, schema_(std::move(schema))
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
#define ERROR_CONTEXT "(Error occurred while processing ", kind_, " block at ", file_, ":", line_, ")"
|
||||
|
||||
Library::Library(Kind kind, std::string ns, c10::optional<DispatchKey> k, const char* file, uint32_t line)
|
||||
: kind_(kind)
|
||||
, ns_(ns == "_" ? c10::nullopt : c10::make_optional(std::move(ns)))
|
||||
, dispatch_key_((!k.has_value() || *k == DispatchKey::CatchAll) ? c10::nullopt : k)
|
||||
, file_(file)
|
||||
, line_(line)
|
||||
{
|
||||
switch (kind_) {
|
||||
case DEF:
|
||||
// Only DEFs require library uniqueness; fragments
|
||||
// don't register a library
|
||||
registrars_.emplace_back(
|
||||
Dispatcher::singleton().registerLibrary(
|
||||
*ns_, debugString("", file_, line_)
|
||||
)
|
||||
);
|
||||
// fallthrough
|
||||
case FRAGMENT:
|
||||
TORCH_CHECK(
|
||||
ns_.has_value(),
|
||||
kind_, ": cannot define ", kind_, " with the wildcard namespace _ "
|
||||
"(every ", kind_, " defines operators for a distinct namespace!)"
|
||||
"Did you mean to use TORCH_LIBRARY_IMPL instead? "
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||
break;
|
||||
case IMPL:
|
||||
// Nothing to do, everything is OK
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Error if an operator is def'ed multiple times. Right now we just
|
||||
// merge everything
|
||||
|
||||
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
|
||||
Library& Library::_def(FunctionSchema&& schema, OperatorName* out_name) & {
|
||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||
DEF_PRELUDE,
|
||||
"Cannot define an operator inside of a ", kind_, " block. "
|
||||
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
|
||||
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
|
||||
auto ns_opt = schema.getNamespace();
|
||||
if (ns_opt.has_value()) {
|
||||
// Note [Redundancy in registration code is OK]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// In an earlier version of this code, I made it an error to explicitly
|
||||
// specify the namespace, even when the namespaces match. I've decided
|
||||
// to relax this constraint because sometimes we code generate registrations
|
||||
// and you cannot conveniently tell what the enclosing context will be;
|
||||
// in these cases, it is simpler (and less error prone) to place all
|
||||
// of the information in the registration site, which will be cross-checked
|
||||
// in the end in any case (and if it turns out you DON'T have the right
|
||||
// information at the site, as is the case with backend specific
|
||||
// per-op registrations, you will get the right behavior!)
|
||||
TORCH_CHECK(false,
|
||||
*ns_opt == *ns_,
|
||||
"Explicitly provided namespace (", *ns_opt, ") in schema string "
|
||||
"does not match namespace of enclsing ", kind_, " block (", *ns_, "). "
|
||||
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
|
||||
"(and consider deleting the namespace from your schema string.) ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
} else {
|
||||
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
|
||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||
}
|
||||
if (out_name) {
|
||||
*out_name = schema.operator_name(); // copy!
|
||||
}
|
||||
registrars_.emplace_back(
|
||||
Dispatcher::singleton().registerDef(
|
||||
std::move(schema),
|
||||
debugString("", file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
#undef DEF_PRELUDE
|
||||
|
||||
Library& Library::_def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
|
||||
FunctionSchema schema = [&] {
|
||||
if (name_or_schema.is_right()) {
|
||||
return std::move(name_or_schema).right();
|
||||
} else {
|
||||
// it's a name; use the inferred schema
|
||||
OperatorName name = std::move(name_or_schema).left();
|
||||
TORCH_CHECK(f.schema_,
|
||||
"def(\"", name, "\"): "
|
||||
"Full schema string was not specified, and we couldn't infer schema either. ",
|
||||
"Please explicitly provide a schema string. ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
FunctionSchema s = f.schema_->cloneWithName(std::move(name.name), std::move(name.overload_name));
|
||||
s.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
|
||||
return s;
|
||||
}
|
||||
}();
|
||||
OperatorName name("", ""); // Get the namespaced name for the impl call
|
||||
// First define the schema...
|
||||
_def(std::move(schema), &name);
|
||||
// Then register the implementation...
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
registrars_.emplace_back(
|
||||
Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
|
||||
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
|
||||
auto name = torch::jit::parseName(name_str);
|
||||
auto ns_opt = name.getNamespace();
|
||||
// This is kind of similar to the checking in def(), but the error
|
||||
// messages are a little different for this call site
|
||||
if (ns_opt.has_value()) {
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(*ns_opt == *ns_,
|
||||
IMPL_PRELUDE,
|
||||
"Explicitly provided namespace (", *ns_opt, ") in operator name "
|
||||
"does not match namespace of enclosing ", kind_, " block (", *ns_, "). "
|
||||
"Move this definition to the ", kind_, " block corresponding to this namespace "
|
||||
"(and consider deleting the namespace from your schema string.) ",
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
} else {
|
||||
bool b = name.setNamespaceIfNotSet(ns_->c_str());
|
||||
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
|
||||
}
|
||||
// See Note [Redundancy in registration code is OK]
|
||||
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
|
||||
dispatch_key_.has_value() &&
|
||||
*f.dispatch_key_ != *dispatch_key_),
|
||||
IMPL_PRELUDE,
|
||||
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
|
||||
"with the dispatch key of the enclosing ", kind_, " block (", *dispatch_key_, "). "
|
||||
"Please declare a separate ", kind_, " block for this dispatch key and "
|
||||
"move your impl() there. "
|
||||
ERROR_CONTEXT
|
||||
);
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
registrars_.emplace_back(
|
||||
Dispatcher::singleton().registerImpl(
|
||||
std::move(name),
|
||||
dispatch_key,
|
||||
std::move(f.func_),
|
||||
std::move(f.schema_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
#undef IMPL_PRELUDE
|
||||
|
||||
Library& Library::_fallback(CppFunction&& f) & {
|
||||
TORCH_CHECK(kind_ == IMPL,
|
||||
"fallback(...): Cannot define an operator inside of a ", kind_, " block. "
|
||||
"Did you mean to call this function inside a TORCH_LIBRARY_IMPL block? ",
|
||||
ERROR_CONTEXT);
|
||||
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
|
||||
TORCH_INTERNAL_ASSERT(dispatch_key.has_value(), ERROR_CONTEXT);
|
||||
TORCH_CHECK(!ns_.has_value(),
|
||||
"fallback(...): Fallback functions which apply to only a single namespace ",
|
||||
"(you specified ", *ns_, ") are not supported. If you intended to apply ",
|
||||
"this fallback function globally, please define a separate block:\n\n",
|
||||
" TORCH_LIBRARY_IMPL(_, ", *dispatch_key, ", m) { m.fallback(...); }\n\n",
|
||||
ERROR_CONTEXT);
|
||||
registrars_.emplace_back(
|
||||
Dispatcher::singleton().registerFallback(
|
||||
*dispatch_key,
|
||||
std::move(f.func_),
|
||||
debugString(std::move(f.debug_), file_, line_)
|
||||
)
|
||||
);
|
||||
return *this;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -594,9 +594,407 @@ private:
|
||||
std::vector<RegistrationHandleRAII> registrars_;
|
||||
};
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
//
|
||||
// New style API
|
||||
//
|
||||
// --------------------------------------------------------------------------
|
||||
//
|
||||
// The basic concept behind the new style API is to be as similar to pybind11's
|
||||
// API as possible.
|
||||
//
|
||||
// A quick tour of a few usage examples:
|
||||
//
|
||||
// // Define a library whose operators live in the namespace 'aten'.
|
||||
// // You must define all of the operators for this library in
|
||||
// // this namespace.
|
||||
// TORCH_LIBRARY(aten, m) {
|
||||
// // Define a schema for an operator, but provide no implementation
|
||||
// m.def("mul(Tensor self, Tensor other) -> Tensor");
|
||||
//
|
||||
// // Define a operator with exactly one implementation for all backends.
|
||||
// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
|
||||
//
|
||||
// // Provide an implementation for a defined operator (you can
|
||||
// // provide multiple; one per backend). We'll take care of calling
|
||||
// // the correct implementation depending on if we get a CPU
|
||||
// // tensor or a CUDA tensor
|
||||
// m.impl("mul", torch::kCPU, &mul_cpu_impl);
|
||||
// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
|
||||
// }
|
||||
//
|
||||
// // Define implementations for operators for a non-standard backend,
|
||||
// // e.g., XLA (valid values are entries of DispatchKey). These
|
||||
// // operator names are not namespaced; you can define implementations
|
||||
// // for any namespace.
|
||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||
// m.impl("mul", &mul_xla_impl);
|
||||
// }
|
||||
|
||||
|
||||
// Represents a C++ function that implements an operator. Most users won't
|
||||
// interact directly with this class, except via error messages: the
|
||||
// constructors this function define the set of permissible "function"-like
|
||||
// things you can bind via the interface.
|
||||
//
|
||||
// This class erases the type of the passed in function, but durably records
|
||||
// the type via an inferred schema for the function.
|
||||
//
|
||||
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
|
||||
// opaque to the user.
|
||||
class CAFFE2_API CppFunction final {
|
||||
public:
|
||||
// This overload accepts function pointers, e.g., CppFunction(&add_impl)
|
||||
template <typename Func>
|
||||
explicit CppFunction(Func* f, std::enable_if_t<guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
|
||||
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
|
||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||
, schema_(detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Func>>>())
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
// This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... })
|
||||
template <typename Lambda>
|
||||
explicit CppFunction(Lambda&& f, std::enable_if_t<guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
|
||||
: func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
|
||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||
, schema_(detail::inferFunctionSchemaFromFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>())
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
// This static factory lets you create CppFunctions that (1) don't have boxing
|
||||
// wrappers (because we don't support it yet) and (2) don't have schema
|
||||
// inference (because some ops don't support it).
|
||||
//
|
||||
// TODO: Eliminate the necessity for this function entirely.
|
||||
template <typename Func>
|
||||
static CppFunction makeUnboxedOnly(Func* f) {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: more user friendly API
|
||||
static CppFunction makeFallthrough() {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFallthrough(),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: more user friendly API
|
||||
template<KernelFunction::BoxedKernelFunction* func>
|
||||
static CppFunction makeFromBoxedFunction() {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFromBoxedFunction<func>(),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
CppFunction&& debug(std::string d) && {
|
||||
debug_ = std::move(d);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||
c10::KernelFunction func_;
|
||||
std::unique_ptr<c10::FunctionSchema> schema_;
|
||||
std::string debug_;
|
||||
|
||||
// The "setter" for dispatch_key_
|
||||
template <typename Func>
|
||||
friend CppFunction dispatch(c10::DispatchKey, Func&&);
|
||||
|
||||
// The only class which actually pulls out values from CppFunction (does so
|
||||
// destructively, felt too lazy to write accessors that I don't even
|
||||
// want users to use)
|
||||
friend class Library;
|
||||
|
||||
CppFunction(KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema);
|
||||
};
|
||||
|
||||
// Create a CppFunction which is associated with a specific dispatch key.
|
||||
// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/
|
||||
// the dispatcher determines that the DispatchKey is the best choice for
|
||||
// a function
|
||||
template <typename Func>
|
||||
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
if (k == c10::DispatchKey::CatchAll) {
|
||||
f.dispatch_key_ = c10::nullopt;
|
||||
} else {
|
||||
f.dispatch_key_ = k;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
// Convenience overload of dispatch which accepts DeviceType
|
||||
template <typename Func>
|
||||
inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
|
||||
auto deviceTypeToDispatchKey = [](DeviceType t){
|
||||
switch (t) {
|
||||
// This list is synchronized with the k-constants in c10/core/DeviceType.h
|
||||
case DeviceType::CPU:
|
||||
return c10::DispatchKey::CPU;
|
||||
case DeviceType::CUDA:
|
||||
return c10::DispatchKey::CUDA;
|
||||
case DeviceType::XLA:
|
||||
return c10::DispatchKey::XLA;
|
||||
case DeviceType::HIP:
|
||||
return c10::DispatchKey::HIP;
|
||||
case DeviceType::MSNPU:
|
||||
return c10::DispatchKey::MSNPU;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Device type ", t, " cannot be overloaded at dispatch time, "
|
||||
"please file a bug report explaining what you were trying to do.");
|
||||
}
|
||||
};
|
||||
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
|
||||
}
|
||||
|
||||
inline FunctionSchema schema(const char* str, AliasAnalysisKind k) {
|
||||
FunctionSchema s = torch::jit::parseSchema(str);
|
||||
s.setAliasAnalysis(k);
|
||||
return s;
|
||||
}
|
||||
inline FunctionSchema schema(const char* s) {
|
||||
return schema(s, AliasAnalysisKind::FROM_SCHEMA);
|
||||
}
|
||||
inline FunctionSchema&& schema(FunctionSchema&& s) { return std::move(s); }
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(FunctionSchema&& s) {
|
||||
return c10::make_right<OperatorName, FunctionSchema>(std::move(s));
|
||||
}
|
||||
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(OperatorName&& n) {
|
||||
return c10::make_left<OperatorName, FunctionSchema>(std::move(n));
|
||||
}
|
||||
inline c10::either<OperatorName, FunctionSchema> constructSchemaOrName(const char* str) {
|
||||
auto s = torch::jit::parseSchemaOrName(str);
|
||||
if (s.is_right()) {
|
||||
s.right().setAliasAnalysis(AliasAnalysisKind::FROM_SCHEMA);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
class TorchLibraryInit;
|
||||
}
|
||||
|
||||
// This is the "handle" by which functions defined in TORCH_LIBRARY
|
||||
// and TORCH_LIBRARY_IMPL can define operators and override implementations
|
||||
// at certain backends.
|
||||
//
|
||||
// Conventionally, you get access to it using those two macros:
|
||||
//
|
||||
// TORCH_LIBRARY(torchvision, m) {
|
||||
// // m is a c10::Library
|
||||
// m.def("roi_align", ...);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||
// // m is a c10::Library
|
||||
// m.impl("add", ...);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// In some cases, you need to define something that applies to all namespaces,
|
||||
// not just one namespace (usually a fallback). In that case, use the reserved
|
||||
// namespace _, e.g.,
|
||||
//
|
||||
// TORCH_LIBRARY_IMPL(_, XLA, m) {
|
||||
// m.fallback(xla_fallback);
|
||||
// }
|
||||
//
|
||||
class CAFFE2_API Library final {
|
||||
public:
|
||||
// Which type of macro produced this Library
|
||||
enum Kind {
|
||||
DEF, // from TORCH_LIBRARY (no qualifier)
|
||||
IMPL,
|
||||
FRAGMENT,
|
||||
};
|
||||
|
||||
// Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly
|
||||
Library(Kind kind, std::string ns, c10::optional<DispatchKey> k, const char* file, uint32_t line);
|
||||
|
||||
Library(const Library&) = delete;
|
||||
Library& operator=(const Library&) = delete;
|
||||
Library(Library&&) = default;
|
||||
Library& operator=(Library&&) = default;
|
||||
|
||||
// Some notes about the API design here. We had the following constraints:
|
||||
//
|
||||
// - We need to support multiple "types" of arguments for schema and
|
||||
// functions (e.g., unnamed lambda types, regular functions, const char*,
|
||||
// fully instantiated schemas)
|
||||
// - We don't want to write exponentially many overloads
|
||||
// - We don't want to rely on implicit conversion to a common type,
|
||||
// because the C++ compiler will only be willing to do a single
|
||||
// implicit conversion (reducing the set of valid types which you
|
||||
// can invoke with); also error messages are worse when an implicit
|
||||
// conversion is not selected (as the compiler will not explain
|
||||
// why it didn't select an implicit conversion; this is different
|
||||
// from overloads where it will explain each candidate overload and
|
||||
// why it didn't apply)
|
||||
//
|
||||
// To solve all of these constraints at the same time, we use a trick taken
|
||||
// from the pybind11 library: template over the argument in the user visible
|
||||
// API, and inside of the templated function explicitly call an overloaded
|
||||
// function to resolve the argument to a real type. You get the good error
|
||||
// messages from overloads, but at the same time you only need to write the
|
||||
// overload for any given argument type once.
|
||||
|
||||
// Declare an operator with a schema, but don't provide any implementations
|
||||
// for it. You're expected to then provide implementations using the
|
||||
// impl() method.
|
||||
template <typename Schema>
|
||||
Library& def(Schema&& raw_schema) & {
|
||||
FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return _def(std::move(s));
|
||||
}
|
||||
|
||||
// Convenience method to define an operator for a schema and then register
|
||||
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
|
||||
// except that if n is not a schema, then the schema is inferred from the
|
||||
// static type of f.
|
||||
template <typename NameOrSchema, typename Func>
|
||||
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
|
||||
return _def(std::move(name_or_schema), std::move(f));
|
||||
}
|
||||
|
||||
// Register an implementation for an operator. You may register multiple
|
||||
// implementations for a single operator at different dispatch keys
|
||||
// (see torch::dispatch). Implementations must have a corresponding
|
||||
// declaration (from def), otherwise they are invalid.
|
||||
template <typename Func>
|
||||
Library& impl(const char* name, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
return _impl(name, std::move(f));
|
||||
}
|
||||
|
||||
// Convenience overload for directly specifying the dispatch key. Dispatch
|
||||
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
|
||||
// the canonical list of accepted overloads.
|
||||
template <typename Dispatch, typename Func>
|
||||
Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
|
||||
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
||||
}
|
||||
|
||||
// Convenience overload for unboxed only kernels. These are quite common
|
||||
// but will be eventually eliminated; this function makes it easy to grep for
|
||||
// them.
|
||||
//
|
||||
// TODO: Remove this overload once the makeUnboxedOnly incidence rate
|
||||
// goes way down
|
||||
template <typename Func>
|
||||
Library& impl_UNBOXED(const char* name, Func* raw_f) & {
|
||||
return impl(name, CppFunction::makeUnboxedOnly(raw_f));
|
||||
}
|
||||
|
||||
// Register a fallback implementation for all operators which will be used
|
||||
// if there is not a specific implementation for an operator available.
|
||||
// Providing a DispatchKey is MANDATORY for fallback at the moment; e.g.,
|
||||
// only call this from TORCH_LIBRARY_IMPL
|
||||
template <typename Func>
|
||||
Library& fallback(Func&& raw_f) & {
|
||||
CppFunction f((std::forward<Func>(raw_f)));
|
||||
return _fallback(std::move(f));
|
||||
}
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
c10::optional<std::string> ns_;
|
||||
c10::optional<DispatchKey> dispatch_key_;
|
||||
const char* file_;
|
||||
uint32_t line_;
|
||||
|
||||
std::vector<RegistrationHandleRAII> registrars_;
|
||||
|
||||
friend detail::TorchLibraryInit;
|
||||
|
||||
// Non-user visible actual implementations of functions. These aren't
|
||||
// public because we only implement & qualifier and not && qualifier
|
||||
Library& _def(FunctionSchema&& schema, OperatorName* out_name = nullptr) &;
|
||||
Library& _def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
|
||||
Library& _impl(const char* name, CppFunction&& f) &;
|
||||
Library& _fallback(CppFunction&& f) &;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
class TorchLibraryInit final {
|
||||
private:
|
||||
using InitFn = void(Library&);
|
||||
Library lib_;
|
||||
public:
|
||||
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<DispatchKey> k, const char* file, uint32_t line)
|
||||
: lib_(kind, ns, k, file, line) {
|
||||
fn(lib_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10
|
||||
|
||||
// NB: The EXACT NAMING of the initializer functions (e.g.,
|
||||
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
|
||||
// see the regexes at tools/code_analyzer/run_analyzer.sh
|
||||
|
||||
#define TORCH_LIBRARY(ns, m) \
|
||||
static void TORCH_LIBRARY_init_ ## ns (c10::Library&); \
|
||||
static c10::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
|
||||
c10::Library::DEF, \
|
||||
&TORCH_LIBRARY_init_ ## ns, \
|
||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_init_ ## ns (c10::Library& m)
|
||||
|
||||
// This macro is a version of TORCH_LIBRARY that doesn't enforce that there
|
||||
// is only one library (it is a "fragment"). This should ONLY be used
|
||||
// with PerOpRegistration (as its name suggests).
|
||||
#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \
|
||||
static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library&); \
|
||||
static c10::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \
|
||||
c10::Library::FRAGMENT, \
|
||||
&TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \
|
||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (c10::Library& m)
|
||||
|
||||
#define TORCH_LIBRARY_IMPL(ns, k, m) \
|
||||
static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library&); \
|
||||
static c10::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \
|
||||
c10::Library::IMPL, \
|
||||
& TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \
|
||||
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (c10::Library& m)
|
||||
|
||||
// These are variants of the macros above which are to be used for testing (they
|
||||
// don't setup the static initializer, so you can control the visibility of
|
||||
// the allocated library yourself).
|
||||
//
|
||||
// DO NOT use these in production code, they are NOT understood by the
|
||||
// code analyzer and will be incorrectly analyzed in those situations.
|
||||
#define MAKE_TORCH_LIBRARY(ns) Library(Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
|
||||
#define MAKE_TORCH_LIBRARY_IMPL(ns, k) Library(Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
|
||||
|
||||
namespace torch {
|
||||
// Old-style API
|
||||
using RegisterOperators = c10::RegisterOperators;
|
||||
|
||||
// New-style API
|
||||
using c10::dispatch;
|
||||
using c10::schema;
|
||||
}
|
||||
|
@ -12,7 +12,6 @@
|
||||
|
||||
#include <ATen/core/boxing/impl/test_helpers.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <functional>
|
||||
|
||||
@ -22,10 +21,7 @@ using c10::OperatorHandle;
|
||||
using c10::Dispatcher;
|
||||
using c10::IValue;
|
||||
using c10::DispatchKey;
|
||||
|
||||
using torch::Library;
|
||||
using torch::CppFunction;
|
||||
|
||||
using c10::Library;
|
||||
using at::Tensor;
|
||||
|
||||
namespace {
|
||||
@ -1446,7 +1442,7 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {
|
||||
|
||||
TEST(NewOperatorRegistrationTest, fallback) {
|
||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, CPU);
|
||||
m.fallback(CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||
m.fallback(c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||
|
||||
auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
|
||||
|
||||
@ -1499,9 +1495,9 @@ TEST(NewOperatorRegistrationTest, CppFunction) {
|
||||
m.def("fn2", dummy_fn);
|
||||
m.def("fn3", [](const Tensor& x) { return x; });
|
||||
// These require explicit schema
|
||||
m.def("fn4(Tensor x) -> Tensor", CppFunction::makeFallthrough());
|
||||
m.def("fn5(Tensor x) -> Tensor", CppFunction::makeUnboxedOnly(dummy_fn));
|
||||
m.def("fn6(Tensor x) -> Tensor", CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||
m.def("fn4(Tensor x) -> Tensor", c10::CppFunction::makeFallthrough());
|
||||
m.def("fn5(Tensor x) -> Tensor", c10::CppFunction::makeUnboxedOnly(dummy_fn));
|
||||
m.def("fn6(Tensor x) -> Tensor", c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());
|
||||
}
|
||||
|
||||
// Some internal tests that have to be done from C++
|
||||
|
@ -120,7 +120,7 @@ m.def("${unqual_schema_string}");
|
||||
# TORCH_LIBRARY macro invocation
|
||||
DEFAULT_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
m.impl("${unqual_operator_name_with_overload}",
|
||||
torch::CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}));
|
||||
CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}));
|
||||
""")
|
||||
|
||||
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
@ -137,7 +137,7 @@ m.impl("${unqual_operator_name_with_overload}", &TypeDefault::${type_wrapper_nam
|
||||
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
m.impl("${unqual_operator_name_with_overload}",
|
||||
torch::dispatch(DispatchKey::${Backend},
|
||||
torch::CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
|
||||
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
|
||||
);
|
||||
""")
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/ResizeCommon.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
namespace at {
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/cuda/Resize.cuh>
|
||||
#include <ATen/native/ResizeCommon.h>
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
@ -116,7 +116,7 @@ The final file `ATen/native/quantized/cpu/qxand.cpp` would look as follows
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h> // Need that for the `native_functions.yaml`
|
||||
#include <ATen/core/Type.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpu/vec256/vec256.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/c10_utils.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/SmallVector.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/init_qnnpack.h>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/cpu/vec256/vec256.h>
|
||||
#include <ATen/native/SortingUtils.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/core/QScheme.h>
|
||||
|
||||
namespace at {
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
TORCH_LIBRARY(quantized, m) {
|
||||
m.def("add(Tensor qa, Tensor qb, float scale, int zero_point) -> Tensor qc");
|
||||
|
@ -1,6 +1,6 @@
|
||||
#ifdef USE_XNNPACK
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/xnnpack/Convolution.h>
|
||||
#include <ATen/native/xnnpack/Linear.h>
|
||||
#include <ATen/native/xnnpack/OpContext.h>
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
namespace at {
|
||||
|
@ -1,7 +1,7 @@
|
||||
// ${generated_comment}
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/TypeDefault.h>
|
||||
$extra_headers
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
// ${generated_comment}
|
||||
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
|
@ -13,7 +13,7 @@
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
namespace {
|
||||
static const char* named_tensors_unsupported_error =
|
||||
|
@ -27,7 +27,7 @@ $storage_tensor_headers
|
||||
#include <utility>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
$extra_cuda_headers
|
||||
$legacy_th_headers
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
using namespace at;
|
||||
@ -110,7 +110,7 @@ void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack*
|
||||
|
||||
TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
|
||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||
m.fallback(CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||
|
||||
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
||||
|
||||
@ -122,7 +122,7 @@ TEST(BackendFallbackTest, TestBackendFallbackWithMode) {
|
||||
|
||||
TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
||||
auto m = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericWrapper);
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
|
||||
m.fallback(CppFunction::makeFromBoxedFunction<&generic_wrapper_fallback>());
|
||||
|
||||
override_call_count = 0;
|
||||
Tensor a = at::detail::make_tensor<GenericWrapperTensorImpl>(ones({5, 5}, kDouble));
|
||||
@ -132,10 +132,10 @@ TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
|
||||
|
||||
TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
|
||||
auto m = MAKE_TORCH_LIBRARY_IMPL(aten, TESTING_ONLY_GenericMode);
|
||||
m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||
m.impl("mul.Tensor", c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());
|
||||
|
||||
auto gm = MAKE_TORCH_LIBRARY_IMPL(_, TESTING_ONLY_GenericMode);
|
||||
gm.fallback(torch::CppFunction::makeFallthrough());
|
||||
gm.fallback(c10::CppFunction::makeFallthrough());
|
||||
|
||||
c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericMode);
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/native/DistributionTemplates.h>
|
||||
#include <ATen/native/cpu/DistributionTemplates.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/all.h>
|
||||
#include <stdexcept>
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
#include <ATen/Generator.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/all.h>
|
||||
#include <stdexcept>
|
||||
|
@ -699,12 +699,7 @@ endif()
|
||||
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
||||
FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES
|
||||
"${TORCH_SRC_DIR}/script.h"
|
||||
"${TORCH_SRC_DIR}/extension.h"
|
||||
"${TORCH_SRC_DIR}/custom_class.h"
|
||||
"${TORCH_SRC_DIR}/library.h"
|
||||
"${TORCH_SRC_DIR}/custom_class_detail.h"
|
||||
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h" "${TORCH_SRC_DIR}/custom_class_detail.h"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
|
||||
|
||||
|
||||
|
@ -64,7 +64,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \
|
||||
../../../torch/csrc/jit/runtime/custom_operator.h \
|
||||
../../../torch/csrc/jit/serialization/import.h \
|
||||
../../../torch/csrc/jit/api/module.h \
|
||||
../../../torch/library.h \
|
||||
../../../torch/custom_class.h
|
||||
# Don't include .cpp files!
|
||||
FILE_PATTERNS = *.h
|
||||
|
@ -66,71 +66,78 @@ struct PickleTester : torch::CustomClassHolder {
|
||||
std::vector<int64_t> vals;
|
||||
};
|
||||
|
||||
static auto test = torch::class_<Foo>("_TorchScriptTesting", "_Foo")
|
||||
.def(torch::init<int64_t, int64_t>())
|
||||
// .def(torch::init<>())
|
||||
.def("info", &Foo::info)
|
||||
.def("increment", &Foo::increment)
|
||||
.def("add", &Foo::add)
|
||||
.def("combine", &Foo::combine);
|
||||
|
||||
static auto testStack =
|
||||
torch::class_<MyStackClass<std::string>>(
|
||||
"_TorchScriptTesting",
|
||||
"_StackString")
|
||||
.def(torch::init<std::vector<std::string>>())
|
||||
.def("push", &MyStackClass<std::string>::push)
|
||||
.def("pop", &MyStackClass<std::string>::pop)
|
||||
.def("clone", &MyStackClass<std::string>::clone)
|
||||
.def("merge", &MyStackClass<std::string>::merge)
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
|
||||
return self->stack_;
|
||||
},
|
||||
[](std::vector<std::string> state) { // __setstate__
|
||||
return c10::make_intrusive<MyStackClass<std::string>>(
|
||||
std::vector<std::string>{"i", "was", "deserialized"});
|
||||
})
|
||||
.def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
|
||||
.def(
|
||||
"top",
|
||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
|
||||
-> std::string { return self->stack_.back(); });
|
||||
// clang-format off
|
||||
// The following will fail with a static assert telling you you have to
|
||||
// take an intrusive_ptr<MyStackClass> as the first argument.
|
||||
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
|
||||
// clang-format on
|
||||
|
||||
static auto testPickle =
|
||||
torch::class_<PickleTester>("_TorchScriptTesting", "_PickleTester")
|
||||
.def(torch::init<std::vector<int64_t>>())
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
||||
return std::vector<int64_t>{1, 3, 3, 7};
|
||||
},
|
||||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<PickleTester>(std::move(state));
|
||||
})
|
||||
.def(
|
||||
"top",
|
||||
[](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
return self->vals.back();
|
||||
})
|
||||
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
auto val = self->vals.back();
|
||||
self->vals.pop_back();
|
||||
return val;
|
||||
});
|
||||
|
||||
at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
|
||||
return torch::zeros({instance->vals.back(), 4});
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
m.class_<Foo>("_Foo")
|
||||
.def(torch::init<int64_t, int64_t>())
|
||||
// .def(torch::init<>())
|
||||
.def("info", &Foo::info)
|
||||
.def("increment", &Foo::increment)
|
||||
.def("add", &Foo::add)
|
||||
.def("combine", &Foo::combine);
|
||||
|
||||
m.class_<MyStackClass<std::string>>("_StackString")
|
||||
.def(torch::init<std::vector<std::string>>())
|
||||
.def("push", &MyStackClass<std::string>::push)
|
||||
.def("pop", &MyStackClass<std::string>::pop)
|
||||
.def("clone", &MyStackClass<std::string>::clone)
|
||||
.def("merge", &MyStackClass<std::string>::merge)
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
|
||||
return self->stack_;
|
||||
},
|
||||
[](std::vector<std::string> state) { // __setstate__
|
||||
return c10::make_intrusive<MyStackClass<std::string>>(
|
||||
std::vector<std::string>{"i", "was", "deserialized"});
|
||||
})
|
||||
.def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
|
||||
.def(
|
||||
"top",
|
||||
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
|
||||
-> std::string { return self->stack_.back(); });
|
||||
// clang-format off
|
||||
// The following will fail with a static assert telling you you have to
|
||||
// take an intrusive_ptr<MyStackClass> as the first argument.
|
||||
// .def("foo", [](int64_t a) -> int64_t{ return 3;});
|
||||
// clang-format on
|
||||
|
||||
m.class_<PickleTester>("_PickleTester")
|
||||
.def(torch::init<std::vector<int64_t>>())
|
||||
.def_pickle(
|
||||
[](c10::intrusive_ptr<PickleTester> self) { // __getstate__
|
||||
return std::vector<int64_t>{1, 3, 3, 7};
|
||||
},
|
||||
[](std::vector<int64_t> state) { // __setstate__
|
||||
return c10::make_intrusive<PickleTester>(std::move(state));
|
||||
})
|
||||
.def(
|
||||
"top",
|
||||
[](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
return self->vals.back();
|
||||
})
|
||||
.def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
|
||||
auto val = self->vals.back();
|
||||
self->vals.pop_back();
|
||||
return val;
|
||||
});
|
||||
|
||||
m.def(
|
||||
"take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
|
||||
take_an_instance);
|
||||
// test that schema inference is ok too
|
||||
m.def("take_an_instance_inferred", take_an_instance);
|
||||
torch::RegisterOperators& register_take_instance() {
|
||||
static auto instance_registry = torch::RegisterOperators().op(
|
||||
torch::RegisterOperators::options()
|
||||
.schema(
|
||||
"_TorchScriptTesting::take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y")
|
||||
.catchAllKernel<decltype(take_an_instance), &take_an_instance>());
|
||||
return instance_registry;
|
||||
}
|
||||
|
||||
static auto& ensure_take_instance_registered = register_take_instance();
|
||||
|
||||
} // namespace
|
||||
|
||||
void testTorchbindIValueAPI() {
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
|
@ -75,7 +75,7 @@ namespace {
|
||||
// cares about the name
|
||||
TORCH_LIBRARY(_test, m) {
|
||||
m.def("AA(Tensor self) -> Tensor");
|
||||
m.impl("AA", torch::CppFunction::makeUnboxedOnly(AA_op));
|
||||
m.impl("AA", CppFunction::makeUnboxedOnly(AA_op));
|
||||
|
||||
m.def("BB(Tensor self) -> Tensor");
|
||||
m.impl("BB", &BB_op);
|
||||
@ -93,7 +93,7 @@ TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(_test, m) {
|
||||
|
||||
TORCH_LIBRARY_IMPL(_test, CPU, m) {
|
||||
m.impl_UNBOXED("EE", EE_op);
|
||||
m.impl("FF", torch::CppFunction::makeUnboxedOnly(FF_op));
|
||||
m.impl("FF", CppFunction::makeUnboxedOnly(FF_op));
|
||||
m.impl("GG",
|
||||
[] (Tensor a) -> Tensor {
|
||||
return call_FF_op(a);
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include "torch/csrc/autograd/VariableTypeUtils.h"
|
||||
|
||||
#include <ATen/TypeDefault.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
// ${generated_comment}
|
||||
|
||||
|
@ -15,7 +15,7 @@ echo "Analyze: ${INPUT}"
|
||||
# to operate, so for safety we match a more expansive set.
|
||||
"${ANALYZER_BIN}" \
|
||||
-op_schema_pattern="^(_aten|_prim|aten|quantized|profiler|_test)::[a-zA-Z0-9_.]+(\(.*)?$" \
|
||||
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|torch::Library::(_?def|_?impl|_?impl_UNBOXED)" \
|
||||
-op_register_pattern="c10::RegisterOperators::(op|checkSchemaAndRegisterOp_)|c10::Module::(_?def|_?impl|impl_UNBOXED)|c10::Library::(_?def|_?impl|_?impl_UNBOXED)" \
|
||||
-op_invoke_pattern="c10::Dispatcher::findSchema|callOp" \
|
||||
-root_symbol_pattern="torch::jit::[^(]" \
|
||||
-torch_library_init_pattern="^.*TORCH_LIBRARY_init_([^(]+)(\(.*)?$" \
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include "function.h"
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
#include <torch/custom_class_detail.h>
|
||||
#include <torch/library.h>
|
||||
#include "interpreter.h"
|
||||
|
||||
namespace torch {
|
||||
|
@ -1,8 +1,8 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TypeDefault.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using at::Scalar;
|
||||
@ -104,12 +104,12 @@ void log_softmax_kernel(const c10::OperatorHandle& op, Stack* stack) {
|
||||
TORCH_LIBRARY_IMPL(_aten, Autograd, m) {
|
||||
m.impl("add.Scalar", torch::autograd::VariableType::add_Scalar);
|
||||
m.impl("mul.Tensor", torch::autograd::VariableType::mul_Tensor);
|
||||
m.impl("conv2d", torch::CppFunction::makeFromBoxedFunction<conv2d_kernel>());
|
||||
m.impl("conv2d", CppFunction::makeFromBoxedFunction<conv2d_kernel>());
|
||||
m.impl("dropout", VariableType::dropout);
|
||||
m.impl("feature_dropout", VariableType::feature_dropout);
|
||||
m.impl(
|
||||
"log_softmax.int",
|
||||
torch::CppFunction::makeFromBoxedFunction<log_softmax_kernel>());
|
||||
CppFunction::makeFromBoxedFunction<log_softmax_kernel>());
|
||||
m.impl(
|
||||
"max_pool2d",
|
||||
[](const Tensor& self,
|
||||
@ -127,7 +127,7 @@ TORCH_LIBRARY_IMPL(_aten, Autograd, m) {
|
||||
ceil_mode);
|
||||
});
|
||||
m.impl("relu", VariableType::relu);
|
||||
m.impl("view", torch::CppFunction::makeFromBoxedFunction<view_kernel>());
|
||||
m.impl("view", CppFunction::makeFromBoxedFunction<view_kernel>());
|
||||
m.impl("t", VariableType::t);
|
||||
m.impl("addmm", VariableType::addmm);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
#include <torch/library.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <pybind11/operators.h>
|
||||
@ -45,12 +45,12 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
||||
|
||||
|
||||
template <typename Func>
|
||||
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
inline c10::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
auto mb_key = parseDispatchKey(key);
|
||||
if (mb_key) {
|
||||
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
||||
return c10::dispatch(*mb_key, std::move(raw_f));
|
||||
} else {
|
||||
torch::CppFunction f(std::forward<Func>(raw_f));
|
||||
c10::CppFunction f(std::forward<Func>(raw_f));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@ -62,16 +62,16 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def("schema", &c10::OperatorHandle::schema);
|
||||
|
||||
// TODO: figure out how to do chaining
|
||||
py::class_<torch::Library>(m, "_DispatchModule")
|
||||
py::class_<c10::Library>(m, "_DispatchModule")
|
||||
.def("def_", [](py::object self, const char* schema, const char* alias) {
|
||||
self.cast<torch::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||
self.cast<c10::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||
return self;
|
||||
}, "", py::arg("schema"), py::arg("alias") = "")
|
||||
// Simulated "legacy" def where alias analysis kind is not set.
|
||||
// Ordinarily this can only be exercised from RegisterOperators() API
|
||||
// but I am not going to bind that here
|
||||
.def("def_legacy", [](py::object self, const char* schema) {
|
||||
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
||||
self.cast<c10::Library&>().def(torch::jit::parseSchema(schema));
|
||||
return self;
|
||||
}, "", py::arg("schema"))
|
||||
// We can't conveniently turn Python functions into valid functions
|
||||
@ -83,7 +83,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
// Mangling scheme: args_rets. One character per.
|
||||
// t = Tensor
|
||||
.def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<torch::Library&>().def(
|
||||
self.cast<c10::Library&>().def(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -94,7 +94,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "default_def_name_t_t")
|
||||
.def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) {
|
||||
self.cast<torch::Library&>().def(
|
||||
self.cast<c10::Library&>().def(
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -108,7 +108,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
// TODO: maybe consider deduplicating the definitions here, it's getting
|
||||
// pretty long
|
||||
.def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<torch::Library&>().impl(
|
||||
self.cast<c10::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -119,7 +119,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "impl_t_t")
|
||||
.def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<torch::Library&>().impl(
|
||||
self.cast<c10::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) {
|
||||
return a;
|
||||
@ -133,9 +133,9 @@ void initDispatchBindings(PyObject* module) {
|
||||
// This is a wee bit dodgy right now, but the "underlying" API is much
|
||||
// easier to test than the high level (using TORCH_LIBRARY, e.g.)
|
||||
if (name.empty()) {
|
||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
||||
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
||||
} else {
|
||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
||||
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -5,12 +5,12 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <c10/util/TypeTraits.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/custom_class_detail.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
@ -270,14 +270,4 @@ using ::torch::class_;
|
||||
|
||||
} // namespace jit
|
||||
|
||||
template <class CurClass>
|
||||
inline class_<CurClass> Library::class_(const std::string& className) {
|
||||
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
|
||||
"class_(\"", className, "\"): Cannot define a class inside of a TORCH_LIBRARY_IMPL block. "
|
||||
"All class_()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. "
|
||||
"(Error occurred at ", file_, ":", line_, ")");
|
||||
TORCH_INTERNAL_ASSERT(ns_.has_value(), file_, ":", line_);
|
||||
return torch::class_<CurClass>(*ns_, className);
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace torch
|
||||
|
406
torch/library.h
406
torch/library.h
@ -1,406 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/DispatchKey.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
#endif
|
||||
|
||||
// Just for inferFunctionSchemaFromFunctor
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
|
||||
namespace torch {
|
||||
|
||||
template <class CurClass>
|
||||
class class_;
|
||||
|
||||
// A quick tour of a few usage examples:
|
||||
//
|
||||
// // Define a library whose operators live in the namespace 'aten'.
|
||||
// // You must define all of the operators for this library in
|
||||
// // this namespace.
|
||||
// TORCH_LIBRARY(aten, m) {
|
||||
// // Define a schema for an operator, but provide no implementation
|
||||
// m.def("mul(Tensor self, Tensor other) -> Tensor");
|
||||
//
|
||||
// // Define a operator with exactly one implementation for all backends.
|
||||
// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
|
||||
//
|
||||
// // Provide an implementation for a defined operator (you can
|
||||
// // provide multiple; one per backend). We'll take care of calling
|
||||
// // the correct implementation depending on if we get a CPU
|
||||
// // tensor or a CUDA tensor
|
||||
// m.impl("mul", torch::kCPU, &mul_cpu_impl);
|
||||
// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
|
||||
// }
|
||||
//
|
||||
// // Define implementations for operators for a non-standard backend,
|
||||
// // e.g., XLA (valid values are entries of DispatchKey). These
|
||||
// // operator names are not namespaced; you can define implementations
|
||||
// // for any namespace.
|
||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||
// m.impl("mul", &mul_xla_impl);
|
||||
// }
|
||||
|
||||
|
||||
// Represents a C++ function that implements an operator. Most users won't
|
||||
// interact directly with this class, except via error messages: the
|
||||
// constructors this function define the set of permissible "function"-like
|
||||
// things you can bind via the interface.
|
||||
//
|
||||
// This class erases the type of the passed in function, but durably records
|
||||
// the type via an inferred schema for the function.
|
||||
//
|
||||
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
|
||||
// opaque to the user.
|
||||
class CAFFE2_API CppFunction final {
|
||||
public:
|
||||
// This overload accepts function pointers, e.g., CppFunction(&add_impl)
|
||||
template <typename Func>
|
||||
explicit CppFunction(Func* f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
|
||||
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
|
||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Func>>>())
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
// This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... })
|
||||
template <typename Lambda>
|
||||
explicit CppFunction(Lambda&& f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
|
||||
: func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
|
||||
// TODO: Don't go through WrapRuntimeKernelFunctor
|
||||
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>())
|
||||
, debug_()
|
||||
{}
|
||||
|
||||
// This static factory lets you create CppFunctions that (1) don't have boxing
|
||||
// wrappers (because we don't support it yet) and (2) don't have schema
|
||||
// inference (because some ops don't support it).
|
||||
//
|
||||
// TODO: Eliminate the necessity for this function entirely.
|
||||
template <typename Func>
|
||||
static CppFunction makeUnboxedOnly(Func* f) {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: more user friendly API
|
||||
static CppFunction makeFallthrough() {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFallthrough(),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: more user friendly API
|
||||
template<c10::KernelFunction::BoxedKernelFunction* func>
|
||||
static CppFunction makeFromBoxedFunction() {
|
||||
return CppFunction(
|
||||
c10::KernelFunction::makeFromBoxedFunction<func>(),
|
||||
/* schema */ nullptr
|
||||
);
|
||||
}
|
||||
|
||||
CppFunction&& debug(std::string d) && {
|
||||
debug_ = std::move(d);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||
c10::KernelFunction func_;
|
||||
std::unique_ptr<c10::FunctionSchema> schema_;
|
||||
std::string debug_;
|
||||
|
||||
// The "setter" for dispatch_key_
|
||||
template <typename Func>
|
||||
friend CppFunction dispatch(c10::DispatchKey, Func&&);
|
||||
|
||||
// The only class which actually pulls out values from CppFunction (does so
|
||||
// destructively, felt too lazy to write accessors that I don't even
|
||||
// want users to use)
|
||||
friend class Library;
|
||||
|
||||
CppFunction(c10::KernelFunction func, std::unique_ptr<c10::FunctionSchema> schema);
|
||||
};
|
||||
|
||||
// Create a CppFunction which is associated with a specific dispatch key.
|
||||
// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/
|
||||
// the dispatcher determines that the DispatchKey is the best choice for
|
||||
// a function
|
||||
template <typename Func>
|
||||
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
if (k == c10::DispatchKey::CatchAll) {
|
||||
f.dispatch_key_ = c10::nullopt;
|
||||
} else {
|
||||
f.dispatch_key_ = k;
|
||||
}
|
||||
return f;
|
||||
}
|
||||
|
||||
// Convenience overload of dispatch which accepts DeviceType
|
||||
template <typename Func>
|
||||
inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
||||
auto deviceTypeToDispatchKey = [](c10::DeviceType t){
|
||||
switch (t) {
|
||||
// This list is synchronized with the k-constants in c10/core/DeviceType.h
|
||||
case c10::DeviceType::CPU:
|
||||
return c10::DispatchKey::CPU;
|
||||
case c10::DeviceType::CUDA:
|
||||
return c10::DispatchKey::CUDA;
|
||||
case c10::DeviceType::XLA:
|
||||
return c10::DispatchKey::XLA;
|
||||
case c10::DeviceType::HIP:
|
||||
return c10::DispatchKey::HIP;
|
||||
case c10::DeviceType::MSNPU:
|
||||
return c10::DispatchKey::MSNPU;
|
||||
default:
|
||||
TORCH_CHECK(false,
|
||||
"Device type ", t, " cannot be overloaded at dispatch time, "
|
||||
"please file a bug report explaining what you were trying to do.");
|
||||
}
|
||||
};
|
||||
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
|
||||
}
|
||||
|
||||
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
|
||||
c10::FunctionSchema s = torch::jit::parseSchema(str);
|
||||
s.setAliasAnalysis(k);
|
||||
return s;
|
||||
}
|
||||
inline c10::FunctionSchema schema(const char* s) {
|
||||
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
}
|
||||
inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); }
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::FunctionSchema&& s) {
|
||||
return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
|
||||
}
|
||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::OperatorName&& n) {
|
||||
return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
|
||||
}
|
||||
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(const char* str) {
|
||||
auto s = torch::jit::parseSchemaOrName(str);
|
||||
if (s.is_right()) {
|
||||
s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
class TorchLibraryInit;
|
||||
|
||||
}
|
||||
|
||||
// This is the "handle" by which functions defined in TORCH_LIBRARY
|
||||
// and TORCH_LIBRARY_IMPL can define operators and override implementations
|
||||
// at certain backends.
|
||||
//
|
||||
// Conventionally, you get access to it using those two macros:
|
||||
//
|
||||
// TORCH_LIBRARY(torchvision, m) {
|
||||
// // m is a torch::Library
|
||||
// m.def("roi_align", ...);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
|
||||
// // m is a torch::Library
|
||||
// m.impl("add", ...);
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// In some cases, you need to define something that applies to all namespaces,
|
||||
// not just one namespace (usually a fallback). In that case, use the reserved
|
||||
// namespace _, e.g.,
|
||||
//
|
||||
// TORCH_LIBRARY_IMPL(_, XLA, m) {
|
||||
// m.fallback(xla_fallback);
|
||||
// }
|
||||
//
|
||||
class CAFFE2_API Library final {
|
||||
public:
|
||||
// Which type of macro produced this Library
|
||||
enum Kind {
|
||||
DEF, // from TORCH_LIBRARY (no qualifier)
|
||||
IMPL,
|
||||
FRAGMENT,
|
||||
};
|
||||
|
||||
// Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly
|
||||
Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line);
|
||||
|
||||
Library(const Library&) = delete;
|
||||
Library& operator=(const Library&) = delete;
|
||||
Library(Library&&) = default;
|
||||
Library& operator=(Library&&) = default;
|
||||
|
||||
// Some notes about the API design here. We had the following constraints:
|
||||
//
|
||||
// - We need to support multiple "types" of arguments for schema and
|
||||
// functions (e.g., unnamed lambda types, regular functions, const char*,
|
||||
// fully instantiated schemas)
|
||||
// - We don't want to write exponentially many overloads
|
||||
// - We don't want to rely on implicit conversion to a common type,
|
||||
// because the C++ compiler will only be willing to do a single
|
||||
// implicit conversion (reducing the set of valid types which you
|
||||
// can invoke with); also error messages are worse when an implicit
|
||||
// conversion is not selected (as the compiler will not explain
|
||||
// why it didn't select an implicit conversion; this is different
|
||||
// from overloads where it will explain each candidate overload and
|
||||
// why it didn't apply)
|
||||
//
|
||||
// To solve all of these constraints at the same time, we use a trick taken
|
||||
// from the pybind11 library: template over the argument in the user visible
|
||||
// API, and inside of the templated function explicitly call an overloaded
|
||||
// function to resolve the argument to a real type. You get the good error
|
||||
// messages from overloads, but at the same time you only need to write the
|
||||
// overload for any given argument type once.
|
||||
|
||||
// Declare an operator with a schema, but don't provide any implementations
|
||||
// for it. You're expected to then provide implementations using the
|
||||
// impl() method.
|
||||
template <typename Schema>
|
||||
Library& def(Schema&& raw_schema) & {
|
||||
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
|
||||
return _def(std::move(s));
|
||||
}
|
||||
|
||||
// Convenience method to define an operator for a schema and then register
|
||||
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
|
||||
// except that if n is not a schema, then the schema is inferred from the
|
||||
// static type of f.
|
||||
template <typename NameOrSchema, typename Func>
|
||||
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
|
||||
return _def(std::move(name_or_schema), std::move(f));
|
||||
}
|
||||
|
||||
// Register an implementation for an operator. You may register multiple
|
||||
// implementations for a single operator at different dispatch keys
|
||||
// (see torch::dispatch). Implementations must have a corresponding
|
||||
// declaration (from def), otherwise they are invalid.
|
||||
template <typename Func>
|
||||
Library& impl(const char* name, Func&& raw_f) & {
|
||||
CppFunction f(std::forward<Func>(raw_f));
|
||||
return _impl(name, std::move(f));
|
||||
}
|
||||
|
||||
// Convenience overload for directly specifying the dispatch key. Dispatch
|
||||
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
|
||||
// the canonical list of accepted overloads.
|
||||
template <typename Dispatch, typename Func>
|
||||
Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
|
||||
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
|
||||
}
|
||||
|
||||
// Convenience overload for unboxed only kernels. These are quite common
|
||||
// but will be eventually eliminated; this function makes it easy to grep for
|
||||
// them.
|
||||
//
|
||||
// TODO: Remove this overload once the makeUnboxedOnly incidence rate
|
||||
// goes way down
|
||||
template <typename Func>
|
||||
Library& impl_UNBOXED(const char* name, Func* raw_f) & {
|
||||
return impl(name, CppFunction::makeUnboxedOnly(raw_f));
|
||||
}
|
||||
|
||||
// Register a fallback implementation for all operators which will be used
|
||||
// if there is not a specific implementation for an operator available.
|
||||
// Providing a DispatchKey is MANDATORY for fallback at the moment; e.g.,
|
||||
// only call this from TORCH_LIBRARY_IMPL
|
||||
template <typename Func>
|
||||
Library& fallback(Func&& raw_f) & {
|
||||
CppFunction f((std::forward<Func>(raw_f)));
|
||||
return _fallback(std::move(f));
|
||||
}
|
||||
|
||||
template <class CurClass>
|
||||
inline class_<CurClass> class_(const std::string& className);
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
c10::optional<std::string> ns_;
|
||||
c10::optional<c10::DispatchKey> dispatch_key_;
|
||||
const char* file_;
|
||||
uint32_t line_;
|
||||
|
||||
std::vector<c10::RegistrationHandleRAII> registrars_;
|
||||
|
||||
friend class detail::TorchLibraryInit;
|
||||
|
||||
// Non-user visible actual implementations of functions. These aren't
|
||||
// public because we only implement & qualifier and not && qualifier
|
||||
Library& _def(c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr) &;
|
||||
Library& _def(c10::either<c10::OperatorName, c10::FunctionSchema>&&, CppFunction&& f) &;
|
||||
Library& _impl(const char* name, CppFunction&& f) &;
|
||||
Library& _fallback(CppFunction&& f) &;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
class TorchLibraryInit final {
|
||||
private:
|
||||
using InitFn = void(Library&);
|
||||
Library lib_;
|
||||
public:
|
||||
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
|
||||
: lib_(kind, ns, k, file, line) {
|
||||
fn(lib_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace torch
|
||||
|
||||
// NB: The EXACT NAMING of the initializer functions (e.g.,
|
||||
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
|
||||
// see the regexes at tools/code_analyzer/run_analyzer.sh
|
||||
|
||||
#define TORCH_LIBRARY(ns, m) \
|
||||
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
|
||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
|
||||
torch::Library::DEF, \
|
||||
&TORCH_LIBRARY_init_ ## ns, \
|
||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
|
||||
|
||||
// This macro is a version of TORCH_LIBRARY that doesn't enforce that there
|
||||
// is only one library (it is a "fragment"). This should ONLY be used
|
||||
// with PerOpRegistration (as its name suggests).
|
||||
#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \
|
||||
static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library&); \
|
||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \
|
||||
torch::Library::FRAGMENT, \
|
||||
&TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \
|
||||
#ns, c10::nullopt, __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library& m)
|
||||
|
||||
#define TORCH_LIBRARY_IMPL(ns, k, m) \
|
||||
static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library&); \
|
||||
static torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \
|
||||
torch::Library::IMPL, \
|
||||
& TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \
|
||||
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \
|
||||
); \
|
||||
void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library& m)
|
||||
|
||||
// These are variants of the macros above which are to be used for testing (they
|
||||
// don't setup the static initializer, so you can control the visibility of
|
||||
// the allocated library yourself).
|
||||
//
|
||||
// DO NOT use these in production code, they are NOT understood by the
|
||||
// code analyzer and will be incorrectly analyzed in those situations.
|
||||
#define MAKE_TORCH_LIBRARY(ns) torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
|
||||
#define MAKE_TORCH_LIBRARY_IMPL(ns, k) torch::Library(torch::Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
|
||||
|
||||
#include <torch/custom_class.h>
|
Reference in New Issue
Block a user