Initial torchbind prototype (#21098)

Summary:
I have some test code in there as well, along with a script "test_libtorch" to run it. You'll need to modify `test_libtorch` to point to where you have `pytorch` built. I currently require that `pybind11` is included as a subdirectory of the test, but added it to the `.gitignore` to make this reviewable.

Currently, something like this works:
```cpp
struct Foo {
  int x, y;
  Foo(): x(2), y(5){}
  Foo(int x_, int y_) : x(x_), y(y_) {}
  void display() {
    cout<<"x: "<<x<<' '<<"y: "<<y<<endl;
  }
  int64_t add(int64_t z) {
    return (x+y)*z;
  }
};
static auto test = torch::jit::class_<Foo>("Foo")
                    .def(torch::jit::init<int64_t, int64_t>())
                    .def("display", &Foo::display)
                    .def("add", &Foo::add)
                    .def("combine", &Foo::combine);

```
with
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val.display()
    print(val.add(3))
```
results in
```
x: 5 y: 3
24
```

Current issues:
- [x] The python class created by torchscript doesn't interactly properly with the surrounding code.
```
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    return val
```
- [x] Doesn't properly take in non-pointer classes. Can't define this function signature in cpp (We don't want to support this I believe).
```cpp
  void combine(Foo x) {
```

- [x] Has some issues with memory for blobs when constructing multiple objects (fix constant propagation pass to not treat capsules as the same object).
```py
torch.jit.script
def f(x):
    val = torch._C.Foo(5, 3)
    val2 = torch._C.Foo(100, 0)
    val.display()
    print(val.add(3))
```
- [ ] Can't define multiple constructors (need to define overload string. Currently not possible since we don't support overloaded methods).
- [x] `init` is a little bit different syntax than `pybind`. `.init<...>()` instead of `.def(py::init<>())`
- [x] I couldn't figure out how to add some files into the build so they'd be copied to the `include/` directories, so I symlinked them manually.
- [ ] Currently, the conversion from Python into Torchscript doesn't work.
- [ ] Torchbind also currently requires Python/Pybind dependency. Fixing this would probably involve some kind of macro to bind into Python when possible.
- [ ] We pass back into Python by value, currently. There's no way of passing by reference.
- [x] Currently can only register one method with the same type signature. This is because we create a `static auto opRegistry`, and the function is templated on the type signature.

Somewhat blocked on https://github.com/pytorch/pytorch/pull/21177. We currently use some structures that will be refactored by his PR (namely `return_type_to_ivalue` and `ivalue_to_arg_type`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21098

Differential Revision: D16634872

Pulled By: Chillee

fbshipit-source-id: 1408bb89ea649c27d560df59e2cf9920467fe1de
This commit is contained in:
Horace He
2019-08-02 18:41:34 -07:00
committed by Facebook Github Bot
parent 4e6e11c139
commit f81db8afb8
24 changed files with 607 additions and 20 deletions

View File

@ -116,6 +116,7 @@ test_custom_script_ops() {
# Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module=model.pt
# Run tests C++-side and load the exported script module.
build/test_custom_ops ./model.pt

View File

@ -162,6 +162,7 @@ test_custom_script_ops() {
cp -a "$CUSTOM_OP_BUILD" build
# Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module=model.pt
# Run tests C++-side and load the exported script module.
build/test_custom_ops ./model.pt

View File

@ -1,5 +1,6 @@
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
git submodule update --init --recursive third_party/pybind11
cd test\custom_operator
:: Build the custom operator library.
@ -23,6 +24,7 @@ popd
:: Run tests Python-side and export a script module.
python test_custom_ops.py -v
python test_custom_classes.py -v
python model.py --export-script-module="build/model.pt"
:: Run tests C++-side and load the exported script module.
cd build

View File

@ -102,6 +102,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printList(out, v.toTensorList(), "[", "]");
case IValue::Tag::Blob:
return out << *v.toBlob();
case IValue::Tag::Capsule:
return out << "Capsule";
case IValue::Tag::GenericList:
return printList(out, v.toGenericList(), "[", "]");
case IValue::Tag::Future:
@ -170,4 +172,15 @@ std::vector<std::pair<IValue, IValue>> iterationOrder(const c10::Dict<IValue, IV
return ordered;
}
std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap() {
static std::unordered_map<std::string, c10::StrongTypePtr> tmap;
return tmap;
}
std::unordered_map<std::string, std::function<PyObject*(void*)>>&
getClassConverter() {
static std::unordered_map<std::string, std::function<PyObject*(void*)>>
classConverter;
return classConverter;
}
} // namespace c10

View File

@ -3,9 +3,11 @@
#include <ATen/core/blob.h>
#include <c10/util/intrusive_ptr.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
namespace torch {
namespace jit {
class CustomClassHolder : public c10::intrusive_ptr_target {};
struct Function;
namespace script {
struct CompilationUnit;
@ -49,8 +51,10 @@ struct Object;
_(GenericDict) \
_(Future) \
_(Device) \
_(Object) \
_(Uninitialized) \
_(Object)
_(Capsule) \
struct CAFFE2_API IValue final {
IValue() : payload{0}, tag(Tag::None), is_intrusive_ptr(false) {}
@ -148,6 +152,14 @@ struct CAFFE2_API IValue final {
c10::intrusive_ptr<caffe2::Blob> toBlob() &&;
c10::intrusive_ptr<caffe2::Blob> toBlob() const &;
// Capsule
IValue(intrusive_ptr<torch::jit::CustomClassHolder> blob);
bool isCapsule() const {
return Tag::Capsule == tag;
}
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() &&;
c10::intrusive_ptr<torch::jit::CustomClassHolder> toCapsule() const &;
// Tuple
IValue(c10::intrusive_ptr<ivalue::Tuple> v);
bool isTuple() const { return Tag::Tuple == tag; }
@ -564,6 +576,26 @@ struct StrongTypePtr {
std::shared_ptr<torch::jit::script::CompilationUnit> cu_;
std::shared_ptr<ClassType> type_;
};
TORCH_API std::unordered_map<std::string, c10::StrongTypePtr>& getCustomClassTypeMap();
template<typename T>
c10::StrongTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(typeid(T).name());
if (res == tmap.end()) {
throw c10::Error("Can't find class id in custom class type map", "");
}
return res->second;
}
template<typename T>
inline bool isCustomClassRegistered() {
auto tmap = c10::getCustomClassTypeMap();
return tmap.find(typeid(T).name()) != tmap.end();
}
TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
getClassConverter();
}
#include <ATen/core/ivalue_inl.h>

View File

@ -24,6 +24,21 @@ struct IValue;
struct ClassType;
struct TupleType;
// For custom class __init__ registration, we need to pass in a function
// that looks like this: [](IValue x, args...)
// However, kernel_functor.h automatically sets the input types of the function
// by introspecting the types of the functor (which is IValue in this case).
// However, we need the type it binds to be Foo.
// Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
// which getTypePtr can recover the original class pointer.
template <typename TaggedCapsuleType>
struct tagged_capsule {
IValue ivalue;
};
template<class T, class NullType>
c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
auto t = c10::intrusive_ptr<T, NullType>::reclaim(static_cast<T*>(payload.as_intrusive_ptr));
@ -38,6 +53,11 @@ c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
return p;
}
template<class T, class U>
intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
}
inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
return moveToIntrusivePtr<ivalue::Future>();
@ -78,6 +98,14 @@ inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const & {
AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
return toIntrusivePtr<caffe2::Blob>();;
}
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() && {
TORCH_INTERNAL_ASSERT(isCapsule());
return moveToIntrusivePtr<torch::jit::CustomClassHolder>();
}
inline c10::intrusive_ptr<torch::jit::CustomClassHolder> IValue::toCapsule() const & {
TORCH_INTERNAL_ASSERT(isCapsule());
return toIntrusivePtr<torch::jit::CustomClassHolder>();
}
namespace ivalue {
@ -430,6 +458,23 @@ std::vector<Elem> generic_to(
return result;
}
template <typename T>
T generic_to(
IValue ivalue,
_fake_type<T>) {
using ElemType = typename std::remove_pointer<T>::type::element_type;
auto obj = ivalue.toObject();
auto capsule = obj->getSlot(0);
return c10::static_intrusive_pointer_cast<ElemType>(capsule.toCapsule());
}
template <typename T>
tagged_capsule<T> generic_to(
IValue ivalue,
_fake_type<tagged_capsule<T>>) {
return tagged_capsule<T>{ivalue};
}
template <typename Elem>
c10::List<Elem> generic_to(
IValue ivalue,
@ -640,6 +685,10 @@ inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
: tag(Tag::Object), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline IValue::IValue(c10::intrusive_ptr<torch::jit::CustomClassHolder> v)
: tag(Tag::Capsule), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
: tag(Tag::Future), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
@ -687,4 +736,50 @@ inline bool IValue::isSameIdentity(const IValue& rhs) const {
}
}
namespace ivalue {
namespace detail {
// This code allows us to template on a function based on whether IValue has a
// constructor for it. Specifically, has_constructor<T>{} inherits from std::true_type if
// IValue(T) compiles, and inherits from std::false_type if IValue(T) doesn't.
// We use it for calling the IValue constructor for `from` if it exists, and otherwise
// attempt to use our custom class code.
template<class> struct type_sink { typedef void type; };
template<class T> using type_sink_t = typename type_sink<T>::type;
template<class T, class=void> struct has_constructor : std::false_type {}; \
template<class T> struct has_constructor<
T,
type_sink_t< decltype( IValue(std::declval<T>())) >
>: std::true_type {};
template <typename T>
IValue from_(T x, std::true_type) {
return IValue(x);
}
template <typename T>
IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
using inputType = c10::intrusive_ptr<T>;
if (!isCustomClassRegistered<inputType>()) {
throw c10::Error("Trying to return a class that we don't support and isn't a registered custom class.", "");
}
auto res = getCustomClassType<inputType>();
auto retObject = ivalue::Object::create(res->second, 1);
auto objPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(x);
retObject->setSlot(0, IValue(objPtr));
auto resIVal = IValue(std::move(retObject));
return resIVal;
}
template <typename T>
IValue from_(T x, std::false_type) {
static_assert(guts::false_t<T>::value, "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
return IValue();
}
}
template <typename T>
IValue from(T x) {
return detail::from_(x, detail::has_constructor<T>{});
}
}
} // namespace c10

View File

@ -13,6 +13,7 @@
#include <memory>
#include <type_traits>
struct ClassType;
namespace torch {
namespace jit {
struct Function;
@ -48,7 +49,8 @@ using OptNameList = c10::optional<std::vector<std::string>>;
_(ProfiledTensorType) \
_(DeviceObjType) \
_(FunctionType) \
_(ClassType)
_(ClassType) \
_(CapsuleType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
@ -1304,6 +1306,28 @@ struct VarType : public Type {
std::string name_;
};
struct CapsuleType;
using CapsuleTypePtr = std::shared_ptr<CapsuleType>;
// This type represents a Python Capsule
struct CAFFE2_API CapsuleType : public Type {
static CapsuleTypePtr create() {
return CapsuleTypePtr(new CapsuleType()); // NOLINT(modernize-make-shared)
}
DEFINE_IS_SUBCLASS(CapsuleType);
bool operator==(const Type& rhs) const override {
return rhs.kind() == kind();
}
std::string str() const override {
return "Capsule";
}
static const TypeKind Kind = TypeKind::CapsuleType;
// global singleton
static CapsuleTypePtr get();
private:
CapsuleType()
: Type(TypeKind::CapsuleType) {}
};
CAFFE2_API std::ostream& operator<<(std::ostream& out, const Type& t);
CAFFE2_API std::ostream& operator<<(std::ostream& out, const VaryingShape& t);
// what is the type, ignoring extra size/shape information?
@ -1359,9 +1383,13 @@ CAFFE2_API c10::optional<TypePtr> unifyTypes(
namespace detail {
template <typename T>
struct getTypePtr_ final {
static_assert(
guts::false_t<T>::value,
"Type could not be converted to any of the known types.");
static TypePtr call() {
if (!isCustomClassRegistered<T>()) {
throw c10::Error("Type could not be converted to any of the known types.", "");
}
auto res = getCustomClassType<T>();
return std::dynamic_pointer_cast<Type>(res.type_);
}
};
template <>
@ -1633,4 +1661,5 @@ struct CAFFE2_API ClassType : public NamedType {
// List of methods associated with this class.
std::vector<Function*> methods_;
};
} // namespace c10

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/op_registration/infer_schema.h>
#include <ATen/core/ivalue.h>
namespace c10 {
/**
@ -37,7 +38,10 @@ namespace detail {
>;
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_input_type {
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported input type.");
assert_is_valid_input_type() {
auto tmap = c10::getCustomClassTypeMap();
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as input argument");
}
};
template<class T, bool AllowDeprecatedTypes>
@ -98,7 +102,10 @@ namespace detail {
};
template<class T, bool AllowDeprecatedTypes, class Enable = void> struct assert_is_valid_output_type {
static_assert(guts::false_t<T>::value, "You tried to register a kernel with an unsupported output type.");
assert_is_valid_output_type() {
auto tmap = getCustomClassTypeMap();
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class as output");
}
};
template<class T, bool AllowDeprecatedTypes>
@ -170,7 +177,7 @@ namespace detail {
template<class T, bool AllowDeprecatedTypes>
IValue return_to_ivalue(T&& v) {
assert_is_valid_output_type<T, AllowDeprecatedTypes>();
return IValue(std::move(v));
return c10::ivalue::from(v);
}
template<class Functor, bool AllowDeprecatedTypes, size_t... ivalue_arg_indices>

View File

@ -119,6 +119,10 @@ OptionalTypePtr OptionalType::ofTensor() {
static auto value = OptionalType::create(TensorType::get());
return value;
}
CapsuleTypePtr CapsuleType::get() {
static auto value = CapsuleType::create();
return value;
}
ListTypePtr ListType::ofTensors() {
static auto value = ListType::create(TensorType::get());
return value;

View File

@ -8,6 +8,7 @@
#include <sstream>
#include <string>
#include <cstdlib>
#include <functional>
#include <c10/macros/Macros.h>
/*
@ -229,6 +230,21 @@ constexpr auto apply(F&& f, Tuple&& t) -> decltype(detail::apply_impl(
#endif
#endif
template <typename Functor, typename... Args>
typename std::enable_if<
std::is_member_pointer<typename std::decay<Functor>::type>::value,
typename std::result_of<Functor && (Args && ...)>::type>::type
invoke(Functor&& f, Args&&... args) {
return std::mem_fn(f)(std::forward<Args>(args)...);
}
template <typename Functor, typename... Args>
typename std::enable_if<
!std::is_member_pointer<typename std::decay<Functor>::type>::value,
typename std::result_of<Functor && (Args && ...)>::type>::type
invoke(Functor&& f, Args&&... args) {
return std::forward<Functor>(f)(std::forward<Args>(args)...);
}
@ -243,6 +259,7 @@ namespace std {
// std::to_string() call, then you're calling std::to_string() but should be calling
// c10::guts::to_string().
inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
}
namespace c10 { namespace guts { namespace detail {

View File

@ -716,7 +716,7 @@ ENDIF()
install(DIRECTORY "${TORCH_SRC_DIR}/csrc"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch
FILES_MATCHING PATTERN "*.h")
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h"
install(FILES "${TORCH_SRC_DIR}/script.h" "${TORCH_SRC_DIR}/extension.h" "${TORCH_SRC_DIR}/custom_class.h"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)

View File

@ -5,6 +5,11 @@ project(custom_ops)
find_package(Torch REQUIRED)
add_library(custom_ops SHARED op.cpp)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11/ ./pybind11)
pybind11_add_module(custom_class SHARED classes.cpp)
target_link_libraries(custom_class PRIVATE "${TORCH_LIBRARIES}")
target_compile_features(custom_ops PUBLIC cxx_range_for)
target_link_libraries(custom_ops "${TORCH_LIBRARIES}")
target_compile_definitions(custom_ops PRIVATE custom_ops_EXPORTS)

View File

@ -0,0 +1,65 @@
#include <cassert>
#include <climits>
#include <cstring>
#include <iostream>
#include <iterator>
#include <list>
#include <torch/script.h>
#include <torch/custom_class.h>
#include <pybind11/pybind11.h>
using namespace std;
namespace py = pybind11;
struct Foo : torch::jit::CustomClassHolder {
int x, y;
Foo(): x(0), y(0){}
Foo(int x_, int y_) : x(x_), y(y_) {}
int64_t info() {
return this->x * this->y;
}
int64_t add(int64_t z) {
return (x+y)*z;
}
void increment(int64_t z) {
this->x+=z;
this->y+=z;
}
int64_t combine(c10::intrusive_ptr<Foo> b) {
return this->info() + b->info();
}
~Foo() {
// std::cout<<"Destroying object with values: "<<x<<' '<<y<<std::endl;
}
};
template <class T> struct Stack : torch::jit::CustomClassHolder {
std::vector<T> stack_;
Stack(std::vector<T> init): stack_(init.begin(), init.end()) {}
void push(T x) {
stack_.push_back(x);
}
T pop() {
auto val = stack_.back();
stack_.pop_back();
return val;
}
};
static auto test = torch::jit::class_<Foo>("Foo")
.def(torch::jit::init<int64_t, int64_t>())
// .def(torch::jit::init<>())
.def("info", &Foo::info)
.def("increment", &Foo::increment)
// .def("add", &Foo::add);
.def("combine", &Foo::combine)
;
static auto testStack = torch::jit::class_<Stack<std::string>>("StackString")
.def(torch::jit::init<std::vector<std::string>>())
.def("push", &Stack<std::string>::push)
.def("pop", &Stack<std::string>::pop)
;

View File

@ -0,0 +1,80 @@
import unittest
import torch
from torch import ops
import torch.jit as jit
import glob
import os
def get_custom_class_library_path():
library_filename = glob.glob("build/*custom_class*")
assert (len(library_filename) == 1)
library_filename = library_filename[0]
path = os.path.abspath(library_filename)
assert os.path.exists(path), path
return path
def test_equality(f, cmp_key):
obj1 = f()
obj2 = jit.script(f)()
return (cmp_key(obj1), cmp_key(obj2))
class TestCustomOperators(unittest.TestCase):
def setUp(self):
ops.load_library(get_custom_class_library_path())
def test_no_return_class(self):
def f():
val = torch.classes.Foo(5, 3)
return val.info()
self.assertEqual(*test_equality(f, lambda x: x))
def test_constructor_with_args(self):
def f():
val = torch.classes.Foo(5, 3)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_function_call_with_args(self):
def f():
val = torch.classes.Foo(5, 3)
val.increment(1)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_function_method_wrong_type(self):
def f():
val = torch.classes.Foo(5, 3)
val.increment("asdf")
return val
with self.assertRaisesRegex(RuntimeError, "Expected"):
jit.script(f)()
@unittest.skip("We currently don't support passing custom classes to custom methods.")
def test_input_class_type(self):
def f():
val = torch.classes.Foo(1, 2)
val2 = torch.classes.Foo(2, 3)
val.combine(val2)
return val
self.assertEqual(*test_equality(f, lambda x: x.info()))
def test_stack_string(self):
def f():
val = torch.classes.StackString(["asdf", "bruh"])
return val.pop()
self.assertEqual(*test_equality(f, lambda x: x))
def test_stack_push_pop(self):
def f():
val = torch.classes.StackString(["asdf", "bruh"])
val2 = torch.classes.StackString(["111", "222"])
val.push(val2.pop())
return val.pop() + val2.pop()
self.assertEqual(*test_equality(f, lambda x: x))
if __name__ == "__main__":
unittest.main()

View File

@ -244,7 +244,7 @@ if (USE_NCCL)
endif()
# In the most recent CMake versions, a new 'TRANSFORM' subcommand of 'list' allows much of the boilerplate of defining the lists
# of type stub files to be omitted.
# of type stub files to be omitted.
# For comptability with older CMake versions, we omit it for now, but leave it as a comment in case comptability with the older
# CMake versions is eventually dropped.
# set(Modules

View File

@ -336,6 +336,7 @@ def compiled_with_cxx11_abi():
# Import the ops "namespace"
from torch._ops import ops # noqa: F401
from torch._classes import classes # noqa: F401
# Import the quasi random sampler
import torch.quasirandom

9
torch/_classes.py Normal file
View File

@ -0,0 +1,9 @@
import types
class _Classes(types.ModuleType):
def __init__(self):
super(_Classes, self).__init__('torch.classes')
# The classes "namespace"
classes = _Classes()

View File

@ -397,6 +397,10 @@ def _qualified_name(obj):
name = obj.__name__
module_name = obj.__module__
# If the module is actually a torchbind module, then we should short circuit
if module_name == "torch._classes":
return obj.qualified_name
# The Python docs are very clear that `__module__` can be None, but I can't
# figure out when it actually would be.
if module_name is None:

View File

@ -106,6 +106,5 @@ class _Ops(types.ModuleType):
ctypes.CDLL(path)
self.loaded_libraries.add(path)
# The ops "namespace"
ops = _Ops()

View File

@ -1,22 +1,16 @@
#ifndef THP_EXPORT_H
#define THP_EXPORT_H
#ifdef __cplusplus
# define THP_EXTERNC extern "C"
#else
# define THP_EXTERNC extern
#endif
#ifdef _WIN32
# ifdef _THP_CORE
# define THP_API THP_EXTERNC __declspec(dllexport)
# define THP_API extern __declspec(dllexport)
# define THP_CLASS __declspec(dllexport)
# else
# define THP_API THP_EXTERNC __declspec(dllimport)
# define THP_API extern __declspec(dllimport)
# define THP_CLASS __declspec(dllimport)
# endif
#else
# define THP_API THP_EXTERNC
# define THP_API extern
# define THP_CLASS
#endif

View File

@ -112,6 +112,8 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
for (size_t i = 0; i < lhs_outputs.size(); ++i) {
if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
return false;
if (lhs_outputs[i]->type() == CapsuleType::get())
return false;
}
// Check whether the inputs are the same.

View File

@ -6,6 +6,7 @@
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/tracer.h>
@ -448,6 +449,8 @@ inline IValue toIValue(
break;
case TypeKind::FunctionType:
AT_ERROR("Function Values aren't yet supported");
case TypeKind::CapsuleType:
AT_ERROR("Capsule Values aren't supported");
}
AT_ERROR(
"Missing cases in toIValue for type: ",
@ -510,6 +513,17 @@ inline IValue returnToIValue(const TypePtr& type, py::handle object) {
}
}
inline c10::optional<py::object> tryToConvertToCustomClass(
const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
if (obj->name().find("__torch__.torch.classes") == 0) {
auto objPtr = (void*)obj->getSlot(0).toCapsule().release();
auto classConverter = c10::getClassConverter()[obj->name()];
py::handle rawPyObj = classConverter(objPtr);
auto o = py::reinterpret_steal<py::object>(rawPyObj);
return o;
}
return c10::nullopt;
}
inline py::object toPyObject(IValue&& ivalue) {
if (ivalue.isNone()) {
return py::none();
@ -573,6 +587,10 @@ inline py::object toPyObject(IValue&& ivalue) {
} else if (ivalue.isObject()) {
const auto obj = std::move(ivalue).toObject();
auto pyCu = get_python_cu();
auto res = tryToConvertToCustomClass(obj);
if (res.has_value()) {
return res.value();
}
const auto classType = pyCu->get_class(c10::QualifiedName(obj->name()));
AT_ASSERT(classType);
auto pyClass =

View File

@ -19,6 +19,7 @@ using c10::GeneratorType;
using c10::IntType;
using c10::ListType;
using c10::NoneType;
using c10::CapsuleType;
using c10::NumberType;
using c10::OptionalType;
using c10::StringType;
@ -45,6 +46,7 @@ TypeAndAlias SchemaTypeParser::parseBaseType() {
{"int", IntType::get()},
{"bool", BoolType::get()},
{"None", NoneType::get()},
{"Capsule", CapsuleType::get()},
};
auto tok = L.cur();
if (!L.nextIf(TK_NONE)) {

207
torch/custom_class.h Normal file
View File

@ -0,0 +1,207 @@
#pragma once
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/core/stack.h>
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/utils/variadic.h>
#include <iostream>
#include <sstream>
namespace py = pybind11;
namespace torch {
namespace jit {
static std::vector<c10::RegisterOperators> registeredOps;
namespace detail {
template <class R, class...>
struct types {
constexpr static bool hasRet = true;
using type = types;
};
template <class... args>
struct types<void, args...> {
constexpr static bool hasRet = false;
using type = types;
};
template <class Sig>
struct args;
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
template <class Sig>
using args_t = typename args<Sig>::type;
} // namespace detail
template <class... Types>
detail::types<void, Types...> init() { return detail::types<void, Types...>{}; }
// To bind custom classes into Torchscript, use an API very similar to Pybind's.
// Currently exposes one class `torch::jit::class_<T>` and 2 methods.
// - Constructing `torch::jit::class_<Foo>` registers `Foo` in Python and
// Torchscript, and puts it under `torch.classes.Foo` in Python.
// - torch::jit::class_<Foo>.def("method1", &Foo::method1) does some template
// metaprogramming to introspect the function types and register the operator
// for use in Torchscript.
// - torch::jit::class_<Foo>.def(torch::jit::init<int64_t, int64_t>()) registers
// the Foo(int, int) constructor.
// see test/custom_operator/classes.cpp and
// test/custom_operator/test_custom_classes.py for example usages
template <class CurClass>
class class_ {
std::string className;
std::string qualClassName;
c10::optional<py::class_<CurClass>> pyClass = c10::nullopt;
std::shared_ptr<script::CompilationUnit> classCu = nullptr;
ClassTypePtr classTypePtr;
const std::string parentModule = "classes";
const std::string topModule = "__torch__.torch";
public:
class_(string className_) : className(std::move(className_)) {
// Currently we register everything as a python class just for convenience.
// We'll want to remove this at some point to get rid of the python
// dependency. It would require significant changes to class registration,
// (I think)?
qualClassName = topModule + "." + parentModule + "." + className;
auto obj = py::module::import("torch").attr(parentModule.c_str());
pyClass = py::class_<CurClass>(obj, className.c_str());
pyClass->attr("qualified_name") = py::str(qualClassName);
auto newClass =
py::module::import("torch.jit")
.attr("_add_script_class")(*pyClass, qualClassName.c_str());
auto castToPython = [](void* objPtr) -> PyObject* {
CurClass x = *static_cast<CurClass*>(objPtr);
auto py_object = py::cast(x);
PyObject* rawPyObj = py_object.release().ptr();
return rawPyObj;
};
getClassConverter()[qualClassName] = castToPython;
// We currently represent custom classes as torchscript classes with a
// capsule attribute
classCu = torch::jit::get_python_cu();
classTypePtr =
ClassType::create(c10::QualifiedName(qualClassName), classCu);
classTypePtr->addAttribute("capsule", CapsuleType::get());
c10::getCustomClassTypeMap().insert({typeid(c10::intrusive_ptr<CurClass>).name(),
StrongTypePtr(classCu, classTypePtr)});
c10::getCustomClassTypeMap().insert({typeid(c10::tagged_capsule<CurClass>).name(),
StrongTypePtr(classCu, classTypePtr)});
classCu->register_class(classTypePtr);
}
template <typename... Types>
class_& def(detail::types<void, Types...>) { // Used in combination with
// torch::jit::init<...>()
pyClass->def(py::init<Types...>());
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto genericPtr = c10::static_intrusive_pointer_cast<torch::jit::CustomClassHolder>(classObj);
auto capsule = IValue(genericPtr);
auto object = self.ivalue.toObject();
object->setSlot(0, capsule);
};
defineMethod<void>("__init__", std::move(func), false);
return *this;
}
template <typename Func>
class_& def(string name, Func f) {
auto res = def_(name, f, detail::args_t<decltype(f)>{});
return *this;
}
private:
template <class T>
struct addInput {
static Value* call(std::shared_ptr<Graph> graph) {
return graph->addInput()->setType(getTypePtr<T>());
}
};
template <class Func, size_t... arg_indices>
std::vector<Value*> addInputs_(
Func f,
std::shared_ptr<Graph> graph,
guts::index_sequence<arg_indices...>) {
using argTypes =
typename guts::infer_function_traits_t<Func>::parameter_types;
std::vector<Value*> res = {
addInput<guts::typelist::element_t<arg_indices, argTypes>>::call(
graph)...};
return res;
}
template <class Func>
std::vector<Value*> addInputs(Func f, std::shared_ptr<Graph> graph) {
constexpr auto numArgs =
guts::infer_function_traits_t<Func>::number_of_parameters;
return addInputs_(f, graph, guts::make_index_sequence<numArgs>());
}
template <typename Last>
std::string type_name() {
return std::string(typeid(Last).name());
}
template <typename First, typename Second, typename... Rest>
std::string type_name() {
return type_name<First>() + "_" + type_name<Second, Rest...>();
}
template <class T>
void addType(Value* v) {
v->setType(getTypePtr<T>());
}
template<typename R, typename Func>
void defineMethod(std::string name, Func func, bool hasRet) {
auto graph = std::make_shared<Graph>();
auto qualFuncName = className + "::" + name;
registeredOps.push_back(
torch::RegisterOperators().op(qualFuncName, std::move(func)));
std::vector<Value*> inputs = addInputs(func, graph);
auto methodCall = graph->insertNode(graph->create(
Symbol::fromQualString(qualFuncName), inputs, hasRet));
Value* res;
if (hasRet) {
res = methodCall->output();
addType<R>(res);
} else {
res = graph->insertConstant(IValue())->setType(NoneType::get());
}
graph->registerOutput(res);
classCu->create_function(qualClassName + "." + name, graph);
}
template <typename Func, typename R, typename... Types>
class_& def_(string name, Func f, detail::types<R, Types...> funcInfo) {
pyClass->def(name.c_str(), f);
auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) {
return guts::invoke(f, *cur, args...);
};
defineMethod<R>(name, std::move(func), funcInfo.hasRet);
return *this;
}
};
} // namespace jit
} // namespace torch