mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Initial torchbind prototype (#21098)
Summary: I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable. Currently, something like this works: ```cpp struct Foo { int x, y; Foo(): x(2), y(5){} Foo(int x_, int y_) : x(x_), y(y_) {} void display() { cout<<"x: "<<x<<' '<<"y: "<<y<<endl; } int64_t add(int64_t z) { return (x+y)*z; } }; static auto test = torch::jit::class_<Foo>("Foo") .def(torch::jit::init<int64_t, int64_t>()) .def("display", &Foo::display) .def("add", &Foo::add) .def("combine", &Foo::combine); ``` with ```py torch.jit.script def f(x): val = torch._C.Foo(5, 3) val.display() print(val.add(3)) ``` results in ``` x: 5 y: 3 24 ``` Current issues: - [x] The python class created by torchscript doesn't interactly properly with the surrounding code. ``` torch.jit.script def f(x): val = torch._C.Foo(5, 3) return val ``` - [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe). ```cpp void combine(Foo x) { ``` - [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object). ```py torch.jit.script def f(x): val = torch._C.Foo(5, 3) val2 = torch._C.Foo(100, 0) val.display() print(val.add(3)) ``` - [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods). - [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())` - [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually. - [ ] Currently, the conversion from Python into Torchscript doesn't work. - [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible. - [ ] We pass back into Python by value, currently. There's no way of passing by reference. - [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature. Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098 Differential Revision: D16634872 Pulled By: Chillee fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
This commit is contained in:
committed by
Facebook Github Bot
parent
4e6e11c139
commit
f81db8afb8
@ -116,6 +116,7 @@ test_custom_script_ops() {
|
||||
|
||||
# Run tests Python-side and export a script module.
|
||||
python test_custom_ops.py -v
|
||||
python test_custom_classes.py -v
|
||||
python model.py --export-script-module=model.pt
|
||||
# Run tests C++-side and load the exported script module.
|
||||
build/test_custom_ops ./model.pt
|
||||
|
@ -162,6 +162,7 @@ test_custom_script_ops() {
|
||||
cp -a "$CUSTOM_OP_BUILD" build
|
||||
# Run tests Python-side and export a script module.
|
||||
python test_custom_ops.py -v
|
||||
python test_custom_classes.py -v
|
||||
python model.py --export-script-module=model.pt
|
||||
# Run tests C++-side and load the exported script module.
|
||||
build/test_custom_ops ./model.pt
|
||||
|
@ -1,5 +1,6 @@
|
||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||
|
||||
git submodule update --init --recursive third_party/pybind11
|
||||
cd test\custom_operator
|
||||
|
||||
:: Build the custom operator library.
|
||||
@ -23,6 +24,7 @@ popd
|
||||
|
||||
:: Run tests Python-side and export a script module.
|
||||
python test_custom_ops.py -v
|
||||
python test_custom_classes.py -v
|
||||
python model.py --export-script-module="build/model.pt"
|
||||
:: Run tests C++-side and load the exported script module.
|
||||
cd build
|
||||
|
@ -102,6 +102,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
return printList(out, v.toTensorList(), "[", "]");
|
||||
case IValue::Tag::Blob:
|
||||
return out << *v.toBlob();
|
||||
case IValue::Tag::Capsule:
|
||||
return out << "Capsule";
|
||||
case IValue::Tag::GenericList:
|
||||
return printList(out, v.toGenericList(), "[", "]");
|
||||
case IValue::Tag::Future:
|
||||
@ -170,4 +172,15 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
|
||||
return ordered;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap() {
|
||||
static std::unordered_map<std::string, c10::StrongTypePtr> tmap;
|
||||
return tmap;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::function<PyObject*(void*)>>&
|
||||
getClassConverter() {
|
||||
static std::unordered_map<std::string, std::function<PyObject*(void*)>>
|
||||
classConverter;
|
||||
return classConverter;
|
||||
}
|
||||
} // namespace c10
|
||||
|
@ -3,9 +3,11 @@
|
||||
#include <ATen/core/blob.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
class CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
struct Function;
|
||||
namespace script {
|
||||
struct CompilationUnit;
|
||||
@ -49,8 +51,10 @@ struct Object;
|
||||
_(GenericDict) \
|
||||
_(Future) \
|
||||
_(Device) \
|
||||
_(Object) \
|
||||
_(Uninitialized) \
|
||||
_(Object)
|
||||
_(Capsule) \
|
||||
|
||||
|
||||
struct CAFFE2_API IValue final {
|
||||
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
|
||||
@ -148,6 +152,14 @@ struct CAFFE2_API IValue final {
|
||||
c10::intrusive_ptr<caffe2::Blob> toBlob() &&;
|
||||
c10::intrusive_ptr<caffe2::Blob> toBlob() const &;
|
||||
|
||||
// Capsule
|
||||
IValue(intrusive_ptr<torch::jit::CustomClassHolder> blob);
|
||||
bool isCapsule() const {
|
||||
return Tag::Capsule == tag;
|
||||
}
|
||||
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() &&;
|
||||
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() const &;
|
||||
|
||||
// Tuple
|
||||
IValue(c10::intrusive_ptr<ivalue::Tuple> v);
|
||||
bool isTuple() const { return Tag::Tuple == tag; }
|
||||
@ -564,6 +576,26 @@ struct StrongTypePtr {
|
||||
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
|
||||
std::shared_ptr<ClassType> type_;
|
||||
};
|
||||
|
||||
TORCH_API std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap();
|
||||
template<typename T>
|
||||
c10::StrongTypePtr getCustomClassType() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
auto res = tmap.find(typeid(T).name());
|
||||
if (res == tmap.end()) {
|
||||
throw c10::Error("Can't find class id in custom class type map", "");
|
||||
}
|
||||
return res->second;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool isCustomClassRegistered() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
return tmap.find(typeid(T).name()) != tmap.end();
|
||||
}
|
||||
|
||||
TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
|
||||
getClassConverter();
|
||||
}
|
||||
|
||||
#include <ATen/core/ivalue_inl.h>
|
||||
|
@ -24,6 +24,21 @@ struct IValue;
|
||||
struct ClassType;
|
||||
struct TupleType;
|
||||
|
||||
// For custom class __init__ registration, we need to pass in a function
|
||||
// that looks like this: [](IValue x, args...)
|
||||
|
||||
// However, kernel_functor.h automatically sets the input types of the function
|
||||
// by introspecting the types of the functor (which is IValue in this case).
|
||||
// However, we need the type it binds to be Foo.
|
||||
|
||||
// Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
|
||||
// which getTypePtr can recover the original class pointer.
|
||||
|
||||
template <typename TaggedCapsuleType>
|
||||
struct tagged_capsule {
|
||||
IValue ivalue;
|
||||
};
|
||||
|
||||
template<class T, class NullType>
|
||||
c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
|
||||
auto t = c10::intrusive_ptr<T, NullType>::reclaim(static_cast<T*>(payload.as_intrusive_ptr));
|
||||
@ -38,6 +53,11 @@ c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
|
||||
return p;
|
||||
}
|
||||
|
||||
template<class T, class U>
|
||||
intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
|
||||
return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
|
||||
}
|
||||
|
||||
inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
|
||||
AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
|
||||
return moveToIntrusivePtr<ivalue::Future>();
|
||||
@ -78,6 +98,14 @@ inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const & {
|
||||
AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
|
||||
return toIntrusivePtr<caffe2::Blob>();;
|
||||
}
|
||||
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() && {
|
||||
TORCH_INTERNAL_ASSERT(isCapsule());
|
||||
return moveToIntrusivePtr<torch::jit::CustomClassHolder>();
|
||||
}
|
||||
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() const & {
|
||||
TORCH_INTERNAL_ASSERT(isCapsule());
|
||||
return toIntrusivePtr<torch::jit::CustomClassHolder>();
|
||||
}
|
||||
|
||||
namespace ivalue {
|
||||
|
||||
@ -430,6 +458,23 @@ std::vector<Elem> generic_to(
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T generic_to(
|
||||
IValue ivalue,
|
||||
_fake_type<T>) {
|
||||
using ElemType = typename std::remove_pointer<T>::type::element_type;
|
||||
auto obj = ivalue.toObject();
|
||||
auto capsule = obj->getSlot(0);
|
||||
return c10::static_intrusive_pointer_cast<ElemType>(capsule.toCapsule());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
tagged_capsule<T> generic_to(
|
||||
IValue ivalue,
|
||||
_fake_type<tagged_capsule<T>>) {
|
||||
return tagged_capsule<T>{ivalue};
|
||||
}
|
||||
|
||||
template <typename Elem>
|
||||
c10::List<Elem> generic_to(
|
||||
IValue ivalue,
|
||||
@ -640,6 +685,10 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
|
||||
: tag(Tag::Object), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.release();
|
||||
}
|
||||
inline IValue::IValue(c10::intrusive_ptr<torch::jit::CustomClassHolder> v)
|
||||
: tag(Tag::Capsule), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.release();
|
||||
}
|
||||
inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
|
||||
: tag(Tag::Future), is_intrusive_ptr(true) {
|
||||
payload.as_intrusive_ptr = v.release();
|
||||
@ -687,4 +736,50 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
|
||||
}
|
||||
}
|
||||
|
||||
namespace ivalue {
|
||||
namespace detail {
|
||||
// This code allows us to template on a function based on whether IValue has a
|
||||
// constructor for it. Specifically, has_constructor<T>{} inherits from std::true_type if
|
||||
// IValue(T) compiles, and inherits from std::false_type if IValue(T) doesn't.
|
||||
// We use it for calling the IValue constructor for `from` if it exists, and otherwise
|
||||
// attempt to use our custom class code.
|
||||
template<class> struct type_sink { typedef void type; };
|
||||
template<class T> using type_sink_t = typename type_sink<T>::type;
|
||||
template<class T, class=void> struct has_constructor : std::false_type {}; \
|
||||
template<class T> struct has_constructor<
|
||||
T,
|
||||
type_sink_t< decltype( IValue(std::declval<T>())) >
|
||||
>: std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
IValue from_(T x, std::true_type) {
|
||||
return IValue(x);
|
||||
}
|
||||
template <typename T>
|
||||
IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
|
||||
using inputType = c10::intrusive_ptr<T>;
|
||||
if (!isCustomClassRegistered<inputType>()) {
|
||||
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
|
||||
}
|
||||
auto res = getCustomClassType<inputType>();
|
||||
auto retObject = ivalue::Object::create(res->second, 1);
|
||||
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(x);
|
||||
|
||||
retObject->setSlot(0, IValue(objPtr));
|
||||
auto resIVal = IValue(std::move(retObject));
|
||||
return resIVal;
|
||||
}
|
||||
template <typename T>
|
||||
IValue from_(T x, std::false_type) {
|
||||
static_assert(guts::false_t<T>::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
|
||||
return IValue();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
IValue from(T x) {
|
||||
return detail::from_(x, detail::has_constructor<T>{});
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace c10
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
struct ClassType;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
struct Function;
|
||||
@ -48,7 +49,8 @@ using OptNameList = c10::optional<std::vector<std::string>>;
|
||||
_(ProfiledTensorType) \
|
||||
_(DeviceObjType) \
|
||||
_(FunctionType) \
|
||||
_(ClassType)
|
||||
_(ClassType) \
|
||||
_(CapsuleType)
|
||||
|
||||
enum class TypeKind {
|
||||
#define DEFINE_TYPE(T) T,
|
||||
@ -1304,6 +1306,28 @@ struct VarType : public Type {
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
struct CapsuleType;
|
||||
using CapsuleTypePtr = std::shared_ptr<CapsuleType>;
|
||||
// This type represents a Python Capsule
|
||||
struct CAFFE2_API CapsuleType : public Type {
|
||||
static CapsuleTypePtr create() {
|
||||
return CapsuleTypePtr(new CapsuleType()); // NOLINT(modernize-make-shared)
|
||||
}
|
||||
DEFINE_IS_SUBCLASS(CapsuleType);
|
||||
bool operator==(const Type& rhs) const override {
|
||||
return rhs.kind() == kind();
|
||||
}
|
||||
std::string str() const override {
|
||||
return "Capsule";
|
||||
}
|
||||
static const TypeKind Kind = TypeKind::CapsuleType;
|
||||
// global singleton
|
||||
static CapsuleTypePtr get();
|
||||
private:
|
||||
CapsuleType()
|
||||
: Type(TypeKind::CapsuleType) {}
|
||||
};
|
||||
|
||||
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
|
||||
CAFFE2_API std::ostream& operator<<(std::ostream& out, const VaryingShape& t);
|
||||
// what is the type, ignoring extra size/shape information?
|
||||
@ -1359,9 +1383,13 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes(
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct getTypePtr_ final {
|
||||
static_assert(
|
||||
guts::false_t<T>::value,
|
||||
"Type could not be converted to any of the known types.");
|
||||
static TypePtr call() {
|
||||
if (!isCustomClassRegistered<T>()) {
|
||||
throw c10::Error("Type could not be converted to any of the known types.", "");
|
||||
}
|
||||
auto res = getCustomClassType<T>();
|
||||
return std::dynamic_pointer_cast<Type>(res.type_);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@ -1633,4 +1661,5 @@ struct CAFFE2_API ClassType : public NamedType {
|
||||
// List of methods associated with this class.
|
||||
std::vector<Function*> methods_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
|
||||
namespace c10 {
|
||||
/**
|
||||
@ -37,7 +38,10 @@ namespace detail {
|
||||
>;
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {
|
||||
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported input type.");
|
||||
assert_is_valid_input_type() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as input argument");
|
||||
}
|
||||
};
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
@ -98,7 +102,10 @@ namespace detail {
|
||||
};
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_output_type {
|
||||
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported output type.");
|
||||
assert_is_valid_output_type() {
|
||||
auto tmap = getCustomClassTypeMap();
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
|
||||
}
|
||||
};
|
||||
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
@ -170,7 +177,7 @@ namespace detail {
|
||||
template<class T, bool AllowDeprecatedTypes>
|
||||
IValue return_to_ivalue(T&& v) {
|
||||
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
|
||||
return IValue(std::move(v));
|
||||
return c10::ivalue::from(v);
|
||||
}
|
||||
|
||||
template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices>
|
||||
|
@ -119,6 +119,10 @@ OptionalTypePtr OptionalType::ofTensor() {
|
||||
static auto value = OptionalType::create(TensorType::get());
|
||||
return value;
|
||||
}
|
||||
CapsuleTypePtr CapsuleType::get() {
|
||||
static auto value = CapsuleType::create();
|
||||
return value;
|
||||
}
|
||||
ListTypePtr ListType::ofTensors() {
|
||||
static auto value = ListType::create(TensorType::get());
|
||||
return value;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
/*
|
||||
@ -229,6 +230,21 @@ constexpr auto apply(F&& f, Tuple&& t) -> decltype(detail::apply_impl(
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename Functor, typename... Args>
|
||||
typename std::enable_if<
|
||||
std::is_member_pointer<typename std::decay<Functor>::type>::value,
|
||||
typename std::result_of<Functor && (Args && ...)>::type>::type
|
||||
invoke(Functor&& f, Args&&... args) {
|
||||
return std::mem_fn(f)(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename Functor, typename... Args>
|
||||
typename std::enable_if<
|
||||
!std::is_member_pointer<typename std::decay<Functor>::type>::value,
|
||||
typename std::result_of<Functor && (Args && ...)>::type>::type
|
||||
invoke(Functor&& f, Args&&... args) {
|
||||
return std::forward<Functor>(f)(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -243,6 +259,7 @@ namespace std {
|
||||
// std::to_string() call, then you're calling std::to_string() but should be calling
|
||||
// c10::guts::to_string().
|
||||
inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
|
||||
|
||||
}
|
||||
namespace c10 { namespace guts { namespace detail {
|
||||
|
||||
|
@ -716,7 +716,7 @@ ENDIF()
|
||||
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
|
||||
FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h"
|
||||
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h"
|
||||
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
|
||||
|
||||
|
||||
|
@ -5,6 +5,11 @@ project(custom_ops)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_library(custom_ops SHARED op.cpp)
|
||||
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11/ ./pybind11)
|
||||
pybind11_add_module(custom_class SHARED classes.cpp)
|
||||
target_link_libraries(custom_class PRIVATE "${TORCH_LIBRARIES}")
|
||||
|
||||
target_compile_features(custom_ops PUBLIC cxx_range_for)
|
||||
target_link_libraries(custom_ops "${TORCH_LIBRARIES}")
|
||||
target_compile_definitions(custom_ops PRIVATE custom_ops_EXPORTS)
|
||||
|
65
test/custom_operator/classes.cpp
Normal file
65
test/custom_operator/classes.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <list>
|
||||
#include <torch/script.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
struct Foo : torch::jit::CustomClassHolder {
|
||||
int x, y;
|
||||
Foo(): x(0), y(0){}
|
||||
Foo(int x_, int y_) : x(x_), y(y_) {}
|
||||
int64_t info() {
|
||||
return this->x * this->y;
|
||||
}
|
||||
int64_t add(int64_t z) {
|
||||
return (x+y)*z;
|
||||
}
|
||||
void increment(int64_t z) {
|
||||
this->x+=z;
|
||||
this->y+=z;
|
||||
}
|
||||
int64_t combine(c10::intrusive_ptr<Foo> b) {
|
||||
return this->info() + b->info();
|
||||
}
|
||||
~Foo() {
|
||||
// std::cout<<"Destroying object with values: "<<x<<' '<<y<<std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T> struct Stack : torch::jit::CustomClassHolder {
|
||||
std::vector<T> stack_;
|
||||
Stack(std::vector<T> init): stack_(init.begin(), init.end()) {}
|
||||
|
||||
void push(T x) {
|
||||
stack_.push_back(x);
|
||||
}
|
||||
T pop() {
|
||||
auto val = stack_.back();
|
||||
stack_.pop_back();
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
static auto test = torch::jit::class_<Foo>("Foo")
|
||||
.def(torch::jit::init<int64_t, int64_t>())
|
||||
// .def(torch::jit::init<>())
|
||||
.def("info", &Foo::info)
|
||||
.def("increment", &Foo::increment)
|
||||
// .def("add", &Foo::add);
|
||||
.def("combine", &Foo::combine)
|
||||
;
|
||||
|
||||
static auto testStack = torch::jit::class_<Stack<std::string>>("StackString")
|
||||
.def(torch::jit::init<std::vector<std::string>>())
|
||||
.def("push", &Stack<std::string>::push)
|
||||
.def("pop", &Stack<std::string>::pop)
|
||||
;
|
80
test/custom_operator/test_custom_classes.py
Normal file
80
test/custom_operator/test_custom_classes.py
Normal file
@ -0,0 +1,80 @@
|
||||
import unittest
|
||||
import torch
|
||||
from torch import ops
|
||||
import torch.jit as jit
|
||||
import glob
|
||||
import os
|
||||
|
||||
def get_custom_class_library_path():
|
||||
library_filename = glob.glob("build/*custom_class*")
|
||||
assert (len(library_filename) == 1)
|
||||
library_filename = library_filename[0]
|
||||
path = os.path.abspath(library_filename)
|
||||
assert os.path.exists(path), path
|
||||
return path
|
||||
|
||||
def test_equality(f, cmp_key):
|
||||
obj1 = f()
|
||||
obj2 = jit.script(f)()
|
||||
return (cmp_key(obj1), cmp_key(obj2))
|
||||
|
||||
class TestCustomOperators(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ops.load_library(get_custom_class_library_path())
|
||||
|
||||
def test_no_return_class(self):
|
||||
def f():
|
||||
val = torch.classes.Foo(5, 3)
|
||||
return val.info()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_constructor_with_args(self):
|
||||
def f():
|
||||
val = torch.classes.Foo(5, 3)
|
||||
return val
|
||||
self.assertEqual(*test_equality(f, lambda x: x.info()))
|
||||
|
||||
def test_function_call_with_args(self):
|
||||
def f():
|
||||
val = torch.classes.Foo(5, 3)
|
||||
val.increment(1)
|
||||
return val
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x.info()))
|
||||
|
||||
def test_function_method_wrong_type(self):
|
||||
def f():
|
||||
val = torch.classes.Foo(5, 3)
|
||||
val.increment("asdf")
|
||||
return val
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected"):
|
||||
jit.script(f)()
|
||||
|
||||
@unittest.skip("We currently don't support passing custom classes to custom methods.")
|
||||
def test_input_class_type(self):
|
||||
def f():
|
||||
val = torch.classes.Foo(1, 2)
|
||||
val2 = torch.classes.Foo(2, 3)
|
||||
val.combine(val2)
|
||||
return val
|
||||
|
||||
self.assertEqual(*test_equality(f, lambda x: x.info()))
|
||||
|
||||
def test_stack_string(self):
|
||||
def f():
|
||||
val = torch.classes.StackString(["asdf", "bruh"])
|
||||
return val.pop()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
def test_stack_push_pop(self):
|
||||
def f():
|
||||
val = torch.classes.StackString(["asdf", "bruh"])
|
||||
val2 = torch.classes.StackString(["111", "222"])
|
||||
val.push(val2.pop())
|
||||
return val.pop() + val2.pop()
|
||||
self.assertEqual(*test_equality(f, lambda x: x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -244,7 +244,7 @@ if (USE_NCCL)
|
||||
endif()
|
||||
|
||||
# In the most recent CMake versions, a new 'TRANSFORM' subcommand of 'list' allows much of the boilerplate of defining the lists
|
||||
# of type stub files to be omitted.
|
||||
# of type stub files to be omitted.
|
||||
# For comptability with older CMake versions, we omit it for now, but leave it as a comment in case comptability with the older
|
||||
# CMake versions is eventually dropped.
|
||||
# set(Modules
|
||||
|
@ -336,6 +336,7 @@ def compiled_with_cxx11_abi():
|
||||
|
||||
# Import the ops "namespace"
|
||||
from torch._ops import ops # noqa: F401
|
||||
from torch._classes import classes # noqa: F401
|
||||
|
||||
# Import the quasi random sampler
|
||||
import torch.quasirandom
|
||||
|
9
torch/_classes.py
Normal file
9
torch/_classes.py
Normal file
@ -0,0 +1,9 @@
|
||||
import types
|
||||
|
||||
class _Classes(types.ModuleType):
|
||||
def __init__(self):
|
||||
super(_Classes, self).__init__('torch.classes')
|
||||
|
||||
|
||||
# The classes "namespace"
|
||||
classes = _Classes()
|
@ -397,6 +397,10 @@ def _qualified_name(obj):
|
||||
name = obj.__name__
|
||||
module_name = obj.__module__
|
||||
|
||||
# If the module is actually a torchbind module, then we should short circuit
|
||||
if module_name == "torch._classes":
|
||||
return obj.qualified_name
|
||||
|
||||
# The Python docs are very clear that `__module__` can be None, but I can't
|
||||
# figure out when it actually would be.
|
||||
if module_name is None:
|
||||
|
@ -106,6 +106,5 @@ class _Ops(types.ModuleType):
|
||||
ctypes.CDLL(path)
|
||||
self.loaded_libraries.add(path)
|
||||
|
||||
|
||||
# The ops "namespace"
|
||||
ops = _Ops()
|
||||
|
@ -1,22 +1,16 @@
|
||||
#ifndef THP_EXPORT_H
|
||||
#define THP_EXPORT_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
# define THP_EXTERNC extern "C"
|
||||
#else
|
||||
# define THP_EXTERNC extern
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
# ifdef _THP_CORE
|
||||
# define THP_API THP_EXTERNC __declspec(dllexport)
|
||||
# define THP_API extern __declspec(dllexport)
|
||||
# define THP_CLASS __declspec(dllexport)
|
||||
# else
|
||||
# define THP_API THP_EXTERNC __declspec(dllimport)
|
||||
# define THP_API extern __declspec(dllimport)
|
||||
# define THP_CLASS __declspec(dllimport)
|
||||
# endif
|
||||
#else
|
||||
# define THP_API THP_EXTERNC
|
||||
# define THP_API extern
|
||||
# define THP_CLASS
|
||||
#endif
|
||||
|
||||
|
@ -112,6 +112,8 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
|
||||
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
|
||||
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
|
||||
return false;
|
||||
if (lhs_outputs[i]->type() == CapsuleType::get())
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check whether the inputs are the same.
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/Layout.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
@ -448,6 +449,8 @@ inline IValue toIValue(
|
||||
break;
|
||||
case TypeKind::FunctionType:
|
||||
AT_ERROR("Function Values aren't yet supported");
|
||||
case TypeKind::CapsuleType:
|
||||
AT_ERROR("Capsule Values aren't supported");
|
||||
}
|
||||
AT_ERROR(
|
||||
"Missing cases in toIValue for type: ",
|
||||
@ -510,6 +513,17 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) {
|
||||
}
|
||||
}
|
||||
|
||||
inline c10::optional<py::object> tryToConvertToCustomClass(
|
||||
const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
|
||||
if (obj->name().find("__torch__.torch.classes") == 0) {
|
||||
auto objPtr = (void*)obj->getSlot(0).toCapsule().release();
|
||||
auto classConverter = c10::getClassConverter()[obj->name()];
|
||||
py::handle rawPyObj = classConverter(objPtr);
|
||||
auto o = py::reinterpret_steal<py::object>(rawPyObj);
|
||||
return o;
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
inline py::object toPyObject(IValue&& ivalue) {
|
||||
if (ivalue.isNone()) {
|
||||
return py::none();
|
||||
@ -573,6 +587,10 @@ inline py::object toPyObject(IValue&& ivalue) {
|
||||
} else if (ivalue.isObject()) {
|
||||
const auto obj = std::move(ivalue).toObject();
|
||||
auto pyCu = get_python_cu();
|
||||
auto res = tryToConvertToCustomClass(obj);
|
||||
if (res.has_value()) {
|
||||
return res.value();
|
||||
}
|
||||
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
|
||||
AT_ASSERT(classType);
|
||||
auto pyClass =
|
||||
|
@ -19,6 +19,7 @@ using c10::GeneratorType;
|
||||
using c10::IntType;
|
||||
using c10::ListType;
|
||||
using c10::NoneType;
|
||||
using c10::CapsuleType;
|
||||
using c10::NumberType;
|
||||
using c10::OptionalType;
|
||||
using c10::StringType;
|
||||
@ -45,6 +46,7 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
|
||||
{"int", IntType::get()},
|
||||
{"bool", BoolType::get()},
|
||||
{"None", NoneType::get()},
|
||||
{"Capsule", CapsuleType::get()},
|
||||
};
|
||||
auto tok = L.cur();
|
||||
if (!L.nextIf(TK_NONE)) {
|
||||
|
207
torch/custom_class.h
Normal file
207
torch/custom_class.h
Normal file
@ -0,0 +1,207 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/jit_type.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/core/stack.h>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/TypeList.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/csrc/jit/operator.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/jit/script/compilation_unit.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
#include <torch/csrc/utils/variadic.h>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
static std::vector<c10::RegisterOperators> registeredOps;
|
||||
|
||||
namespace detail {
|
||||
template <class R, class...>
|
||||
struct types {
|
||||
constexpr static bool hasRet = true;
|
||||
using type = types;
|
||||
};
|
||||
template <class... args>
|
||||
struct types<void, args...> {
|
||||
constexpr static bool hasRet = false;
|
||||
using type = types;
|
||||
};
|
||||
template <class Sig>
|
||||
struct args;
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
|
||||
template <class Sig>
|
||||
using args_t = typename args<Sig>::type;
|
||||
} // namespace detail
|
||||
template <class... Types>
|
||||
detail::types<void, Types...> init() { return detail::types<void, Types...>{}; }
|
||||
|
||||
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
|
||||
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
|
||||
// - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and
|
||||
// Torchscript, and puts it under `torch.classes.Foo` in Python.
|
||||
// - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template
|
||||
// metaprogramming to introspect the function types and register the operator
|
||||
// for use in Torchscript.
|
||||
// - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers
|
||||
// the Foo(int, int) constructor.
|
||||
// see test/custom_operator/classes.cpp and
|
||||
// test/custom_operator/test_custom_classes.py for example usages
|
||||
|
||||
template <class CurClass>
|
||||
class class_ {
|
||||
std::string className;
|
||||
std::string qualClassName;
|
||||
c10::optional<py::class_<CurClass>> pyClass = c10::nullopt;
|
||||
std::shared_ptr<script::CompilationUnit> classCu = nullptr;
|
||||
ClassTypePtr classTypePtr;
|
||||
|
||||
const std::string parentModule = "classes";
|
||||
const std::string topModule = "__torch__.torch";
|
||||
|
||||
public:
|
||||
class_(string className_) : className(std::move(className_)) {
|
||||
// Currently we register everything as a python class just for convenience.
|
||||
// We'll want to remove this at some point to get rid of the python
|
||||
// dependency. It would require significant changes to class registration,
|
||||
// (I think)?
|
||||
qualClassName = topModule + "." + parentModule + "." + className;
|
||||
|
||||
auto obj = py::module::import("torch").attr(parentModule.c_str());
|
||||
pyClass = py::class_<CurClass>(obj, className.c_str());
|
||||
pyClass->attr("qualified_name") = py::str(qualClassName);
|
||||
auto newClass =
|
||||
py::module::import("torch.jit")
|
||||
.attr("_add_script_class")(*pyClass, qualClassName.c_str());
|
||||
|
||||
auto castToPython = [](void* objPtr) -> PyObject* {
|
||||
CurClass x = *static_cast<CurClass*>(objPtr);
|
||||
auto py_object = py::cast(x);
|
||||
PyObject* rawPyObj = py_object.release().ptr();
|
||||
return rawPyObj;
|
||||
};
|
||||
getClassConverter()[qualClassName] = castToPython;
|
||||
|
||||
// We currently represent custom classes as torchscript classes with a
|
||||
// capsule attribute
|
||||
classCu = torch::jit::get_python_cu();
|
||||
classTypePtr =
|
||||
ClassType::create(c10::QualifiedName(qualClassName), classCu);
|
||||
classTypePtr->addAttribute("capsule", CapsuleType::get());
|
||||
|
||||
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
|
||||
StrongTypePtr(classCu, classTypePtr)});
|
||||
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
|
||||
StrongTypePtr(classCu, classTypePtr)});
|
||||
|
||||
classCu->register_class(classTypePtr);
|
||||
}
|
||||
|
||||
template <typename... Types>
|
||||
class_& def(detail::types<void, Types...>) { // Used in combination with
|
||||
// torch::jit::init<...>()
|
||||
pyClass->def(py::init<Types...>());
|
||||
|
||||
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
||||
auto classObj = c10::make_intrusive<CurClass>(args...);
|
||||
auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(classObj);
|
||||
auto capsule = IValue(genericPtr);
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, capsule);
|
||||
};
|
||||
|
||||
defineMethod<void>("__init__", std::move(func), false);
|
||||
return *this;
|
||||
}
|
||||
template <typename Func>
|
||||
class_& def(string name, Func f) {
|
||||
auto res = def_(name, f, detail::args_t<decltype(f)>{});
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
struct addInput {
|
||||
static Value* call(std::shared_ptr<Graph> graph) {
|
||||
return graph->addInput()->setType(getTypePtr<T>());
|
||||
}
|
||||
};
|
||||
template <class Func, size_t... arg_indices>
|
||||
std::vector<Value*> addInputs_(
|
||||
Func f,
|
||||
std::shared_ptr<Graph> graph,
|
||||
guts::index_sequence<arg_indices...>) {
|
||||
using argTypes =
|
||||
typename guts::infer_function_traits_t<Func>::parameter_types;
|
||||
std::vector<Value*> res = {
|
||||
addInput<guts::typelist::element_t<arg_indices, argTypes>>::call(
|
||||
graph)...};
|
||||
return res;
|
||||
}
|
||||
template <class Func>
|
||||
std::vector<Value*> addInputs(Func f, std::shared_ptr<Graph> graph) {
|
||||
constexpr auto numArgs =
|
||||
guts::infer_function_traits_t<Func>::number_of_parameters;
|
||||
return addInputs_(f, graph, guts::make_index_sequence<numArgs>());
|
||||
}
|
||||
|
||||
template <typename Last>
|
||||
std::string type_name() {
|
||||
return std::string(typeid(Last).name());
|
||||
}
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
std::string type_name() {
|
||||
return type_name<First>() + "_" + type_name<Second, Rest...>();
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void addType(Value* v) {
|
||||
v->setType(getTypePtr<T>());
|
||||
}
|
||||
template<typename R, typename Func>
|
||||
void defineMethod(std::string name, Func func, bool hasRet) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto qualFuncName = className + "::" + name;
|
||||
registeredOps.push_back(
|
||||
torch::RegisterOperators().op(qualFuncName, std::move(func)));
|
||||
|
||||
|
||||
std::vector<Value*> inputs = addInputs(func, graph);
|
||||
auto methodCall = graph->insertNode(graph->create(
|
||||
Symbol::fromQualString(qualFuncName), inputs, hasRet));
|
||||
Value* res;
|
||||
if (hasRet) {
|
||||
res = methodCall->output();
|
||||
addType<R>(res);
|
||||
} else {
|
||||
res = graph->insertConstant(IValue())->setType(NoneType::get());
|
||||
}
|
||||
graph->registerOutput(res);
|
||||
|
||||
classCu->create_function(qualClassName + "." + name, graph);
|
||||
}
|
||||
template <typename Func, typename R, typename... Types>
|
||||
class_& def_(string name, Func f, detail::types<R, Types...> funcInfo) {
|
||||
pyClass->def(name.c_str(), f);
|
||||
|
||||
auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) {
|
||||
return guts::invoke(f, *cur, args...);
|
||||
};
|
||||
defineMethod<R>(name, std::move(func), funcInfo.hasRet);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
||||
} // namespace torch
|
Reference in New Issue
Block a user