Revert "Delete functorch C extension entirely. (#163340)"

This reverts commit 1faf6367e396b1d0894e8735912a47ac465f469d.

Reverted https://github.com/pytorch/pytorch/pull/163340 on behalf of https://github.com/wdvr due to temporary revert to pull out #162659 ([comment](https://github.com/pytorch/pytorch/pull/163340#issuecomment-3317105243))
This commit is contained in:
PyTorch MergeBot
2025-09-22 06:20:04 +00:00
parent f0078941cf
commit ae5be038a6
12 changed files with 4883 additions and 0 deletions

View File

@ -3,6 +3,7 @@ dist/
functorch.egg-info/
*__pycache__*
functorch/version.py
functorch/_C.so
.gdbinit
t.py
.vscode/

45
functorch/CMakeLists.txt Normal file
View File

@ -0,0 +1,45 @@
cmake_minimum_required(VERSION 3.18)
project(functorch)
set(CMAKE_CXX_STANDARD 17)
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
set(FT_DIR csrc)
file(GLOB_RECURSE FT_SOURCES ${FT_DIR}/*.cpp ${FT_DIR}/*.c)
add_library(${PROJECT_NAME} MODULE ${FT_SOURCES})
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_compile_definitions(${PROJECT_NAME} PRIVATE FUNCTORCH_BUILD_MAIN_LIB)
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_EXTENSION_NAME=_C)
target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
target_compile_options(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS})
target_compile_options_if_supported(${PROJECT_NAME} "-Wmissing-prototypes")
target_compile_options_if_supported(${PROJECT_NAME} "-Werror=missing-prototypes")
if(BUILD_LIBTORCHLESS)
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIB} torch_python)
else()
# functorch cannot use the alias to build on windows
target_link_libraries(${PROJECT_NAME} PRIVATE torch torch_python)
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE pybind::pybind11)
set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
${CMAKE_BINARY_DIR}/functorch)
set_target_properties(${PROJECT_NAME} PROPERTIES INSTALL_RPATH "${_rpath_portable_origin}/../torch/lib")
# Copy-pasted prefix/suffix logic for Python extensions from
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L1975
# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L2022
# TODO: It would be good to be able to use Python3_add_library target, but it does not work in many cases
set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" DEBUG_POSTFIX "")
if(WIN32)
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".pyd")
else()
set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".so")
endif()
# Needed to link functorch on MacOS
if(NOT ${TORCH_PYTHON_LINK_FLAGS} STREQUAL "")
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS})
endif()
install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}")

332
functorch/csrc/dim/arena.h Normal file
View File

@ -0,0 +1,332 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/ATen.h>
#include "minpybind.h"
#if defined(_MSC_VER) && !defined(__clang__)
#include <intrin.h>
// https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code
inline unsigned int __builtin_clz(unsigned int x) {
unsigned long r = 0;
_BitScanReverse(&r, x);
return (31 - r);
}
#endif
inline int round2min8(int num) {
int nzeros = __builtin_clz((num - 1)|4);
return 1 << (32 - nzeros);
}
struct Arena;
template<typename T>
struct OwnedSlice;
template<typename T>
struct Slice {
Slice()
: begin_(nullptr), size_(0), capacity_(0) {}
template<typename... Args>
Slice(Arena& arena, Args&&... args);
T* begin() const {
return begin_;
}
T* end() const {
return begin_ + size_;
}
int size() const {
return size_;
}
int capacity() const {
return capacity_;
}
T& back(int i=-1) {
return begin_[size_ + i];
}
T& operator[](int i) const {
return begin_[i];
}
std::optional<int> index(const T& value) {
for (int i : enumerate()) {
if (begin_[i] == value) {
return i;
}
}
return std::nullopt;
}
bool contains(const T& value) {
return index(value).has_value();
}
void insert(Arena& arena, Slice where, Slice to_insert);
void insert(Arena& arena, Slice where, T v) {
return insert(arena, where, Slice(&v, &v + 1));
}
void insert(Arena& arena, int where, T v) {
return insert(arena, slice(where, where), v);
}
void append(Arena& arena, T value);
void extend(Arena& arena, Slice to_insert);
void extend(Arena& arena, const T* begin, const T* end) {
return extend(arena, Slice<T>((T*)begin, (T*)end));
}
bool remove(Arena& A, T value) {
auto idx = index(value);
if (idx) {
insert(A, slice(*idx, *idx + 1), Slice());
}
return idx.has_value();
}
Slice slice(int begin) {
return slice(begin, size_);
}
Slice slice(int begin, int end) {
if (begin < 0) {
begin += size_;
}
if (end < 0) {
end += size_;
}
Slice result;
result.begin_ = begin_ + begin;
result.size_ = end - begin;
result.capacity_ = result.size_;
return result;
}
bool inside(Slice where) {
return begin() <= where.begin() && where.end() <= end();
}
irange enumerate() const {
return irange(size_);
}
irange reversed_enumerate() const {
return irange(size_ - 1, -1, -1);
}
bool operator==(const Slice<T>& rhs) const {
if (size() != rhs.size()) {
return false;
}
return std::equal(begin(), end(), rhs.begin());
}
Slice(T* begin, T* end)
: begin_(begin), size_(end - begin), capacity_(size_) {}
protected:
static int _length(const T& t) {
return 1;
}
static int _length(Slice t) {
return t.size_;
}
static T* _insert(T*& dst, T t) {
*dst = std::move(t);
return ++dst;
}
static T* _insert(T*& dst, Slice t) {
std::memcpy(dst, t.begin_, sizeof(T)*t.size_);
dst += t.size_;
return dst;
}
T* begin_;
int size_;
int capacity_;
friend struct OwnedSlice<T>;
};
template<typename T>
struct OwnedSlice {
typedef void (*deleter_t)(Slice<T>);
static void _no_delete(Slice<T>) {}
OwnedSlice()
: deleter_(_no_delete) {}
OwnedSlice(const OwnedSlice&) = delete;
OwnedSlice& operator=(const OwnedSlice&) = delete;
~OwnedSlice() {
deleter_(slice_);
if (slice_.size_ > 8) {
delete [] slice_.begin_;
}
}
void set(Slice<T> to_own, deleter_t deleter = _no_delete) {
slice_.size_ = slice_.capacity_ = to_own.size();
slice_.begin_ = (slice_.size_ > 8) ? new T[slice_.size_] : &small_buf[0];
std::memcpy(slice_.begin_, to_own.begin(), slice_.size_ * sizeof(T));
deleter_ = deleter;
}
Slice<T> slice() const {
return slice_;
}
private:
Slice<T> slice_;
deleter_t deleter_;
T small_buf[8];
};
template<typename T>
inline std::ostream& operator<<(std::ostream& s, const Slice<T>& v) {
s << "[";
for (int i : v.enumerate()) {
if (i > 0) {
s << ", ";
}
s << v[i];
}
s << "]";
return s;
}
struct TensorRef {
TensorRef()
: impl_(nullptr){}
TensorRef(const at::Tensor& t)
: impl_(t.unsafeGetTensorImpl()) {}
const at::Tensor& operator*() const {
return *(at::Tensor*)this;
}
at::Tensor* operator->() const {
return (at::Tensor*)this;
}
operator bool() const {
return impl_ != nullptr;
}
private:
at::TensorImpl* impl_;
};
constexpr int ARENA_MAX_SIZE = 4096;
constexpr int ALIGNMENT = 8;
struct Arena {
Arena()
: allocated_(0) {}
template<typename T>
T* allocate(int n) {
if (!n) {
return nullptr;
}
int to_allocate = sizeof(T)*n;
int to_allocate_rounded = ALIGNMENT * ((to_allocate - 1) / ALIGNMENT + 1);
auto prev_allocated = allocated_;
allocated_ += to_allocate_rounded;
if (C10_UNLIKELY_OR_CONST(allocated_ > ARENA_MAX_SIZE)) {
overflow_.emplace_back(new char[to_allocate]);
return (T*) &overflow_.back()[0];
}
return (T*) (buffer_ + prev_allocated);
}
TensorRef autorelease(at::Tensor s) {
auto ref = TensorRef(s);
s.unsafeReleaseTensorImpl();
ar_tensors_.append(*this, ref);
return ref;
}
mpy::handle autorelease(mpy::object obj) {
ar_objects_.append(*this, obj);
obj.release();
return ar_objects_.back();
}
~Arena() {
for(TensorRef t: ar_tensors_) {
c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(t->unsafeGetTensorImpl());
}
for(mpy::handle h: ar_objects_) {
mpy::object::steal(h);
}
}
private:
int64_t allocated_;
char buffer_[ARENA_MAX_SIZE];
Slice<TensorRef> ar_tensors_;
Slice<mpy::handle> ar_objects_;
std::vector<std::unique_ptr<char[]>> overflow_;
};
template<typename T>
inline void Slice<T>::insert(Arena& arena, Slice where, Slice to_insert) {
AT_ASSERT(inside(where));
Slice result = *this;
/// b------sb---se-----e, 0----n
T* body_dest = where.begin();
if (where.size() != to_insert.size()) {
int new_size = size() - where.size() + to_insert.size();
T* tail_dest = where.begin() + to_insert.size();
if (new_size >= capacity_) {
int new_capacity = new_size ? round2min8(new_size) : 0;
result.capacity_ = new_capacity;
result.begin_ = arena.allocate<T>(new_capacity);
body_dest = result.begin_ + (where.begin() - begin());
tail_dest = body_dest + to_insert.size();
//std::memcpy(result.begin_, begin_, sizeof(T)*(where.begin() - begin()));
std::copy(begin_, begin_ + (where.begin() - begin()), result.begin_);
}
std::memmove(tail_dest, where.end(), sizeof(T)*(end() - where.end()));
result.size_ = new_size;
}
//std::memcpy(body_dest, to_insert.begin(), sizeof(T)*to_insert.size());
std::copy(to_insert.begin(), to_insert.end(), body_dest);
*this = result;
}
template<typename T>
inline void Slice<T>::append(Arena& arena, T value) {
Slice result = *this;
if (size_ == capacity_) {
int new_size = size_ ? round2min8(size_)*2 : 8;
T* n = arena.allocate<T>(new_size);
//memcpy(n, begin_, size_*sizeof(T));
std::copy(begin_, begin_ + size_, n);
result.begin_ = n;
result.capacity_ = new_size;
}
result[result.size_++] = std::move(value);
*this = result;
}
template<typename T>
inline void Slice<T>::extend(Arena& arena, Slice<T> rhs) {
Slice result = *this;
result.size_ = size_ + rhs.size();
if (result.size_ > capacity_) {
int new_size = round2min8(result.size_);
T* n = arena.allocate<T>(new_size);
//memcpy(n, begin_, size_*sizeof(T));
std::copy(begin_, begin_+size_, n);
result.begin_ = n;
result.capacity_ = new_size;
}
//memcpy(result.begin_ + size_, rhs.begin(), sizeof(T)*rhs.size());
std::copy(rhs.begin(), rhs.end(), result.begin_ + size_);
*this = result;
}
template<typename T>
template<typename... Args>
Slice<T>::Slice(Arena& arena, Args&&... args) {
int lens[] = {_length(args)...};
size_ = 0;
for (auto i : lens) {
size_ += i;
}
capacity_ = size_ ? round2min8(size_) : 0;
begin_ = arena.allocate<T>(capacity_);
T* dst_ = begin_;
T* unused[] = {_insert(dst_, args)...};
(void) unused;
}

3656
functorch/csrc/dim/dim.cpp Normal file

File diff suppressed because it is too large Load Diff

8
functorch/csrc/dim/dim.h Normal file
View File

@ -0,0 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <Python.h>
PyObject* Dim_init();

View File

@ -0,0 +1,17 @@
#include <torch/csrc/utils/python_compat.h>
#if defined(_WIN32) && IS_PYTHON_3_11_PLUS
#define Py_BUILD_CORE
#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt, _PyOpcode_Caches
#if IS_PYTHON_3_13_PLUS
#include <cpython/code.h> // To get PyUnstable_Code_GetFirstFree
#define NEED_OPCODE_METADATA
#include "internal/pycore_opcode_metadata.h"
#undef NEED_OPCODE_METADATA
#else
#include "internal/pycore_opcode.h"
#endif
#undef NEED_OPCODE_TABLES
#undef Py_BUILD_CORE
#endif

View File

@ -0,0 +1,692 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <utility>
#include <ostream>
#include <memory>
#define PY_BEGIN try {
#define PY_END(v) } catch(mpy::exception_set & err) { return (v); }
#if PY_VERSION_HEX < 0x03080000
#define PY_VECTORCALL _PyObject_FastCallKeywords
#else
#define PY_VECTORCALL _PyObject_Vectorcall
#endif
struct irange {
public:
irange(int64_t end)
: irange(0, end, 1) {}
irange(int64_t begin, int64_t end, int64_t step = 1)
: begin_(begin), end_(end), step_(step) {}
int64_t operator*() const {
return begin_;
}
irange& operator++() {
begin_ += step_;
return *this;
}
bool operator!=(const irange& other) {
return begin_ != other.begin_;
}
irange begin() {
return *this;
}
irange end() {
return irange {end_, end_, step_};
}
private:
int64_t begin_;
int64_t end_;
int64_t step_;
};
namespace mpy {
struct exception_set {
};
struct object;
struct vector_args;
struct handle {
handle(PyObject* ptr)
: ptr_(ptr) {}
handle() = default;
PyObject* ptr() const {
return ptr_;
}
object attr(const char* key);
bool hasattr(const char* key);
handle type() const {
return (PyObject*) Py_TYPE(ptr());
}
template<typename... Args>
object call(Args&&... args);
object call_object(mpy::handle args);
object call_object(mpy::handle args, mpy::handle kwargs);
object call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames);
object call_vector(vector_args args);
bool operator==(handle rhs) {
return ptr_ == rhs.ptr_;
}
static handle checked(PyObject* ptr) {
if (!ptr) {
throw exception_set();
}
return ptr;
}
protected:
PyObject* ptr_ = nullptr;
};
template<typename T>
struct obj;
template<typename T>
struct hdl : public handle {
T* ptr() {
return (T*) handle::ptr();
}
T* operator->() {
return ptr();
}
hdl(T* ptr)
: hdl((PyObject*) ptr) {}
hdl(const obj<T>& o)
: hdl(o.ptr()) {}
private:
hdl(handle h) : handle(h) {}
};
struct object : public handle {
object() = default;
object(const object& other)
: handle(other.ptr_) {
Py_XINCREF(ptr_);
}
object(object&& other) noexcept
: handle(other.ptr_) {
other.ptr_ = nullptr;
}
object& operator=(const object& other) {
return *this = object(other);
}
object& operator=(object&& other) noexcept {
PyObject* tmp = ptr_;
ptr_ = other.ptr_;
other.ptr_ = tmp;
return *this;
}
~object() {
Py_XDECREF(ptr_);
}
static object steal(handle o) {
return object(o.ptr());
}
static object checked_steal(handle o) {
if (!o.ptr()) {
throw exception_set();
}
return steal(o);
}
static object borrow(handle o) {
Py_XINCREF(o.ptr());
return steal(o);
}
PyObject* release() {
auto tmp = ptr_;
ptr_ = nullptr;
return tmp;
}
protected:
explicit object(PyObject* ptr)
: handle(ptr) {}
};
template<typename T>
struct obj : public object {
obj() = default;
obj(const obj& other)
: object(other.ptr_) {
Py_XINCREF(ptr_);
}
obj(obj&& other) noexcept
: object(other.ptr_) {
other.ptr_ = nullptr;
}
obj& operator=(const obj& other) {
return *this = obj(other);
}
obj& operator=(obj&& other) noexcept {
PyObject* tmp = ptr_;
ptr_ = other.ptr_;
other.ptr_ = tmp;
return *this;
}
static obj steal(hdl<T> o) {
return obj(o.ptr());
}
static obj checked_steal(hdl<T> o) {
if (!o.ptr()) {
throw exception_set();
}
return steal(o);
}
static obj borrow(hdl<T> o) {
Py_XINCREF(o.ptr());
return steal(o);
}
T* ptr() const {
return (T*) object::ptr();
}
T* operator->() {
return ptr();
}
protected:
explicit obj(T* ptr)
: object((PyObject*)ptr) {}
};
static bool isinstance(handle h, handle c) {
return PyObject_IsInstance(h.ptr(), c.ptr());
}
[[ noreturn ]] inline void raise_error(handle exception, const char *format, ...) {
va_list args;
va_start(args, format);
PyErr_FormatV(exception.ptr(), format, args);
va_end(args);
throw exception_set();
}
template<typename T>
struct base {
PyObject_HEAD
PyObject* ptr() const {
return (PyObject*) this;
}
static obj<T> alloc(PyTypeObject* type = nullptr) {
if (!type) {
type = &T::Type;
}
auto self = (T*) type->tp_alloc(type, 0);
if (!self) {
throw mpy::exception_set();
}
new (self) T;
return obj<T>::steal(self);
}
template<typename ... Args>
static obj<T> create(Args ... args) {
auto self = alloc();
self->init(std::forward<Args>(args)...);
return self;
}
static bool check(handle v) {
return isinstance(v, (PyObject*)&T::Type);
}
static hdl<T> unchecked_wrap(handle self_) {
return hdl<T>((T*)self_.ptr());
}
static hdl<T> wrap(handle self_) {
if (!check(self_)) {
raise_error(PyExc_ValueError, "not an instance of %S", &T::Type);
}
return unchecked_wrap(self_);
}
static obj<T> unchecked_wrap(object self_) {
return obj<T>::steal(unchecked_wrap(self_.release()));
}
static obj<T> wrap(object self_) {
return obj<T>::steal(wrap(self_.release()));
}
static PyObject* new_stub(PyTypeObject *type, PyObject *args, PyObject *kwds) {
PY_BEGIN
return (PyObject*) alloc(type).release();
PY_END(nullptr)
}
static void dealloc_stub(PyObject *self) {
((T*)self)->~T();
Py_TYPE(self)->tp_free(self);
}
static void ready(mpy::handle mod, const char* name) {
if (PyType_Ready(&T::Type)) {
throw exception_set();
}
if(PyModule_AddObject(mod.ptr(), name, (PyObject*) &T::Type) < 0) {
throw exception_set();
}
}
};
inline object handle::attr(const char* key) {
return object::checked_steal(PyObject_GetAttrString(ptr(), key));
}
inline bool handle::hasattr(const char* key) {
return PyObject_HasAttrString(ptr(), key);
}
inline object import(const char* module) {
return object::checked_steal(PyImport_ImportModule(module));
}
template<typename... Args>
inline object handle::call(Args&&... args) {
return object::checked_steal(PyObject_CallFunctionObjArgs(ptr_, args.ptr()..., nullptr));
}
inline object handle::call_object(mpy::handle args) {
return object::checked_steal(PyObject_CallObject(ptr(), args.ptr()));
}
inline object handle::call_object(mpy::handle args, mpy::handle kwargs) {
return object::checked_steal(PyObject_Call(ptr(), args.ptr(), kwargs.ptr()));
}
inline object handle::call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames) {
return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) begin, nargs, kwnames.ptr()));
}
struct tuple : public object {
void set(int i, object v) {
PyTuple_SET_ITEM(ptr_, i, v.release());
}
tuple(int size)
: object(checked_steal(PyTuple_New(size))) {}
};
struct list : public object {
void set(int i, object v) {
PyList_SET_ITEM(ptr_, i, v.release());
}
list(int size)
: object(checked_steal(PyList_New(size))) {}
};
namespace{
mpy::object unicode_from_format(const char* format, ...) {
va_list args;
va_start(args, format);
auto r = PyUnicode_FromFormatV(format, args);
va_end(args);
return mpy::object::checked_steal(r);
}
mpy::object unicode_from_string(const char * str) {
return mpy::object::checked_steal(PyUnicode_FromString(str));
}
mpy::object from_int(Py_ssize_t s) {
return mpy::object::checked_steal(PyLong_FromSsize_t(s));
}
mpy::object from_bool(bool b) {
return mpy::object::borrow(b ? Py_True : Py_False);
}
bool is_sequence(handle h) {
return PySequence_Check(h.ptr());
}
}
struct sequence_view : public handle {
sequence_view(handle h)
: handle(h) {}
Py_ssize_t size() const {
auto r = PySequence_Size(ptr());
if (r == -1 && PyErr_Occurred()) {
throw mpy::exception_set();
}
return r;
}
irange enumerate() const {
return irange(size());
}
static sequence_view wrap(handle h) {
if (!is_sequence(h)) {
raise_error(PyExc_ValueError, "expected a sequence");
}
return sequence_view(h);
}
mpy::object operator[](Py_ssize_t i) const {
return mpy::object::checked_steal(PySequence_GetItem(ptr(), i));
}
};
namespace {
mpy::object repr(handle h) {
return mpy::object::checked_steal(PyObject_Repr(h.ptr()));
}
mpy::object str(handle h) {
return mpy::object::checked_steal(PyObject_Str(h.ptr()));
}
bool is_int(handle h) {
return PyLong_Check(h.ptr());
}
bool is_none(handle h) {
return h.ptr() == Py_None;
}
Py_ssize_t to_int(handle h) {
Py_ssize_t r = PyLong_AsSsize_t(h.ptr());
if (r == -1 && PyErr_Occurred()) {
throw mpy::exception_set();
}
return r;
}
bool to_bool(handle h) {
return PyObject_IsTrue(h.ptr()) != 0;
}
}
struct slice_view {
slice_view(handle h, Py_ssize_t size) {
if(PySlice_Unpack(h.ptr(), &start, &stop, &step) == -1) {
throw mpy::exception_set();
}
slicelength = PySlice_AdjustIndices(size, &start, &stop, step);
}
Py_ssize_t start, stop, step, slicelength;
};
static bool is_slice(handle h) {
return PySlice_Check(h.ptr());
}
inline std::ostream& operator<<(std::ostream& ss, handle h) {
ss << PyUnicode_AsUTF8(str(h).ptr());
return ss;
}
struct tuple_view : public handle {
tuple_view() = default;
tuple_view(handle h) : handle(h) {}
Py_ssize_t size() const {
return PyTuple_GET_SIZE(ptr());
}
irange enumerate() const {
return irange(size());
}
handle operator[](Py_ssize_t i) {
return PyTuple_GET_ITEM(ptr(), i);
}
static bool check(handle h) {
return PyTuple_Check(h.ptr());
}
};
struct list_view : public handle {
list_view() = default;
list_view(handle h) : handle(h) {}
Py_ssize_t size() const {
return PyList_GET_SIZE(ptr());
}
irange enumerate() const {
return irange(size());
}
handle operator[](Py_ssize_t i) {
return PyList_GET_ITEM(ptr(), i);
}
static bool check(handle h) {
return PyList_Check(h.ptr());
}
};
struct dict_view : public handle {
dict_view() = default;
dict_view(handle h) : handle(h) {}
object keys() const {
return mpy::object::checked_steal(PyDict_Keys(ptr()));
}
object values() const {
return mpy::object::checked_steal(PyDict_Values(ptr()));
}
object items() const {
return mpy::object::checked_steal(PyDict_Items(ptr()));
}
bool contains(handle k) const {
return PyDict_Contains(ptr(), k.ptr());
}
handle operator[](handle k) {
return mpy::handle::checked(PyDict_GetItem(ptr(), k.ptr()));
}
static bool check(handle h) {
return PyDict_Check(h.ptr());
}
bool next(Py_ssize_t* pos, mpy::handle* key, mpy::handle* value) {
PyObject *k = nullptr, *v = nullptr;
auto r = PyDict_Next(ptr(), pos, &k, &v);
*key = k;
*value = v;
return r;
}
void set(handle k, handle v) {
if (-1 == PyDict_SetItem(ptr(), k.ptr(), v.ptr())) {
throw exception_set();
}
}
};
struct kwnames_view : public handle {
kwnames_view() = default;
kwnames_view(handle h) : handle(h) {}
Py_ssize_t size() const {
return PyTuple_GET_SIZE(ptr());
}
irange enumerate() const {
return irange(size());
}
const char* operator[](Py_ssize_t i) const {
PyObject* obj = PyTuple_GET_ITEM(ptr(), i);
return PyUnicode_AsUTF8(obj);
}
static bool check(handle h) {
return PyTuple_Check(h.ptr());
}
};
inline mpy::object funcname(mpy::handle func) {
if (func.hasattr("__name__")) {
return func.attr("__name__");
} else {
return mpy::str(func);
}
}
struct vector_args {
vector_args(PyObject *const *a,
Py_ssize_t n,
PyObject *k)
: vector_args((mpy::handle*)a, n, k) {}
vector_args(mpy::handle* a,
Py_ssize_t n,
mpy::handle k)
: args((mpy::handle*)a), nargs(n), kwnames(k) {}
mpy::handle* args;
Py_ssize_t nargs;
kwnames_view kwnames;
mpy::handle* begin() {
return args;
}
mpy::handle* end() {
return args + size();
}
mpy::handle operator[](int64_t i) const {
return args[i];
}
bool has_keywords() const {
return kwnames.ptr();
}
irange enumerate_positional() {
return irange(nargs);
}
irange enumerate_all() {
return irange(size());
}
int64_t size() const {
return nargs + (has_keywords() ? kwnames.size() : 0);
}
// bind a test function so this can be tested, first two args for required/kwonly, then return what was parsed...
// provide write kwarg
// don't provide a required arg
// don't provide an optional arg
// provide a kwarg that is the name of already provided positional
// provide a kwonly argument positionally
// provide keyword arguments in the wrong order
// provide only keyword arguments
void parse(const char * fname_cstr, std::initializer_list<const char*> names, std::initializer_list<mpy::handle*> values, int required, int kwonly=0) {
auto error = [&]() {
// rather than try to match the slower infrastructure with error messages exactly, once we have detected an error, just use that
// infrastructure to format it and throw it
// have to leak this, because python expects these to last
const char** names_buf = new const char*[names.size() + 1];
std::copy(names.begin(), names.end(), &names_buf[0]);
names_buf[names.size()] = nullptr;
#if PY_VERSION_HEX < 0x03080000
char* format_str = new char[names.size() + 3];
int i = 0;
char* format_it = format_str;
for (auto it = names.begin(); it != names.end(); ++it, ++i) {
if (i == required) {
*format_it++ = '|';
}
if (i == (int)names.size() - kwonly) {
*format_it++ = '$';
}
*format_it++ = 'O';
}
*format_it++ = '\0';
_PyArg_Parser* _parser = new _PyArg_Parser{format_str, &names_buf[0], fname_cstr, 0};
PyObject *dummy = NULL;
_PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
#else
_PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
auto buf = std::make_unique<PyObject*[]>(names.size());
_PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
#endif
throw exception_set();
};
auto values_it = values.begin();
auto names_it = names.begin();
auto npositional = values.size() - kwonly;
if (nargs > (Py_ssize_t)npositional) {
// TOO MANY ARGUMENTS
error();
}
for (auto i : irange(nargs)) {
*(*values_it++) = args[i];
++names_it;
}
if (!kwnames.ptr()) {
if (nargs < required) {
// not enough positional arguments
error();
}
} else {
int consumed = 0;
for (auto i : irange(nargs, values.size())) {
bool success = i >= required;
const char* target_name = *(names_it++);
for (auto j : kwnames.enumerate()) {
if (!strcmp(target_name,kwnames[j])) {
*(*values_it) = args[nargs + j];
++consumed;
success = true;
break;
}
}
++values_it;
if (!success) {
// REQUIRED ARGUMENT NOT SPECIFIED
error();
}
}
if (consumed != kwnames.size()) {
// NOT ALL KWNAMES ARGUMENTS WERE USED
error();
}
}
}
int index(const char* name, int pos) {
if (pos < nargs) {
return pos;
}
if (kwnames.ptr()) {
for (auto j : kwnames.enumerate()) {
if (!strcmp(name, kwnames[j])) {
return nargs + j;
}
}
}
return -1;
}
};
inline object handle::call_vector(vector_args args) {
return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) args.args, args.nargs, args.kwnames.ptr()));
}
}
#define MPY_ARGS_NAME(typ, name) #name ,
#define MPY_ARGS_DECLARE(typ, name) typ name;
#define MPY_ARGS_POINTER(typ, name) &name ,
#define MPY_PARSE_ARGS_KWARGS(fmt, FORALL_ARGS) \
static char* kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
FORALL_ARGS(MPY_ARGS_DECLARE) \
if (!PyArg_ParseTupleAndKeywords(args, kwargs, fmt, kwlist, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
throw mpy::exception_set(); \
}
#define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
FORALL_ARGS(MPY_ARGS_DECLARE) \
static _PyArg_Parser parser = {fmt, kwlist, 0}; \
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
throw mpy::exception_set(); \
}

View File

@ -0,0 +1,49 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
// note: pytorch's python variable simple includes pybind which conflicts with minpybind
// so this file just reproduces the minimal API needed to extract Tensors from python objects.
#include <torch/csrc/python_headers.h>
#include <ATen/core/Tensor.h>
#include <torch/csrc/Export.h>
// Python object that backs torch.autograd.Variable
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct THPVariable {
PyObject_HEAD;
// Payload
c10::MaybeOwned<at::Tensor> cdata;
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks = nullptr;
};
TORCH_PYTHON_API extern PyObject *THPVariableClass;
TORCH_PYTHON_API extern PyObject *ParameterClass;
TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var);
inline bool THPVariable_Check(PyObject *obj)
{
if (!THPVariableClass)
return false;
const auto result = PyObject_IsInstance(obj, THPVariableClass);
AT_ASSERT(result != -1);
return result;
}
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
return *var->cdata;
}
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
return THPVariable_Unpack(reinterpret_cast<THPVariable*>(obj));
}
TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter();

View File

@ -0,0 +1,22 @@
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/extension.h>
#include <functorch/csrc/dim/dim.h>
namespace at {
namespace functorch {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// initialize first-class dims and install it as a submodule on _C
auto dim = Dim_init();
if (!dim) {
throw py::error_already_set();
}
py::setattr(m, "dim", py::reinterpret_steal<py::object>(dim));
}
}}