From ae5be038a6381bb11dcd5d7e7c3321ed84dd3c90 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 22 Sep 2025 06:20:04 +0000 Subject: [PATCH] Revert "Delete functorch C extension entirely. (#163340)" This reverts commit 1faf6367e396b1d0894e8735912a47ac465f469d. Reverted https://github.com/pytorch/pytorch/pull/163340 on behalf of https://github.com/wdvr due to temporary revert to pull out #162659 ([comment](https://github.com/pytorch/pytorch/pull/163340#issuecomment-3317105243)) --- BUILD.bazel | 31 + CMakeLists.txt | 4 + functorch/.gitignore | 1 + functorch/CMakeLists.txt | 45 + functorch/csrc/dim/arena.h | 332 ++ functorch/csrc/dim/dim.cpp | 3656 +++++++++++++++++++ functorch/csrc/dim/dim.h | 8 + functorch/csrc/dim/dim_opcode.c | 17 + functorch/csrc/dim/minpybind.h | 692 ++++ functorch/csrc/dim/python_variable_simple.h | 49 + functorch/csrc/init_dim_only.cpp | 22 + setup.py | 26 + 12 files changed, 4883 insertions(+) create mode 100644 functorch/CMakeLists.txt create mode 100644 functorch/csrc/dim/arena.h create mode 100644 functorch/csrc/dim/dim.cpp create mode 100644 functorch/csrc/dim/dim.h create mode 100644 functorch/csrc/dim/dim_opcode.c create mode 100644 functorch/csrc/dim/minpybind.h create mode 100644 functorch/csrc/dim/python_variable_simple.h create mode 100644 functorch/csrc/init_dim_only.cpp diff --git a/BUILD.bazel b/BUILD.bazel index ff8a57ba28ca..d4202e7a2c1e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -833,6 +833,36 @@ pybind_extension( ], ) +cc_library( + name = "functorch", + hdrs = glob([ + "functorch/csrc/dim/*.h", + ]), + srcs = glob([ + "functorch/csrc/dim/*.cpp", + ]), + deps = [ + ":aten_nvrtc", + ":torch_python", + "@pybind11", + ], +) + +pybind_extension( + name = "functorch/_C", + copts=[ + "-DTORCH_EXTENSION_NAME=_C" + ], + srcs = [ + "functorch/csrc/init_dim_only.cpp", + ], + deps = [ + ":functorch", + ":torch_python", + ":aten_nvrtc", + ], +) + cc_binary( name = "torch/bin/torch_shm_manager", srcs = [ @@ -873,6 +903,7 @@ py_library( ], data = [ ":torch/_C.so", + ":functorch/_C.so", ":torch/bin/torch_shm_manager", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d37aa25c74f..384dd27f9262 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1391,6 +1391,10 @@ endif() include(cmake/Summary.cmake) caffe2_print_configuration_summary() +if(BUILD_FUNCTORCH) + add_subdirectory(functorch) +endif() + # Parse custom debug info if(DEFINED USE_CUSTOM_DEBINFO) string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}") diff --git a/functorch/.gitignore b/functorch/.gitignore index 58bffff1353d..145ab7d60839 100644 --- a/functorch/.gitignore +++ b/functorch/.gitignore @@ -3,6 +3,7 @@ dist/ functorch.egg-info/ *__pycache__* functorch/version.py +functorch/_C.so .gdbinit t.py .vscode/ diff --git a/functorch/CMakeLists.txt b/functorch/CMakeLists.txt new file mode 100644 index 000000000000..bdfa4bfe4550 --- /dev/null +++ b/functorch/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.18) +project(functorch) +set(CMAKE_CXX_STANDARD 17) + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +set(FT_DIR csrc) +file(GLOB_RECURSE FT_SOURCES ${FT_DIR}/*.cpp ${FT_DIR}/*.c) + +add_library(${PROJECT_NAME} MODULE ${FT_SOURCES}) +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_compile_definitions(${PROJECT_NAME} PRIVATE FUNCTORCH_BUILD_MAIN_LIB) +target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_EXTENSION_NAME=_C) +target_compile_definitions(${PROJECT_NAME} PRIVATE TORCH_API_INCLUDE_EXTENSION_H) +target_compile_options(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_COMPILE_OPTIONS}) +target_compile_options_if_supported(${PROJECT_NAME} "-Wmissing-prototypes") +target_compile_options_if_supported(${PROJECT_NAME} "-Werror=missing-prototypes") +if(BUILD_LIBTORCHLESS) + target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIB} torch_python) +else() + # functorch cannot use the alias to build on windows + target_link_libraries(${PROJECT_NAME} PRIVATE torch torch_python) +endif() +target_link_libraries(${PROJECT_NAME} PRIVATE pybind::pybind11) + +set_target_properties(${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY + ${CMAKE_BINARY_DIR}/functorch) +set_target_properties(${PROJECT_NAME} PROPERTIES INSTALL_RPATH "${_rpath_portable_origin}/../torch/lib") + +# Copy-pasted prefix/suffix logic for Python extensions from +# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L1975 +# https://github.com/pytorch/pytorch/blob/33bb8ae350611760139457b85842b1d7edf9aa11/caffe2/CMakeLists.txt#L2022 +# TODO: It would be good to be able to use Python3_add_library target, but it does not work in many cases +set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" DEBUG_POSTFIX "") +if(WIN32) + set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".pyd") +else() + set_target_properties(${PROJECT_NAME} PROPERTIES SUFFIX ".so") +endif() +# Needed to link functorch on MacOS +if(NOT ${TORCH_PYTHON_LINK_FLAGS} STREQUAL "") + set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS ${TORCH_PYTHON_LINK_FLAGS}) +endif() +install(TARGETS ${PROJECT_NAME} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}") diff --git a/functorch/csrc/dim/arena.h b/functorch/csrc/dim/arena.h new file mode 100644 index 000000000000..ec2cfef66895 --- /dev/null +++ b/functorch/csrc/dim/arena.h @@ -0,0 +1,332 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include "minpybind.h" + +#if defined(_MSC_VER) && !defined(__clang__) +#include +// https://stackoverflow.com/questions/355967/how-to-use-msvc-intrinsics-to-get-the-equivalent-of-this-gcc-code +inline unsigned int __builtin_clz(unsigned int x) { + unsigned long r = 0; + _BitScanReverse(&r, x); + return (31 - r); +} +#endif + +inline int round2min8(int num) { + int nzeros = __builtin_clz((num - 1)|4); + return 1 << (32 - nzeros); +} + +struct Arena; +template +struct OwnedSlice; + +template +struct Slice { + Slice() + : begin_(nullptr), size_(0), capacity_(0) {} + + template + Slice(Arena& arena, Args&&... args); + + T* begin() const { + return begin_; + } + T* end() const { + return begin_ + size_; + } + int size() const { + return size_; + } + int capacity() const { + return capacity_; + } + + T& back(int i=-1) { + return begin_[size_ + i]; + } + + T& operator[](int i) const { + return begin_[i]; + } + std::optional index(const T& value) { + for (int i : enumerate()) { + if (begin_[i] == value) { + return i; + } + } + return std::nullopt; + } + bool contains(const T& value) { + return index(value).has_value(); + } + + void insert(Arena& arena, Slice where, Slice to_insert); + void insert(Arena& arena, Slice where, T v) { + return insert(arena, where, Slice(&v, &v + 1)); + } + void insert(Arena& arena, int where, T v) { + return insert(arena, slice(where, where), v); + } + void append(Arena& arena, T value); + void extend(Arena& arena, Slice to_insert); + void extend(Arena& arena, const T* begin, const T* end) { + return extend(arena, Slice((T*)begin, (T*)end)); + } + + bool remove(Arena& A, T value) { + auto idx = index(value); + if (idx) { + insert(A, slice(*idx, *idx + 1), Slice()); + } + return idx.has_value(); + } + + Slice slice(int begin) { + return slice(begin, size_); + } + + Slice slice(int begin, int end) { + if (begin < 0) { + begin += size_; + } + if (end < 0) { + end += size_; + } + Slice result; + result.begin_ = begin_ + begin; + result.size_ = end - begin; + result.capacity_ = result.size_; + return result; + } + + bool inside(Slice where) { + return begin() <= where.begin() && where.end() <= end(); + } + + irange enumerate() const { + return irange(size_); + } + + irange reversed_enumerate() const { + return irange(size_ - 1, -1, -1); + } + + bool operator==(const Slice& rhs) const { + if (size() != rhs.size()) { + return false; + } + return std::equal(begin(), end(), rhs.begin()); + } + + Slice(T* begin, T* end) + : begin_(begin), size_(end - begin), capacity_(size_) {} + +protected: + static int _length(const T& t) { + return 1; + } + static int _length(Slice t) { + return t.size_; + } + static T* _insert(T*& dst, T t) { + *dst = std::move(t); + return ++dst; + } + static T* _insert(T*& dst, Slice t) { + std::memcpy(dst, t.begin_, sizeof(T)*t.size_); + dst += t.size_; + return dst; + } + T* begin_; + int size_; + int capacity_; + friend struct OwnedSlice; +}; + +template +struct OwnedSlice { + typedef void (*deleter_t)(Slice); + static void _no_delete(Slice) {} + OwnedSlice() + : deleter_(_no_delete) {} + OwnedSlice(const OwnedSlice&) = delete; + OwnedSlice& operator=(const OwnedSlice&) = delete; + ~OwnedSlice() { + deleter_(slice_); + if (slice_.size_ > 8) { + delete [] slice_.begin_; + } + } + void set(Slice to_own, deleter_t deleter = _no_delete) { + slice_.size_ = slice_.capacity_ = to_own.size(); + slice_.begin_ = (slice_.size_ > 8) ? new T[slice_.size_] : &small_buf[0]; + std::memcpy(slice_.begin_, to_own.begin(), slice_.size_ * sizeof(T)); + deleter_ = deleter; + } + Slice slice() const { + return slice_; + } +private: + Slice slice_; + deleter_t deleter_; + T small_buf[8]; +}; + +template +inline std::ostream& operator<<(std::ostream& s, const Slice& v) { + s << "["; + for (int i : v.enumerate()) { + if (i > 0) { + s << ", "; + } + s << v[i]; + } + s << "]"; + return s; +} + +struct TensorRef { + TensorRef() + : impl_(nullptr){} + TensorRef(const at::Tensor& t) + : impl_(t.unsafeGetTensorImpl()) {} + const at::Tensor& operator*() const { + return *(at::Tensor*)this; + } + at::Tensor* operator->() const { + return (at::Tensor*)this; + } + operator bool() const { + return impl_ != nullptr; + } +private: + at::TensorImpl* impl_; +}; + +constexpr int ARENA_MAX_SIZE = 4096; +constexpr int ALIGNMENT = 8; +struct Arena { + Arena() + : allocated_(0) {} + template + T* allocate(int n) { + if (!n) { + return nullptr; + } + int to_allocate = sizeof(T)*n; + int to_allocate_rounded = ALIGNMENT * ((to_allocate - 1) / ALIGNMENT + 1); + auto prev_allocated = allocated_; + allocated_ += to_allocate_rounded; + if (C10_UNLIKELY_OR_CONST(allocated_ > ARENA_MAX_SIZE)) { + overflow_.emplace_back(new char[to_allocate]); + return (T*) &overflow_.back()[0]; + } + return (T*) (buffer_ + prev_allocated); + } + TensorRef autorelease(at::Tensor s) { + auto ref = TensorRef(s); + s.unsafeReleaseTensorImpl(); + ar_tensors_.append(*this, ref); + return ref; + } + mpy::handle autorelease(mpy::object obj) { + ar_objects_.append(*this, obj); + obj.release(); + return ar_objects_.back(); + } + ~Arena() { + for(TensorRef t: ar_tensors_) { + c10::intrusive_ptr::reclaim(t->unsafeGetTensorImpl()); + } + for(mpy::handle h: ar_objects_) { + mpy::object::steal(h); + } + } +private: + int64_t allocated_; + char buffer_[ARENA_MAX_SIZE]; + Slice ar_tensors_; + Slice ar_objects_; + std::vector> overflow_; +}; + +template +inline void Slice::insert(Arena& arena, Slice where, Slice to_insert) { + AT_ASSERT(inside(where)); + Slice result = *this; + /// b------sb---se-----e, 0----n + T* body_dest = where.begin(); + if (where.size() != to_insert.size()) { + int new_size = size() - where.size() + to_insert.size(); + T* tail_dest = where.begin() + to_insert.size(); + if (new_size >= capacity_) { + int new_capacity = new_size ? round2min8(new_size) : 0; + result.capacity_ = new_capacity; + result.begin_ = arena.allocate(new_capacity); + body_dest = result.begin_ + (where.begin() - begin()); + tail_dest = body_dest + to_insert.size(); + //std::memcpy(result.begin_, begin_, sizeof(T)*(where.begin() - begin())); + std::copy(begin_, begin_ + (where.begin() - begin()), result.begin_); + } + std::memmove(tail_dest, where.end(), sizeof(T)*(end() - where.end())); + result.size_ = new_size; + } + + //std::memcpy(body_dest, to_insert.begin(), sizeof(T)*to_insert.size()); + std::copy(to_insert.begin(), to_insert.end(), body_dest); + *this = result; +} + +template +inline void Slice::append(Arena& arena, T value) { + Slice result = *this; + if (size_ == capacity_) { + int new_size = size_ ? round2min8(size_)*2 : 8; + T* n = arena.allocate(new_size); + //memcpy(n, begin_, size_*sizeof(T)); + std::copy(begin_, begin_ + size_, n); + result.begin_ = n; + result.capacity_ = new_size; + } + result[result.size_++] = std::move(value); + *this = result; +} + +template +inline void Slice::extend(Arena& arena, Slice rhs) { + Slice result = *this; + result.size_ = size_ + rhs.size(); + if (result.size_ > capacity_) { + int new_size = round2min8(result.size_); + T* n = arena.allocate(new_size); + //memcpy(n, begin_, size_*sizeof(T)); + std::copy(begin_, begin_+size_, n); + result.begin_ = n; + result.capacity_ = new_size; + } + //memcpy(result.begin_ + size_, rhs.begin(), sizeof(T)*rhs.size()); + std::copy(rhs.begin(), rhs.end(), result.begin_ + size_); + *this = result; +} + +template +template +Slice::Slice(Arena& arena, Args&&... args) { + int lens[] = {_length(args)...}; + size_ = 0; + for (auto i : lens) { + size_ += i; + } + capacity_ = size_ ? round2min8(size_) : 0; + begin_ = arena.allocate(capacity_); + T* dst_ = begin_; + T* unused[] = {_insert(dst_, args)...}; + (void) unused; +} diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp new file mode 100644 index 000000000000..5258ba52f99c --- /dev/null +++ b/functorch/csrc/dim/dim.cpp @@ -0,0 +1,3656 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +// Many APIs have changed/don't exist anymore +#if IS_PYTHON_3_12_PLUS + +#include "dim.h" + +// Re-enable this some day +PyObject* Dim_init() { + PyErr_SetString( + PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); + return nullptr; +} + +#else + +#include +#include +#include +#include +#include +#include +#include "minpybind.h" +// #include +#include +#include +#include +#include +#include +#include "arena.h" +#include "dim.h" +#include "python_variable_simple.h" + +#if IS_PYTHON_3_11_PLUS + +#define Py_BUILD_CORE +#include "internal/pycore_opcode.h" +#undef Py_BUILD_CORE +#endif + +// C++ API functions for objects to +// * construct the object, returning a ref-counted handle +// * The actual API, with methods that take/return C-typed values + +// extend minpybind.h to include +// * typed handles so that -> can get to their raw API +// * object/handle distinction for the typed handles + +// class Dim: --------------- +mpy::handle torch_Tensor___mul__; +mpy::handle _Tensor; +mpy::handle _Tensor_sum; +mpy::handle NamedTuple; +mpy::dict_view pointwise; +mpy::handle torch_Tensor_expand; +binaryfunc THPVariable_getitem; +objobjargproc THPVariable_setitem; +mpy::handle no_slice; +PyTypeObject* torch_Tensor; +mpy::handle torch_Tensor_copy_; +mpy::handle torch_Tensor_split; +bool pointwise_optimize = true; +PyTypeObject* DimType = nullptr; + +PyObject* Tensor_getitem(PyObject* self, PyObject* index); +int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); + +namespace { +void maybeInitializeGlobals() { + // globals that depend on the python dim library, + // which we can't lookup until we finish initializing the _C module + if (_Tensor.ptr()) { + return; + } + auto dim = mpy::import("functorch.dim"); + _Tensor = dim.attr("_Tensor"); + pointwise = dim.attr("pointwise"); + _Tensor_sum = _Tensor.attr("sum"); + DimType = (PyTypeObject*)mpy::import("functorch.dim").attr("Dim").ptr(); +} + +void replaceMappingIfMatches(mpy::handle tp) { + auto T = (PyTypeObject*)tp.ptr(); + bool recurse = false; + if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { + T->tp_as_mapping->mp_subscript = Tensor_getitem; + recurse = true; + } + if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { + T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; + recurse = true; + } + if (recurse) { + auto result = tp.attr("__subclasses__").call(); + mpy::list_view lv(result); + for (auto i : lv.enumerate()) { + replaceMappingIfMatches(lv[i]); + } + } +} + +void initializeGlobals(Arena& A) { + auto torch = mpy::import("torch"); + torch_Tensor = (PyTypeObject*)torch.attr("Tensor").ptr(); + torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); + + torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); + torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); + torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + auto TensorBase = (PyTypeObject*)py_TensorBase.ptr(); + THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; + THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; + NamedTuple = mpy::import("typing").attr("NamedTuple"); + no_slice = PySlice_New(NULL, NULL, NULL); +} + +mpy::handle DimensionBindError_; +mpy::handle DimensionBindError() { + if (!DimensionBindError_.ptr()) { + DimensionBindError_ = + mpy::import("functorch.dim").attr("DimensionBindError"); + } + return DimensionBindError_; +} + +static int64_t n_dims_created = 65; + +struct Dim : public mpy::base { + int64_t level_; // for stable comparisons in prototype + mpy::object name_; + Dim() : level_(n_dims_created++) {} + void init(mpy::object name, int64_t s = -1) { + name_ = std::move(name); + size_ = s; + } + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == DimType; + } + + int64_t size() const { + if (size_ == -1) { + mpy::raise_error( + PyExc_ValueError, "dimension %S is unbound", name_.ptr()); + } + return size_; + } + void set_size(int64_t v) { + if (size_ == -1) { + size_ = v; + } else if (size_ != v) { + mpy::raise_error( + DimensionBindError(), + "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", + this, + this->size_, + v); + } + } + bool is_bound() const { + return size_ != -1; + } + static mpy::obj create(mpy::object name, int64_t s = -1) { + if (!DimType) { + maybeInitializeGlobals(); + } + auto r = Dim::alloc(DimType); + r->init(std::move(name), s); + return r; + } + static PyTypeObject Type; + const at::Tensor& range() { + if (!range_.defined()) { + range_ = at::arange(size()); + } + return range_; + } + const at::Tensor& batchtensor() { + if (!batchtensor_.defined()) { + batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); + } + return batchtensor_; + } + + private: + int64_t size_{-1}; + at::Tensor range_; + at::Tensor batchtensor_; +}; + +struct DimEntry { + // union of either a negative number indicating which dimension this is from + // the rhs, or a pointer to a first-class dimension. pointers do not have + // their highest bit set, so checking the number is negative tells us that it + // is not a dim. + bool is_positional() const { + return data_ < 0; + } + bool is_none() const { + return data_ == 0; + } + int64_t position() const { + return data_; + } + mpy::hdl dim() const { + Dim* result; + std::memcpy(&result, &data_, sizeof(Dim*)); + return mpy::hdl(result); + } + + DimEntry() : data_(0) {} + + DimEntry(int64_t pos) : data_(pos) { + AT_ASSERT(pos < 0); + } + DimEntry(mpy::hdl d) { + std::memcpy(&data_, &d, sizeof(int64_t)); + } + bool operator==(const DimEntry& rhs) const { + return data_ == rhs.data_; + } + + private: + int64_t data_; +}; + +// Dim wrapper methods +DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { + if (Dim::check(d)) { + if (keepdim) { + mpy::raise_error( + PyExc_ValueError, + "cannot preserve first-class dimensions with keepdim=True"); + } + return Dim::unchecked_wrap(d); + } else if (mpy::is_int(d)) { + auto i = mpy::to_int(d); + while (i >= 0) { + i -= N; + } + return i; + } else { + return DimEntry(); + } +} + +int Dim_init(mpy::hdl self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"name", "size", nullptr}; + mpy::handle name; + mpy::handle size = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|O", const_cast(kwlist), &name, &size)) { + return -1; + } + self->init( + mpy::object::borrow(name), + (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); + return 0; + PY_END(-1) +} + +PyObject* Dim_repr(Dim* self) { + PY_BEGIN + mpy::object name = (self->name_.ptr()) + ? self->name_ + : mpy::unicode_from_string(""); + return name.release(); + PY_END(nullptr) +} + +PyObject* Dim_getsize(Dim* self, void*) { + PY_BEGIN + return mpy::from_int(self->size()).release(); + PY_END(nullptr) +} + +int Dim_setsize(Dim* self, PyObject* size, void*) { + PY_BEGIN + self->set_size(mpy::to_int(size)); + return 0; + PY_END(-1) +} + +PyObject* Dim_getis_bound(Dim* self, void*) { + return PyBool_FromLong(self->is_bound()); +} + +PyObject* Dim_getlevel(Dim* self, void*) { + return PyLong_FromLong(self->level_); +} + +PyObject* Dim_get_levels(Dim* self, void*) { + mpy::tuple t(1); + t.set(0, mpy::object::borrow(self->ptr())); + return t.release(); +} + +PyObject* Dim_get_has_device(Dim* self, void*) { + Py_RETURN_FALSE; +} + +PyObject* Dim_get_tensor(Dim* self, void*) { + return THPVariable_Wrap(self->range()); +} + +PyObject* Dim_get_batchtensor(Dim* self, void*) { + return THPVariable_Wrap(self->batchtensor()); +} + +PyGetSetDef Dim_getsetters[] = { + {"size", (getter)Dim_getsize, (setter)Dim_setsize, "Dimension size", NULL}, + {"is_bound", (getter)Dim_getis_bound, NULL, "is_bound", NULL}, + {"_level", (getter)Dim_getlevel, NULL, "_level", NULL}, + {"_levels", (getter)Dim_get_levels, NULL, "_levels", NULL}, + {"_has_device", (getter)Dim_get_has_device, NULL, "_has_device", NULL}, + {"_tensor", (getter)Dim_get_tensor, NULL, "_tensor", NULL}, + {"_batchtensor", (getter)Dim_get_batchtensor, NULL, "_batchtensor", NULL}, + {"ndim", + (getter)[](PyObject* self, void*) + ->PyObject* {return mpy::from_int(1).release(); +} // namespace +, NULL, "ndim", NULL +} +, { + NULL +} /* Sentinel */ +} +; +} +PyTypeObject Dim::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.Dim", /* tp_name */ + sizeof(Dim), /* tp_basicsize */ + 0, /* tp_itemsize */ + Dim::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)Dim_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Dim Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Dim_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)(void*)static_cast, PyObject*, PyObject*)>( + Dim_init), /* tp_init */ + 0, /* tp_alloc */ + Dim::new_stub, /* tp_new */ +}; + +// class DimList ------------ + +struct DimList : public mpy::base { + mpy::object name_; + std::vector> dims_; + static PyTypeObject Type; + void init(mpy::object name) { + name_ = std::move(name); + } + void set_dims(std::vector> dims) { + bound_ = true; + dims_ = std::move(dims); + } + bool is_bound() { + return bound_; + } + void bind_len(int64_t size) { + if (bound_) { + int64_t b_size = dims_.size(); + if (b_size != size) { + mpy::raise_error( + DimensionBindError(), + "Dimlist has size %lld but it is being bound to size %d", + b_size, + size); + } + } else { + bound_ = true; + dims_.resize(size); + for (Py_ssize_t i = 0; i < size; ++i) { + dims_[i] = + Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); + } + } + } + int64_t size() const { + if (!bound_) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + return dims_.size(); + } + void set_bound(bool b) { + bound_ = b; + } + + private: + bool bound_ = false; +}; + +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds); + +static PyObject* DimList_repr(DimList* self) { + PY_BEGIN + if (self->is_bound()) { + size_t size = self->dims_.size(); + mpy::tuple t(size); + for (size_t i = 0; i < size; ++i) { + t.set(i, self->dims_[i]); + } + return mpy::repr(t).release(); + } else if (!mpy::is_none(self->name_)) { + return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); + } else { + return mpy::unicode_from_string("").release(); + } + PY_END(nullptr) +} + +static PyObject* DimList_bind( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle sizes; + static const char* const _keywords[] = {"sizes", nullptr}; + static _PyArg_Parser parser = {"O", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { + return nullptr; + } + if (!mpy::is_sequence(sizes)) { + mpy::raise_error(PyExc_ValueError, "expected a sequence"); + } + mpy::sequence_view seq = sizes; + auto size = seq.size(); + self->bind_len(size); + for (Py_ssize_t i = 0; i < size; ++i) { + self->dims_[i]->set_size(mpy::to_int(seq[i])); + } + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyObject* DimList_bind_len( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + int size; + static const char* const _keywords[] = {"N", nullptr}; + static _PyArg_Parser parser = {"i", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { + return nullptr; + } + self->bind_len(size); + Py_RETURN_NONE; + PY_END(nullptr) +} + +static PyMethodDef DimList_methods[] = { + {"bind", (PyCFunction)(void*)DimList_bind, METH_FASTCALL | METH_KEYWORDS}, + {"bind_len", + (PyCFunction)(void*)DimList_bind_len, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +static Py_ssize_t DimList_len(DimList* self) { + PY_BEGIN + return self->size(); + PY_END(-1) +} + +static PyObject* DimList_item(DimList* self, Py_ssize_t idx) { + PY_BEGIN + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + if (idx < 0 || (size_t)idx >= self->dims_.size()) { + mpy::raise_error(PyExc_IndexError, "index out of bounds"); + } + mpy::object r = self->dims_[idx]; + return r.release(); + PY_END(nullptr) +} + +PySequenceMethods DimList_seq{ + (lenfunc)DimList_len, // lenfunc sq_length; + 0, // binaryfunc sq_concat; + 0, // ssizeargfunc sq_repeat; + (ssizeargfunc)DimList_item, // ssizeargfunc sq_item; + 0, // void *was_sq_slice; + 0, // ssizeobjargproc sq_ass_item; + 0, // void *was_sq_ass_slice; + 0, // objobjproc sq_contains; + + 0, // binaryfunc sq_inplace_concat; + 0, // ssizeargfunc sq_inplace_repeat; +}; + +static PyObject* DimList_getis_bound(DimList* self, void*) { + return PyBool_FromLong(self->is_bound()); +} + +static PyGetSetDef DimList_getsetters[] = { + {"is_bound", (getter)DimList_getis_bound, NULL, "is_bound", NULL}, + {NULL} /* Sentinel */ +}; + +static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { + PY_BEGIN + if (mpy::is_int(idx)) { + return DimList_item(self, mpy::to_int(idx)); + } else if (mpy::is_slice(idx)) { + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + mpy::slice_view s(idx, self->dims_.size()); + mpy::tuple r(s.slicelength); + for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { + r.set(j++, self->dims_[i]); + } + return r.release(); + } else { + mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); + return nullptr; + } + PY_END(nullptr) +} + +PyMappingMethods DimList_mapping = { + 0, // lenfunc mp_length; + (binaryfunc)(void*)DimList_subscript, // binaryfunc mp_subscript; + 0, // objobjargproc mp_ass_subscript; +}; + +PyTypeObject DimList::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.DimList", /* tp_name */ + sizeof(DimList), /* tp_basicsize */ + 0, /* tp_itemsize */ + DimList::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)DimList_repr, /* tp_repr */ + 0, /* tp_as_number */ + &DimList_seq, /* tp_as_sequence */ + &DimList_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + 0, /* tp_flags */ + "DimList Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DimList_methods, /* tp_methods */ + 0, /* tp_members */ + DimList_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DimList_init, /* tp_init */ + 0, /* tp_alloc */ + DimList::new_stub, /* tp_new */ +}; + +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; + mpy::handle len_or_dims = nullptr; + PyObject* name = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { + return -1; + } + self->init(mpy::object::borrow(name ? name : Py_None)); + if (len_or_dims.ptr()) { + if (mpy::is_int(len_or_dims)) { + self->bind_len(mpy::to_int(len_or_dims)); + } else if (mpy::is_sequence(len_or_dims)) { + mpy::sequence_view s(len_or_dims); + std::vector> dims; + size_t size = s.size(); + dims.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto r = s[i]; + if (mpy::is_int(r)) { + dims.emplace_back(Dim::create( + mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), + mpy::to_int(r))); + } else { + dims.emplace_back(Dim::wrap(r)); + } + } + self->set_dims(std::move(dims)); + } else { + PyErr_Format( + PyExc_ValueError, "expected a length or a sequence of dimensions"); + return -1; + } + return 0; + } + return 0; + PY_END(-1); +} + +// Tensor ----------------------------- + +PyTypeObject* TensorType = nullptr; // the python wrapper type. +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise); + +namespace { + +at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { + auto levels = Slice(); + levels.extend(A, levels_); + while (true) { + int64_t min_real_index = -1; + int64_t min_index = -1; + int64_t min_value = INT_MAX; + int64_t i = 0; + int64_t r = 0; + for (auto l : levels) { + if (!l.is_none()) { + if (!l.is_positional() && l.dim()->level_ < min_value) { + min_value = l.dim()->level_; + min_index = i; + min_real_index = r; + } + ++i; + } + ++r; + } + if (min_index == -1) { + return t; + } + auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); + t = std::move(t2); + levels[min_real_index] = DimEntry(); + } +} + +struct DelayedOperator { + DelayedOperator(mpy::object o, mpy::vector_args a) + : orig(std::move(o)), args(a) { + auto all = a.size(); + // this will outlive the call so + // take ownership of temporaries + // in vector args + auto buf = new mpy::handle[all]; + memcpy(buf, args.args, sizeof(mpy::handle) * all); + args.args = buf; + for (auto i : args.enumerate_all()) { + Py_INCREF(args.args[i].ptr()); + } + Py_XINCREF(args.kwnames.ptr()); + } + ~DelayedOperator() { + for (auto i : args.enumerate_all()) { + Py_DECREF(args[i].ptr()); + } + if (args.has_keywords()) { + Py_XDECREF(args.kwnames.ptr()); + } + delete[] args.args; + } + mpy::object orig; + mpy::vector_args args; +}; + +void free_levels_dims(Slice levels) { + for (auto e : levels) { + if (!e.is_positional()) { + mpy::object::steal(e.dim()); + } + } +} +} // namespace + +struct Tensor : public mpy::base { + private: + at::Tensor tensor_; + at::Tensor batchtensor_; + OwnedSlice levels_; + bool has_device_; + std::unique_ptr delayed_; + + public: + at::Tensor& tensor(Arena& A) { + if (C10_UNLIKELY(!tensor_.defined())) { + AT_ASSERT(delayed_); + auto t = Tensor::wrap( + run_torch_function(A, delayed_->orig, delayed_->args, true)); + tensor_ = t->tensor(A); + delayed_.reset(); + // don't force creation of batch tensor if it wasn't already provided. + batchtensor_ = t->batchtensor_; + AT_ASSERT(levels() == t->levels()); + } + return tensor_; + } + at::Tensor& batchtensor(Arena& A) { + if (C10_UNLIKELY(!batchtensor_.defined())) { + batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); + } + return batchtensor_; + } + Slice levels() { + return levels_.slice(); + } + bool has_device() { + return has_device_; + } + DelayedOperator* delayed() { + return delayed_.get(); + } + static PyTypeObject Type; + + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == TensorType; + } + + static mpy::obj create() { + if (!TensorType) { + TensorType = + (PyTypeObject*)mpy::import("functorch.dim").attr("Tensor").release(); + } + return Tensor::alloc(TensorType); + } + void capture_levels(Slice levels) { + // grab ownership of the dims inside levels + for (auto l : levels) { + if (!l.is_positional()) { + mpy::object::borrow(l.dim()).release(); + } + } + levels_.set(levels, free_levels_dims); + } + static mpy::object from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device); + static mpy::obj create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device); + friend struct EnableAllLayers; +}; + +namespace { +// version in header does a unnecessary refcount +/- +at::functorch::BatchedTensorImpl* maybeGetBatchedImpl( + const at::Tensor& tensor) { + if (at::functorch::isBatchedTensor(tensor)) { + return static_cast( + tensor.unsafeGetTensorImpl()); + } + return nullptr; +} + +TensorRef unchecked_tensor_from(mpy::handle p) { + auto v = (THPVariable*)p.ptr(); + return TensorRef(*v->cdata); +} + +static int64_t ndim_of_levels(Slice levels) { + int64_t r = 0; + for (auto l : levels) { + if (l.is_positional()) { + ++r; + } + } + return r; +} + +struct TensorInfo { + TensorRef tensor; + Slice levels; + bool has_device; + TensorRef batchedtensor; + int64_t ndim() const { + return ndim_of_levels(levels); + } + operator bool() const { + return tensor; + } + + static TensorInfo create( + Arena& A, + mpy::handle h, + bool ensure_batched = true, + bool ensure_present = true) { + if (Tensor::check_exact(h)) { + auto t = Tensor::unchecked_wrap(h); + return TensorInfo{ + t->tensor(A), + t->levels(), + t->has_device(), + ensure_batched ? t->batchtensor(A) : TensorRef()}; + } else if (Dim::check_exact(h)) { + auto d = Dim::unchecked_wrap(h); + return TensorInfo{ + d->range(), + Slice(A, DimEntry(d)), + false, + ensure_batched ? d->batchtensor() : TensorRef()}; + } else if (THPVariable_Check(h.ptr())) { + TensorRef t = unchecked_tensor_from(h); + Slice levels; + for (auto i : irange(-t->dim(), 0)) { + levels.append(A, i); + } + return TensorInfo{t, levels, true, t}; + } else { + if (ensure_present) { + mpy::raise_error(PyExc_ValueError, "expected a tensor object"); + } + return TensorInfo{}; + } + } +}; + +static PyObject* py_Tensor_from_positional( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN +#define ARGS(_) \ + _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) + MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) +#undef ARGS + + if (!THPVariable_Check(tensor.ptr())) { + mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); + } + + Slice levels; + mpy::sequence_view sq(py_levels); + for (auto i : sq.enumerate()) { + mpy::object v = sq[i]; + if (mpy::is_int(v)) { + auto vi = mpy::to_int(v); + levels.append(A, vi); + } else { + auto dim = Dim::wrap(std::move(v)); + mpy::hdl hdim = dim; + levels.append(A, hdim); + } + } + return Tensor::from_positional( + A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0) + .release(); + PY_END(nullptr) +} +} // namespace + +mpy::object Tensor::from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device) { + size_t seen_dims = 0; + int last = 0; + // auto sz = tensor.sizes(); + for (auto i : levels.enumerate()) { + auto l = levels[i]; + if (l.is_positional()) { + AT_ASSERT(last == 0 || last + 1 == l.position()); + last = l.position(); + } else { + mpy::object::borrow(l.dim()).release(); + // AT_ASSERT(sz[i] == l.dim()->size()); + ++seen_dims; + } + } + AT_ASSERT(last == 0 || last == -1); + if (!seen_dims) { + return mpy::object::steal(THPVariable_Wrap(tensor)); + } + + mpy::obj self = Tensor::create(); + self->tensor_ = std::move(tensor); + AT_ASSERT(self->tensor_.dim() == levels.size()); + self->levels_.set(levels, free_levels_dims); + self->has_device_ = has_device; + mpy::object r = std::move(self); + return r; +} + +mpy::obj Tensor::create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device) { + mpy::obj self = Tensor::create(); + self->capture_levels(levels); + self->has_device_ = has_device; + self->delayed_ = std::make_unique(std::move(op), args); + return self; +} + +namespace { +mpy::list slice_to_list(Slice h) { + mpy::list lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; +} + +mpy::tuple slice_to_tuple(Slice h) { + mpy::tuple lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; +} + +enum UType { + U_ELEM, + U_TUPLE_LIKE, + U_DICT, +}; + +struct Unflatten { + mpy::object operator()(Slice& elements) { + mpy::object r; + switch (type) { + case U_ELEM: { + r = mpy::object::borrow(elements[0]); + elements = elements.slice(1); + } break; + case U_TUPLE_LIKE: { + mpy::tuple tup(children.size()); + for (auto i : children.enumerate()) { + tup.set(i, children[i](elements)); + } + r = obj.call(tup); + } break; + case U_DICT: { + r = mpy::object::checked_steal(PyDict_New()); + mpy::dict_view rv(r); + mpy::dict_view d(obj); + Py_ssize_t pos = 0; + mpy::handle k, v; + for (int i = 0; d.next(&pos, &k, &v); ++i) { + rv.set(k, children[i](elements)); + } + } break; + } + return r; + } + UType type; + mpy::handle obj; + Slice children; +}; + +Unflatten tree_flatten( + Arena& A, + mpy::handle agg, + Slice& flat_elements) { + Slice c; + UType utype; + mpy::handle obj; + if (mpy::list_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + mpy::list_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::tuple_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + // includes named tuples + mpy::tuple_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::dict_view::check(agg)) { + utype = U_DICT; + mpy::dict_view d(agg); + obj = agg; + Py_ssize_t pos = 0; + mpy::handle k, v; + while (d.next(&pos, &k, &v)) { + c.append(A, tree_flatten(A, v, flat_elements)); + } + } else { + utype = U_ELEM; + flat_elements.append(A, agg); + } + return Unflatten{utype, obj, c}; +} + +struct UnflattenVectorArgs { + mpy::vector_args operator()(Arena& A, Slice& elements) { + if (!had_nested) { + auto args = elements.begin(); + elements = Slice(); + return mpy::vector_args(args, nargs, kwnames); + } + Slice args; + for (auto u : children) { + args.append(A, A.autorelease(u(elements))); + } + return mpy::vector_args(args.begin(), nargs, kwnames); + } + Slice children; + Py_ssize_t nargs; + mpy::handle kwnames; + bool had_nested; +}; + +UnflattenVectorArgs tree_flatten( + Arena& A, + mpy::vector_args args, + Slice& flat_elements) { + UnflattenVectorArgs r; + r.kwnames = args.kwnames; + r.nargs = args.nargs; + r.had_nested = false; + auto N = args.size(); + for (auto i : irange(N)) { + auto typ = Py_TYPE(args[i].ptr()); + // fast checks that this thing isn't something that is nested. + bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || + typ == TensorType || typ == DimType; + if (!is_element) { + flat_elements.extend(A, args.args, args.args + i); + for (auto j : irange(i)) { + (void)j; + r.children.append(A, Unflatten{U_ELEM}); + } + for (auto j : irange(i, N)) { + r.children.append(A, tree_flatten(A, args[j], flat_elements)); + if (r.children.back().type != U_ELEM) { + r.had_nested = true; + } + } + return r; + } + } + flat_elements.extend(A, args.args, args.args + N); + return r; +} + +struct UnflattenArena { + Arena A; + Unflatten unflatten; +}; + +PyObject* py_unflatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, ns) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + mpy::sequence_view sv(ns); + // because we do not have a autorelase pool yet... + Arena A; + Slice slice; + mpy::handle Tuple = (PyObject*)&PyTuple_Type; + auto inputs = Tuple.call(ns); + mpy::tuple_view tv(inputs); + for (auto i : tv.enumerate()) { + slice.append(A, tv[i]); + } + auto AA = (UnflattenArena*)PyCapsule_GetPointer(self, "arena"); + auto r = AA->unflatten(slice).release(); + AT_ASSERT(r != nullptr); + return r; + PY_END(nullptr) +} + +PyMethodDef py_unflatten_def = { + "unflatten", + (PyCFunction)(void*)py_unflatten, + METH_FASTCALL | METH_KEYWORDS}; + +void free_unflatten_arena(PyObject* pc) { + delete (UnflattenArena*)PyCapsule_GetPointer(pc, "arena"); +} + +PyObject* py_tree_flatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, tree) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + auto A = new UnflattenArena; + Slice elements; + A->unflatten = tree_flatten(A->A, tree, elements); + auto cap = mpy::object::checked_steal( + PyCapsule_New(A, "arena", free_unflatten_arena)); + auto unflatten = mpy::object::checked_steal( + PyCFunction_New(&py_unflatten_def, cap.release())); + mpy::tuple r(2); + r.set(0, slice_to_list(elements)); + r.set(1, std::move(unflatten)); + return r.release(); + PY_END(nullptr) +} + +mpy::object tree_map( + Arena& A, + const std::function& fn, + mpy::handle agg) { + Slice elements; + auto unflatten = tree_flatten(A, agg, elements); + for (auto i : elements.enumerate()) { + elements[i] = fn(elements[i]); + } + return unflatten(elements); +} + +// prereq: isinstance(h, _Tensor) +int64_t _Tensor_ndim(mpy::handle h) { + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } + } + return r; + } + // Dim or DelayedMulTensor + return 0; +} + +mpy::handle handle_from_tensor(Arena& A, TensorRef t) { + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} +} // namespace +struct EnableAllLayers { + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort( + levels_to_dim_.begin(), + levels_to_dim_.end(), + [](mpy::hdl lhs, mpy::hdl rhs) { + return lhs->level_ < rhs->level_; + }); + + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer( + at::functorch::TransformType::Vmap, + batch_size, + at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } + } + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT( + at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == + to_remove - i); + } + } + + mpy::obj from_batched( + Arena& A, + at::Tensor batchedtensor, + bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); + } + TensorRef tensor; + at::functorch::BatchedTensorImpl* impl = maybeGetBatchedImpl(batchedtensor); + while (true) { + auto level = impl->level(); + AT_ASSERT( + level >= levels_start_ && + level < levels_start_ + levels_to_dim_.size()); + mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); + levels.insert(A, impl->bdim(), dim); + at::functorch::BatchedTensorImpl* nimpl = + maybeGetBatchedImpl(impl->value()); + if (!nimpl) { + tensor = impl->value(); + break; + } + impl = nimpl; + } + + mpy::obj self = Tensor::create(); + // grab ownership of the tensors + self->tensor_ = *tensor; + self->batchtensor_ = std::move(batchedtensor); + self->has_device_ = has_device; + self->capture_levels(levels); + return self; + } + void inplace_update_layers(TensorRef batchtensor, Slice levels) { + // XXX - requires a patch to functorch to att set_level + auto impl = maybeGetBatchedImpl(*batchtensor); + for (auto i : levels_to_dim_.reversed_enumerate()) { + if (!impl) { + break; + } + if (levels.contains(levels_to_dim_[i])) { + impl->_unsafe_set_level(levels_start_ + i); + impl = maybeGetBatchedImpl(impl->value()); + } + } + } + + private: + int64_t levels_start_{}; + Slice> levels_to_dim_; +}; + +namespace { +TensorRef _match_levels( + Arena& A, + TensorRef v, + Slice from_levels, + Slice to_levels, + bool drop_levels = false) { + if (from_levels == to_levels) { + return v; + } + // drop_levels -> if a dim appears in from_levels but not to_levels, it is + // assumed it has stride 0. + at::IntArrayRef sz = v->sizes(); + at::IntArrayRef sd = v->strides(); + AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); + Slice nsz; + Slice nsd; + for (auto l : to_levels) { + auto oidx = from_levels.index(l); + if (!oidx) { + nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); + nsd.append(A, 0); + } else { + auto idx = *oidx; + nsz.append(A, sz[idx]); + nsd.append(A, sd[idx]); + } + } + return A.autorelease(v->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + v->storage_offset())); +} +} // namespace +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (!pointwise_optimize) { + is_pointwise = false; + } + // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : + // "functorch") << " " << orig << "\n"; + + Slice> all_dims; + Slice flat_args; + auto unflatten_args = tree_flatten(A, args, flat_args); + TensorRef device_holding_tensor; + + Slice infos; + Slice result_levels; + for (auto f : flat_args) { + infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); + if (infos.back()) { + TensorInfo& info = infos.back(); + AT_ASSERT(is_pointwise || info.batchedtensor); + if (!device_holding_tensor && info.has_device) { + device_holding_tensor = infos.back().tensor; + } + for (auto l : info.levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + } + + if (is_pointwise) { + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef tensor = infos[i].tensor; + if (device_holding_tensor && !infos[i].has_device) { + tensor = A.autorelease(tensor->to(device_holding_tensor->device())); + } + auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); + flat_args[i] = handle_from_tensor(A, std::move(ml)); + } + } + + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + + mpy::object result = orig.call_vector(uargs); + + // fast wrap for normal case where operator just returns a tensor. + if (THPVariable_Check(result.ptr())) { + return Tensor::from_positional( + A, + THPVariable_Unpack(result.ptr()), + result_levels, + device_holding_tensor); + } + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, + THPVariable_Unpack(h.ptr()), + result_levels, + device_holding_tensor)); + } + return h; + }; + return tree_map(A, wrap, result); + } else { + // std::cout << orig << " calling functorch...\n"; + // std::cout << "rl: " << result_levels << "\n"; + EnableAllLayers guard(A, result_levels); + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef batched = infos[i].batchedtensor; + if (device_holding_tensor && !infos[i].has_device) { + batched = A.autorelease(batched->to(device_holding_tensor->device())); + } + guard.inplace_update_layers(batched, infos[i].levels); + flat_args[i] = handle_from_tensor(A, batched); + } + } + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + AT_ASSERT(flat_it.size() == 0); + mpy::object result = orig.call_vector(uargs); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(guard.from_batched( + A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); + } + return h; + }; + if (THPVariable_Check(result.ptr())) { + return guard.from_batched( + A, THPVariable_Unpack(result.ptr()), device_holding_tensor); + } + return tree_map(A, wrap, result); + } +} + +namespace { + +mpy::object __torch_function__( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (orig == torch_Tensor___mul__) { + AT_ASSERT(args.nargs == 2 && !args.has_keywords()); + auto lhs = args[0]; + auto rhs = args[1]; + if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && + _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { + bool has_device = false; + Slice levels; + for (auto i : args.enumerate_positional()) { + auto t = TensorInfo::create(A, args[i], false); + // something like a mask * rhs, which matrix multiplies don't correctly + // promote + if (!t.tensor->is_floating_point()) { + return run_torch_function(A, orig, args, is_pointwise); + } + has_device = has_device || t.has_device; + for (auto l : t.levels) { + if (!levels.contains(l)) { + levels.append(A, l); + } + } + } + // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; + return Tensor::create_delayed( + mpy::object::borrow(orig), args, levels, has_device); + } + } + return run_torch_function(A, orig, args, is_pointwise); +} + +mpy::vector_args as_vector_args( + Arena& A, + mpy::handle args, + mpy::handle kwargs) { + auto pos_args = (mpy::handle*)&PyTuple_GET_ITEM(args.ptr(), 0); + auto pos_n = PyTuple_GET_SIZE(args.ptr()); + if (!kwargs.ptr()) { + return mpy::vector_args(pos_args, pos_n, nullptr); + } + Slice all_args; + Slice kwnames; + all_args.extend(A, pos_args, pos_args + pos_n); + mpy::dict_view dv(kwargs); + Py_ssize_t pos = 0; + mpy::handle key, value; + while (dv.next(&pos, &key, &value)) { + all_args.append(A, value); + kwnames.append(A, key); + } + return mpy::vector_args( + all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); +} + +PyObject* py___torch_function__( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + AT_ASSERT(nargs == 4 || nargs == 5); + auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); + bool is_pointwise = pointwise.contains(args[1]); + return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); + PY_END(nullptr) +} + +mpy::object levels_to_tuple(Slice slice) { + mpy::tuple t(slice.size()); + for (auto i : slice.enumerate()) { + t.set( + i, + slice[i].is_positional() ? mpy::from_int(slice[i].position()) + : mpy::object::borrow(slice[i].dim())); + } + mpy::object r = std::move(t); + return r; +} + +PyObject* Tensor_ndim(Tensor* self, void*) { + Py_ssize_t i = 0; + for (auto l : self->levels()) { + if (l.is_positional()) { + ++i; + } + } + return mpy::from_int(i).release(); +} + +PyGetSetDef Tensor_getsetters[] = { + {"_has_device", + (getter)[](PyObject* self, void*) + ->PyObject* { + return mpy::from_bool(((Tensor*)self)->has_device()).release(); +} // namespace +, NULL +} +, {"_tensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->tensor(A)); +} +, NULL +} +, {"_batchtensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); +} +, NULL +} +, + {"_levels", + (getter)[](PyObject* self, void*) + ->PyObject* {PY_BEGIN return levels_to_tuple(((Tensor*)self)->levels()) + .release(); +PY_END(nullptr) +} +} +, {"ndim", (getter)Tensor_ndim, NULL, "ndim", NULL}, { + NULL +} /* Sentinel */ +} +; + +PyMethodDef Tensor_methods[] = { + {NULL, NULL, 0, NULL} /* Sentinel */ +}; +} + +PyTypeObject Tensor::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.Tensor", /* tp_name */ + sizeof(Tensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + Tensor::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Tensor Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tensor_methods, /* tp_methods */ + 0, /* tp_members */ + Tensor_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + Tensor::new_stub, /* tp_new */ +}; + +// dim() -------------------- + +static bool relevant_op(_Py_CODEUNIT c) { + switch (c) { + case STORE_NAME: + case STORE_GLOBAL: + case STORE_FAST: + case STORE_DEREF: + return true; + default: + return false; + } +} + +static mpy::object create_dim(mpy::object name, mpy::handle size) { + auto d = Dim::create(std::move(name)); + if (!mpy::is_none(size)) { + d->set_size(mpy::to_int(size)); + } + return std::move(d); +} + +static mpy::object create_dimlist(mpy::object name, mpy::handle size) { + auto d = DimList::create(std::move(name)); + if (!mpy::is_none(size)) { + if (mpy::is_int(size)) { + d->bind_len(mpy::to_int(size)); + } else { + mpy::sequence_view s(size); + d->bind_len(s.size()); + for (auto i : irange(d->size())) { + d->dims_[i]->set_size(mpy::to_int(s[i])); + } + } + } + return std::move(d); +} + +// Python wrappers that make new reflection primitives available for older +// runtimes +#if !(IS_PYTHON_3_11_PLUS) +#define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) +#endif + +namespace { +struct PyInstDecoder { + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), + code_(_PyCode_CODE(code_object)), + offset_(lasti / sizeof(_Py_CODEUNIT)) {} + // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols + // See https://github.com/pytorch/pytorch/issues/93854 + void next() { +#if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; +#endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); +#if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; +#endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } + + mpy::object name() { + mpy::object names; + switch (opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = mpy::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = mpy::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = mpy::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return mpy::object(); + } + return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } + + private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; +}; + +template +static PyObject* _dims( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Py_ssize_t specified_ndims = -1; + Py_ssize_t found_ndims = 0; + Py_ssize_t sizes = -1; + mpy::handle n = Py_None; + mpy::handle py_sizes = Py_None; + + if (nargs || kwnames) { + mpy::vector_args va(args, nargs, kwnames); + va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); + if (!mpy::is_none(py_sizes)) { + sizes = mpy::sequence_view(py_sizes).size(); + specified_ndims = sizes; + } + if (!mpy::is_none(n)) { + specified_ndims = mpy::to_int(n); + } + } + + PyThreadState* state = PyThreadState_GET(); + auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); + auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); +#if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { + decoder.next(); + } +#endif + decoder.next(); + + if (relevant_op(decoder.opcode())) { + found_ndims = 1; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); + } + + if (specified_ndims == -1) { + if (found_ndims == 0) { + mpy::raise_error( + PyExc_SyntaxError, + "dims() must be assigned to a sequence of variable names or have argument n specified"); + } + specified_ndims = found_ndims; + } + if (found_ndims != specified_ndims) { + found_ndims = 0; // avoid taking the wrong names for dimensions + } + + auto genobject = [&](int i) -> mpy::object { + mpy::object name; + if (i < found_ndims) { + name = decoder.name(); + } + if (!name.ptr()) { + name = mpy::unicode_from_format("d%d", i); + found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); + } + return create_object( + std::move(name), + sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); + }; + if (sizes != -1 && sizes != specified_ndims) { + mpy::raise_error( + PyExc_ValueError, + "expected %d sizes but found %d", + int(specified_ndims), + int(sizes)); + } + if (specified_ndims == 1) { + return genobject(0).release(); + } + mpy::tuple result(specified_ndims); + for (int i = 0; i < specified_ndims; ++i) { + result.set(i, genobject(i)); + } + return result.release(); + PY_END(nullptr) +} + +struct DotPart { + Slice dims; + size_t total_size = 1; + void append(Arena& A, mpy::hdl d) { + total_size *= d->size(); + dims.append(A, d); + } +}; + +template +static at::ArrayRef as_array_ref(Slice t) { + return at::ArrayRef(t.begin(), t.end()); +} + +static TensorRef dot_prepare( + Arena& A, + std::initializer_list parts, + const TensorInfo& t) { + Slice new_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + new_levels.extend(A, p.dims); + } + auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); + if (!needs_reshape) { + return r; + } + Slice view; + for (auto p : parts) { + view.append(A, p.total_size); + } + return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); +} + +static mpy::object dot_finish( + Arena& A, + std::initializer_list parts, + at::Tensor r) { + Slice result_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; + } + result_levels.extend(A, p.dims); + } + if (needs_reshape) { + Slice new_size; + for (auto l : result_levels) { + new_size.append(A, l.dim()->size()); + } + r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + } + return Tensor::from_positional(A, std::move(r), result_levels, true); +} + +static mpy::object dot( + Arena& A, + TensorInfo lhs, + TensorInfo rhs, + Slice sum) { + auto lhs_strides = lhs.tensor->strides(); + auto rhs_strides = rhs.tensor->strides(); + + DotPart lro_dims; + DotPart lo_dims; + DotPart ro_dims; + DotPart lr_dims; + + auto insert_dim = [&](mpy::hdl d, + std::optional lhs_idx, + std::optional rhs_idx) { + bool reduced = sum.contains(d); + int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; + int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; + if (reduced) { + // lr + lr_dims.append(A, d); + } else { + if ((lhs_stride == 0) == (rhs_stride == 0)) { + // lro + lro_dims.append(A, d); + } else if (lhs_stride != 0) { + // lo + lo_dims.append(A, d); + } else { + AT_ASSERT(rhs_stride != 0); + ro_dims.append(A, d); + } + } + }; + + auto rhs_seen = A.allocate(rhs.levels.size()); + std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); + + for (auto i : lhs.levels.enumerate()) { + auto d = lhs.levels[i]; + auto rhs_idx = rhs.levels.index(d); + if (rhs_idx) { + rhs_seen[*rhs_idx] = true; + } + insert_dim(d.dim(), i, rhs_idx); + } + + for (auto i : rhs.levels.enumerate()) { + if (rhs_seen[i]) { + continue; + } + auto d = rhs.levels[i]; + insert_dim(d.dim(), std::nullopt, i); + } + + if (lr_dims.dims.size() != sum.size()) { + for (auto& d : sum) { + if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "summing over non-existent dimension %S", + d.dim().ptr()); + } + } + } + + // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; + // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << + // " " << lr_dims.dims << "\n"; + + // no batch, just call mm + if (lro_dims.dims.size() != 0) { + auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); + return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + } else { + auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); + return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + } +} + +static PyObject* test_c( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + + Arena A; + Slice s(A, 3, 4, 5); + AT_ASSERT(s.size() == 3 && s.capacity() == 8); + AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); + s.append(A, 6); + AT_ASSERT(s[3] == 6); + for (int i : irange(10)) { + s.append(A, i); + } + AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); + + Slice s2(A, -1, -2, -3); + AT_ASSERT(s2[1] == -2 && s[0] == 3); + + auto ss = s.slice(1, 2); + AT_ASSERT(ss.size() == 1); + AT_ASSERT(ss[0] == 4); + AT_ASSERT(ss.capacity() == 1); + ss.append(A, -4); + AT_ASSERT(ss.size() == 2 && ss[1] == -4); + ss[0] = 3; + AT_ASSERT(s[1] == 4); + + s.insert(A, s.slice(1, 4), ss); + AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); + + auto sz = s.size(); + s.insert(A, s.slice(1, 1), 4); + AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + + Slice d(A, 0, 1, 2, 3, 4); + + Slice b(A, 0, 1, 2, 3, 4); + b.insert(A, b.slice(1, 1), d); + AT_ASSERT(b.size() == 10); + AT_ASSERT(b[1] == 0); + AT_ASSERT(b[5] == 4); + AT_ASSERT(b.back() == 4); + + Py_RETURN_NONE; + + PY_END(nullptr); +} + +static PyObject* order( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + if (kwnames) { + mpy::raise_error( + PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } + AT_ASSERT(nargs-- > 0); + Slice orig_levels; + Slice levels; + TensorRef data; + mpy::handle self = args++[0]; + bool has_device; + if (Tensor::check_exact(self)) { + auto t = Tensor::unchecked_wrap(self); + orig_levels = t->levels(); + data = t->tensor(A); + has_device = t->has_device(); + } else { + auto d = Dim::unchecked_wrap(self); + orig_levels.append(A, d); + data = d->range(); + has_device = false; + } + + Slice flat_positional_dims; + Slice> to_flatten; + levels.extend(A, orig_levels); + + int orig_ndim = ndim_of_levels(levels); + auto append = [&](DimEntry d) { + auto midx = levels.index(d); + if (!midx) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_ValueError, + "tensor has %d positional dimensions, but %d specified, or it was specified twice", + int(orig_ndim), + int(d.position() + orig_ndim)); + } else { + mpy::raise_error( + PyExc_ValueError, + "tensor of dimensions %R does not contain dim %R or it was specified twice", + levels_to_tuple(orig_levels).ptr(), + d.dim().ptr()); + } + } + levels[*midx] = DimEntry(); + flat_positional_dims.append(A, d); + }; + + int n_new_positional = 0; + for (auto i : irange(nargs)) { + mpy::handle arg = args[i]; + DimEntry entry = _wrap_dim(arg, orig_ndim, false); + if (!entry.is_none()) { + append(entry); + ++n_new_positional; + } else if (DimList::check(arg)) { + auto dl = DimList::unchecked_wrap(arg); + for (mpy::obj& d : dl->dims_) { + append(mpy::hdl(d)); + ++n_new_positional; + } + } else { + ++n_new_positional; + if (!mpy::is_sequence(arg)) { + mpy::raise_error( + PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); + } + mpy::sequence_view sq(arg); + auto N = sq.size(); + to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); + for (auto j : irange(N)) { + DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); + if (e.is_none()) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); + } + append(e); + } + } + } + + int insert_point = -1; + Slice new_levels; + for (auto l : levels) { + if (l.is_none()) { + continue; + } + if (l.is_positional()) { + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } + } + new_levels.append(A, l); + } + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } + + at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); + + if (to_flatten.size()) { + Slice view; + auto sz = ndata.sizes(); + // before the new positional dims + for (auto i : irange(0, insert_point)) { + view.append(A, sz[i]); + } + int i = 0; + for (auto to_flat : to_flatten) { + for (; i < to_flat.first; ++i) { + view.append(A, sz[insert_point + i]); + } + int64_t new_size = 1; + int last = i + to_flat.second; + for (; i < last; ++i) { + new_size *= sz[insert_point + i]; + } + view.append(A, new_size); + } + for (; i < flat_positional_dims.size(); ++i) { + view.append(A, sz[insert_point + i]); + } + // after the new positional dims + for (auto i : + irange(insert_point + flat_positional_dims.size(), levels.size())) { + view.append(A, sz[i]); + } + // we shorted the number of dimension, so remove them from new levels + // we will renumber them later + auto n_to_remove = flat_positional_dims.size() - n_new_positional; + new_levels.insert( + A, + new_levels.slice(insert_point, insert_point + n_to_remove), + Slice()); + ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + } + + // renumber the positional dimension + int seen = 0; + for (auto i : new_levels.reversed_enumerate()) { + if (new_levels[i].is_positional() || + (i >= insert_point && i < insert_point + n_new_positional)) { + new_levels[i] = --seen; + } + } + return Tensor::from_positional(A, std::move(ndata), new_levels, has_device) + .release(); + + PY_END(nullptr) +} + +static PyObject* expand( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs-- > 0); + auto info = TensorInfo::create(A, args++[0], false); + for (auto i : irange(nargs)) { + if (!Dim::check(args[i])) { + maybeInitializeGlobals(); + mpy::vector_args vargs(args - 1, nargs + 1, kwnames); + if (THPVariable_Check(args[-1])) { + return torch_Tensor_expand.call_vector(vargs).release(); + } else { + return __torch_function__(A, torch_Tensor_expand, vargs, false) + .release(); + } + } + } + const at::Tensor& data = *info.tensor; + auto levels = info.levels; + Slice new_levels; + Slice sz; + Slice sd; + for (auto i : irange(nargs)) { + auto d = Dim::unchecked_wrap(args[i]); + if (levels.contains(d) || new_levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "expanding dimension %R already exists in tensor with dims", + d.ptr()); + } + new_levels.append(A, d); + sz.append(A, d->size()); + sd.append(A, 0); + } + new_levels.extend(A, levels); + at::IntArrayRef osz = data.sizes(); + at::IntArrayRef osd = data.strides(); + sz.extend(A, osz.begin(), osz.end()); + sd.extend(A, osd.begin(), osd.end()); + at::Tensor ndata = data.as_strided( + at::IntArrayRef(sz.begin(), sz.end()), + at::IntArrayRef(sd.begin(), sd.end()), + data.storage_offset()); + return Tensor::from_positional( + A, std::move(ndata), new_levels, info.has_device) + .release(); + PY_END(nullptr) +} + +static void _bind_dims_to_size( + Arena& A, + int64_t sz, + int64_t sd, + Slice> dims, + Slice& nsz, + Slice& nsd) { + int64_t rhs_prod = 1; + for (auto i : dims.enumerate()) { + if (!dims[i]->is_bound()) { + for (auto j : irange(i + 1, dims.size())) { + if (!dims[j]->is_bound()) { + mpy::raise_error( + DimensionBindError(), + "cannot infer the sizes of two dimensions at once %R and %R", + dims[i].ptr(), + dims[j].ptr()); + } + rhs_prod *= dims[j]->size(); + } + if (sz % rhs_prod != 0) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set( + j, + dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) + : mpy::unicode_from_string("?")); + } + mpy::raise_error( + DimensionBindError(), + "inferred dimension does not evenly fit into larger dimension: %d vs %R", + (int)sz, + tup.ptr()); + } + int64_t inferred_size = sz / rhs_prod; + dims[i]->set_size(inferred_size); + rhs_prod = sz; + break; + } + rhs_prod *= dims[i]->size(); + } + if (rhs_prod != sz) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, mpy::object::borrow(dims[j])); + } + mpy::raise_error( + DimensionBindError(), + "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", + (int)sz, + (int)rhs_prod, + tup.ptr()); + } + auto new_strides = A.allocate(dims.size()); + auto prev_stride = sd; + for (auto i : dims.reversed_enumerate()) { + new_strides[i] = prev_stride; + prev_stride = dims[i]->size() * prev_stride; + } + for (auto i : dims.enumerate()) { + nsd.append(A, new_strides[i]); + nsz.append(A, dims[i]->size()); + } +} + +static bool has_dims(mpy::handle d) { + return Dim::check_exact(d) || Tensor::check_exact(d); +} + +struct IndexingInfo { + bool can_call_original; // if true, then it is safe to just call getitem or + // setitem, these objects do not need special handling + bool advanced_indexing; // requires actual lookup + TensorRef self; + Slice flat_inputs; + Slice result_levels; + bool has_device; +}; +} // namespace + +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none); +namespace { +Slice as_slice(mpy::tuple_view tv) { + PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); +} + +Slice as_slice(mpy::list_view tv) { + PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); +} + +bool maybe_dimpack( + Slice& elements, + mpy::handle s, + bool check_first = true) { + // can we avoid rechecking? + if (mpy::list_view::check(s)) { + mpy::list_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + // can we avoid rechecking? + if (mpy::tuple_view::check(s)) { + mpy::tuple_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; + } + } + return false; +}; + +bool is_dimpack(mpy::handle s) { + Slice e; + return maybe_dimpack(e, s); +} + +mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { + at::Tensor rtensor; + if (iinfo.advanced_indexing) { + auto self_hdl = handle_from_tensor(A, iinfo.self); + auto tup = slice_to_tuple(iinfo.flat_inputs); + // std::cout << "calling original getindex " << self_hdl << " " << tup << + // "\n"; + auto pytensor = mpy::object::checked_steal( + THPVariable_getitem(self_hdl.ptr(), tup.ptr())); + rtensor = THPVariable_Unpack(pytensor.ptr()); + } else { + // std::cout << "skipping original getindex\n"; + rtensor = *iinfo.self; + } + // std::cout << "returning (from_positional)\n"; + return Tensor::from_positional( + A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); +} + +mpy::object index( + Arena& A, + mpy::handle self, + mpy::handle dims, + mpy::handle indices) { + maybeInitializeGlobals(); + Slice dims_list; + Slice indices_list; + // we allow for matching single dims to multiple dims, + // so we first have to normalize everything into the case where there is a + // list on lhs and the rhs + bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); + bool rhs_list = + mpy::tuple_view::check(indices) || mpy::list_view::check(indices); + if (lhs_list && rhs_list) { + mpy::sequence_view dv(dims); + mpy::sequence_view ind(indices); + Py_ssize_t N = dv.size(); + if (N != ind.size()) { + mpy::raise_error( + PyExc_TypeError, + "dims (%d) and indices (%d) must have the same length", + int(N), + int(ind.size())); + } + for (auto i : irange(N)) { + dims_list.append(A, A.autorelease(dv[i])); + indices_list.append(A, A.autorelease(ind[i])); + } + } else { + dims_list.append(A, dims); + indices_list.append(A, indices); + } + + // dims being indexed can be grouped together into a single index space, and + // we have to flatten them int a single dimension before we can index them... + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + Slice new_levels; + Slice to_flatten; + Slice dims_list_flat; + auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { + auto d = _wrap_dim(s, ndim, false); + if (d.is_none()) { + mpy::raise_error( + PyExc_TypeError, + "expected a dimension specifyer but found %R", + s.ptr()); + } + return d; + }; + auto dim_not_present = [&](DimEntry d) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_TypeError, + "dimension %d not in tensor of %d dimensions", + d.position() + ndim, + ndim); + } else { + mpy::raise_error( + PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); + } + }; + + for (auto i : dims_list.enumerate()) { + Slice m; + if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { + if (m.size() == 0) { + // plausible semantics work for this to have 0 elements (e.g. the index + // will always be 0) + dims_list_flat.append(A, DimEntry()); // value is just dropped + } + auto first = parse_dim_entry(m[0]); + dims_list_flat.append(A, first); + if (m.size() == 1) { + continue; + } + if (to_flatten.size() == 0) { + new_levels.extend(A, self_info.levels); + } + Slice rest; + for (auto i : irange(1, m.size())) { + auto d = parse_dim_entry(m[i]); + if (!new_levels.remove(A, d)) { + dim_not_present(d); + } + rest.append(A, d); + } + + auto first_idx = new_levels.index(first); + if (!first_idx) { + dim_not_present(first); + } + new_levels.insert( + A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); + to_flatten.extend(A, rest); + } else { + dims_list_flat.append(A, parse_dim_entry(dims_list[i])); + } + } + if (to_flatten.size() > 0) { + TensorRef rearranged = + _match_levels(A, self_info.tensor, self_info.levels, new_levels); + at::IntArrayRef sizes = rearranged->sizes(); + Slice new_sizes; + Slice reshape_levels; + for (auto i : new_levels.enumerate()) { + if (to_flatten.contains(new_levels[i])) { + new_sizes.back() *= sizes[i]; + } else { + new_sizes.append(A, sizes[i]); + reshape_levels.append(A, new_levels[i]); + } + } + self_info.tensor = A.autorelease(rearranged->reshape( + at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); + + self_info.levels = + reshape_levels; // note: we are using the first level in a flattened + // group to represent the group for the rest of the op + // we need to be careful not to rely the dimensions size + // because it doesn't match the size of the whole group + } + bool has_dimpacks = false; + for (auto idx : indices_list) { + if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { + has_dimpacks = true; + break; + } + } + IndexingInfo info = getsetitem_flat( + A, + self_info, + Slice(), + dims_list_flat, + indices_list, + has_dimpacks); + return invoke_getitem(A, info); +} + +// true -- the indices were flattened out of a tuple, list or sequence... + +Slice slice_from_sequence(Arena& A, mpy::handle value) { + if (mpy::tuple_view::check(value)) { + return as_slice(mpy::tuple_view(value)); + } else if (mpy::list_view::check(value)) { + return as_slice(mpy::list_view(value)); + } else { + mpy::sequence_view sv(value); + Slice r; + for (auto i : sv.enumerate()) { + r.append(A, A.autorelease(sv[i])); + } + return r; + } +} + +bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { + if (mpy::tuple_view::check(index)) { + indices.extend(A, as_slice(mpy::tuple_view(index))); + return true; + } else if (THPVariable_Check(index.ptr())) { + indices.append(A, index); + return false; + } else if (!mpy::is_sequence(index)) { + indices.append(A, index); + return false; + } + // a copy of treatSequenceAsTuple modified to add Dim and our wrapped + // tensors.. + mpy::sequence_view sv(index); + if (sv.size() >= 32) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + for (auto i : sv.enumerate()) { + mpy::handle item; + try { + item = sv[i]; + } catch (mpy::exception_set& e) { + PyErr_Clear(); + indices.append(A, index); + return false; + } + if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || + PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || + mpy::is_none(item) || has_dims(item)) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + } + indices.append(A, index); + return false; +} + +IndexingInfo getsetitem( + Arena& A, + mpy::handle self, + mpy::handle index, + bool tensors_have_dims) { + bool can_call_original_getitem = !tensors_have_dims; + + Slice input; + if (has_dims(index)) { + input.append(A, index); + } else { + bool is_sequence = extractIndices(A, index, input); + // nothing about first class dims here, fallback to getitem + if (can_call_original_getitem && !is_sequence) { + return {true}; + } + } + + int64_t dims_indexed = 0; + int64_t expanding_object = -1; + DimList* unbound_dim_list = nullptr; + auto check_expanding = [&](int64_t i) { + if (expanding_object != -1) { + mpy::raise_error( + DimensionBindError(), + "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", + (int)expanding_object, + (int)i); + } + expanding_object = i; + }; + Slice dimlists; + + // calculate how many dimensioned have been indexed in order to compute the + // size of ... or expand a potentially unbound dimension list. + + bool has_dimpacks_or_none = false; + for (auto i : input.enumerate()) { + mpy::handle s = input[i]; + if (Dim::check_exact(s) || Tensor::check_exact(s)) { + can_call_original_getitem = false; + ++dims_indexed; + } else if (s.ptr() == Py_Ellipsis) { + check_expanding(i); + } else if (DimList::check(s)) { + can_call_original_getitem = false; + auto dl = DimList::unchecked_wrap(s); + if (!dl->is_bound()) { + check_expanding(i); + unbound_dim_list = dl.ptr(); + } else { + dims_indexed += dl->dims_.size(); + } + dimlists.append(A, i); + } else if (mpy::is_none(s)) { + has_dimpacks_or_none = true; + } else if (is_dimpack(s)) { + can_call_original_getitem = false; + has_dimpacks_or_none = true; + ++dims_indexed; + } else { + ++dims_indexed; + } + } + + // at this point if we haven't seen any Dim objects, we also can fallback to + // the original getitem. + if (can_call_original_getitem) { + return {true}; + } + + // std::cout << "__getitem__ " << self << " " << index << "\n"; + + TensorInfo self_info = TensorInfo::create(A, self, false, true); + auto ndim = self_info.ndim(); + if (dims_indexed > ndim) { + mpy::raise_error( + PyExc_ValueError, + "at least %d indices were supplied but the tensor only has %d dimensions", + (int)dims_indexed, + (int)ndim); + } + // expand any unbound dimension list, or expand ... into individual : slices. + auto expanding_dims = ndim - dims_indexed; + if (expanding_object != -1) { + if (unbound_dim_list) { + unbound_dim_list->bind_len(expanding_dims); + } else { + // ... + Slice no_slices; + for (auto i : irange(expanding_dims)) { + (void)i; + no_slices.append(A, no_slice); + } + input.insert( + A, input.slice(expanding_object, expanding_object + 1), no_slices); + } + } + + // flatten out any dimensions stored in dimlist elements directly into the + // inputs std::cout << dimlists << " <- dim lists!\n"; + for (int64_t i = dimlists.size() - 1; i >= 0; --i) { + auto idx = dimlists[i]; + // we added more elements to input because of ... + // so we need to also adjust the index to get back to where the + // dimlist existed + if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { + idx += expanding_dims; + } + auto dl = DimList::unchecked_wrap(input[idx]); + // XXX would be better if we used an OwnedSlice in DimList + Slice more_dims( + (mpy::handle*)&*dl->dims_.begin(), (mpy::handle*)&*dl->dims_.end()); + input.insert(A, input.slice(idx, idx + 1), more_dims); + } + + return getsetitem_flat( + A, + self_info, + input, + Slice(), + Slice(), + has_dimpacks_or_none); +} +} // namespace +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none) { + // At this point: + // ..., DimList have been eliminated + // Dim, Tensor, Tuple[Dim,...], int, slice still remain + + // we have to count how many times we see a dimension. + // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires + // advanced indexing. + Slice> seen_dims; + Slice seen_dims_nuses; + auto add_dim = [&](mpy::hdl entry) { + auto midx = seen_dims.index(entry); + if (!midx) { + seen_dims.append(A, entry); + seen_dims_nuses.append(A, 1); + } else { + ++seen_dims_nuses[*midx]; + } + }; + + Slice input_it = input; + + Slice flat_inputs; + // flat inputs will start with an empty mpy::handle if the + // actual value is in the tensor-like object in the tensor info + Slice tensor_inputs; + + auto append_flat_handle = [&](mpy::handle h) { + flat_inputs.append(A, h); + tensor_inputs.append(A, TensorInfo()); + }; + TensorRef device_holding_tensor; + auto append_tensor_input = [&](TensorInfo ti) { + flat_inputs.append(A, mpy::handle()); + tensor_inputs.append(A, ti); + if (ti.has_device && !device_holding_tensor) { + device_holding_tensor = ti.tensor; + } + }; + + Slice nsz; + Slice nsd; + at::IntArrayRef sz = self_info.tensor->sizes(); + at::IntArrayRef sd = self_info.tensor->strides(); + + auto append_size = [&](int i) { + if (has_dimpacks_or_none) { + nsz.append(A, sz[i]); + nsd.append(A, sd[i]); + } + }; + // std::cout << "self levels: " << self_info.levels << "\n"; + + auto parse_nones = [&]() { + while (input_it.size() && mpy::is_none(input_it[0])) { + append_flat_handle(no_slice); + nsz.append(A, 1); + nsd.append(A, 0); + input_it = input_it.slice(1); + } + }; + + auto append_item = [&](int i, mpy::handle arg) { + if (Dim::check_exact(arg)) { + auto d = Dim::unchecked_wrap(arg); + d->set_size(sz[i]); + add_dim(d); + append_size(i); + append_flat_handle(arg); + return; + } + auto info = TensorInfo::create(A, arg, false, false); + if (info) { + append_size(i); + append_tensor_input(info); + for (auto il : info.levels) { + if (!il.is_positional()) { + add_dim(il.dim()); + } + } + return; + } + + if (has_dimpacks_or_none) { + Slice mp; + if (maybe_dimpack(mp, arg)) { + // dim pack + Slice> dim_pack; + for (auto d : mp) { + dim_pack.append(A, Dim::wrap(d)); + add_dim(dim_pack.back()); + append_flat_handle(dim_pack.back()); + } + _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); + return; + } + } + + append_size(i); + append_flat_handle(arg); + }; + + // pair up the indexing expressions with dimension of self it indexes + // self may have first-class dims, which do not participate the indexing. + for (auto i : self_info.levels.enumerate()) { + auto l = self_info.levels[i]; + auto idx = keys.index(l); + if (idx) { + append_item(i, values[*idx]); + } else if (l.is_positional()) { + // grab and index from the positional list + parse_nones(); + if (!input_it.size()) { + // we might have fewer indices than tensor dimensions, + // which implicitly indexes the remaining dimensions with : + append_flat_handle(no_slice); + append_size(i); + } else { + mpy::handle arg = input_it[0]; + input_it = input_it.slice(1); + append_item(i, arg); + } + } else { + add_dim(l.dim()); + append_flat_handle(l.dim()); + append_size(i); + } + } + // any training Nones may have no existing dimension associated with them in + // self. + parse_nones(); + + // we have to restride the tensor to collapse dimension packs and introduce + // our none dimensions. + if (has_dimpacks_or_none) { + self_info.tensor = A.autorelease(self_info.tensor->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + self_info.tensor->storage_offset())); + } + + // figure out what the shape of the indexing tensors will be + // and what the shape of the resulting tensor will be + Slice result_levels; + Slice index_levels; + int64_t tensor_insert_point = -1; + bool requires_getindex = false; + auto mark_tensor_index = [&] { + if (tensor_insert_point == -1) { + tensor_insert_point = result_levels.size(); + } else if (tensor_insert_point != result_levels.size()) { + tensor_insert_point = 0; + } + }; + for (auto i : flat_inputs.enumerate()) { + auto inp = flat_inputs[i]; + if (tensor_inputs[i]) { + requires_getindex = true; + mark_tensor_index(); + for (auto l : tensor_inputs[i].levels) { + // std::cout << "Consider to add " << l << "\n"; + if (!index_levels.contains(l)) { + index_levels.append(A, l); + } + } + } else if (Dim::check_exact(inp)) { + auto d = Dim::unchecked_wrap(inp); + // dimensions used once are just binding operations + if (1 == seen_dims_nuses[*seen_dims.index(d)]) { + flat_inputs[i] = no_slice; + result_levels.append(A, d); + } else { + requires_getindex = true; + flat_inputs[i] = mpy::handle(); + tensor_inputs[i] = TensorInfo{ + d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; + if (!index_levels.contains(d)) { + index_levels.append(A, d); + } + mark_tensor_index(); + } + } else { + if (inp.ptr() != no_slice.ptr()) { + requires_getindex = true; + } + if (!mpy::is_int(inp)) { + // note: actual positional indexes are accurately computed later + result_levels.append(A, -1); + } + } + } + + // indexing dimensions appear in the tensor at the _first use of a tensor_ in + // the indexing. So insert the indexing leveles into the result klevels at + // this spot + if (tensor_insert_point != -1) { + result_levels.insert( + A, + result_levels.slice(tensor_insert_point, tensor_insert_point), + index_levels); + } + + // std::cout << "flat inputs: " << flat_inputs << "\n"; + // std::cout << "result_levels: " << result_levels << "\n"; + // std::cout << "index_levels: " << index_levels << "\n"; + + // get all the tensors to be the right shape for indexing + if (requires_getindex) { + for (auto i : flat_inputs.enumerate()) { + if (tensor_inputs[i]) { + AT_ASSERT(!flat_inputs[i].ptr()); + // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << + // "\n"; + TensorRef t = tensor_inputs[i].tensor; + if (!tensor_inputs[i].has_device && device_holding_tensor) { + t = A.autorelease(t->to(device_holding_tensor->device())); + } + flat_inputs[i] = handle_from_tensor( + A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); + } + } + } + + // previously we didn't know how many positional dimensions there would be so + // we couldn't number them right so fill it in now. + auto seen_positionals = 0; + for (auto i : result_levels.reversed_enumerate()) { + if (result_levels[i].is_positional()) { + result_levels[i] = -(++seen_positionals); + } + } + + return IndexingInfo{ + false, + requires_getindex, + self_info.tensor, + flat_inputs, + result_levels, + self_info.has_device}; +} +namespace { +mpy::object __getitem__(Arena& A, mpy::handle self, mpy::handle index) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self)); + if (iinfo.can_call_original) { + return mpy::object::checked_steal( + THPVariable_getitem(self.ptr(), index.ptr())); + } + + return invoke_getitem(A, iinfo); +} + +void __setitem__( + Arena& A, + mpy::handle self, + mpy::handle index, + mpy::handle rhs) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); + if (iinfo.can_call_original) { + if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + return; + } + + auto rhs_info = TensorInfo::create(A, rhs, false, false); + if (rhs_info) { // otherwise rhs can be a scalar... + for (auto l : rhs_info.levels) { + if (!iinfo.result_levels.contains(l)) { + if (l.is_positional()) { + mpy::raise_error( + DimensionBindError(), + "rhs contains too many dimensions (%d) compared to indexed value (%d)", + ndim_of_levels(iinfo.result_levels), + rhs_info.ndim()); + } else { + auto tup = levels_to_tuple(iinfo.result_levels); + mpy::raise_error( + DimensionBindError(), + "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", + l.dim().ptr(), + tup.ptr()); + } + } + } + auto rhs_matched = + _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); + rhs = handle_from_tensor(A, rhs_matched); + } + self = handle_from_tensor(A, iinfo.self); + + if (iinfo.advanced_indexing) { + auto tup = slice_to_tuple(iinfo.flat_inputs); + if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + } else { + torch_Tensor_copy_.call(self, rhs); + } +} +} // namespace + +PyObject* Tensor_getitem(PyObject* self, PyObject* index) { + Arena A; + PY_BEGIN + return __getitem__(A, self, index).release(); + PY_END(nullptr); +} + +int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { + Arena A; + PY_BEGIN + __setitem__(A, self, index, value); + return 0; + PY_END(-1); +} + +namespace { +PyObject* py___getitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 2); + return __getitem__(A, args[0], args[1]).release(); + PY_END(nullptr) +} + +PyObject* py___setitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 3); + __setitem__(A, args[0], args[1], args[2]); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* py_index( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, dims, indices; + va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); + return index(A, self, dims, indices).release(); + PY_END(nullptr) +} + +PyObject* py_stack( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle tensors, new_dim, dim; + va.parse( + "stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); + + Slice result_levels; + Slice infos; + mpy::sequence_view sv(tensors); + auto new_dim_d = Dim::wrap(new_dim); + for (auto i : sv.enumerate()) { + infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); + for (auto l : infos.back().levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + new_dim_d->set_size(infos.size()); + std::vector inputs; + inputs.reserve(infos.size()); + for (auto in : infos) { + inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); + } + auto ndim = ndim_of_levels(result_levels); + int64_t rawdim = 0; + if (dim.ptr()) { + auto d = _wrap_dim(dim, ndim, false); + auto idx = result_levels.index(d); + if (!idx) { + mpy::raise_error( + PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); + } + rawdim = *idx; + } + auto result = at::stack(inputs, rawdim); + result_levels.insert(A, rawdim, new_dim_d); + return Tensor::from_positional(A, std::move(result), result_levels, true) + .release(); + PY_END(nullptr) +} + +PyObject* py_split( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, split_size_or_sections, dim; + va.parse( + "split", + {"self", "split_size_or_sections", "dim"}, + {&self, &split_size_or_sections, &dim}, + 2); + bool dim_is_object = dim.ptr() && Dim::check_exact(dim); + Slice sizes; + + bool all_dims = true; + bool all_ints = true; + + if (!mpy::is_int(split_size_or_sections)) { + mpy::sequence_view sv(split_size_or_sections); + for (auto i : sv.enumerate()) { + sizes.append(A, A.autorelease(sv[i])); + if (Dim::check_exact(sizes.back())) { + all_ints = false; + } else { + all_dims = false; + } + } + } + if (all_ints) { + if (dim_is_object) { + mpy::raise_error( + PyExc_TypeError, + "when dim is specified as a Dim object, split sizes must also be dimensions."); + } + // call original split (if self has dimensions this will use torch function + // to do the split) + return torch_Tensor_split + .call_vector(mpy::vector_args(args, nargs, kwnames)) + .release(); + } + if (!all_dims) { + mpy::raise_error( + PyExc_TypeError, "split list must be ints or dims but got a mix"); + } + + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + if (!dim_is_object && ndim == 0) { + mpy::raise_error( + PyExc_TypeError, "split expects at least a 1-dimension tensor"); + } + DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; + + auto idx = self_info.levels.index(dim_l); + if (!idx) { + if (!dim.ptr()) { + dim = A.autorelease(mpy::from_int(0)); + } + mpy::raise_error( + PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + } + Slice indices; + + int64_t total_size = 0; + Slice unbound; + for (auto i : sizes.enumerate()) { + auto d = Dim::unchecked_wrap(sizes[i]); + if (d->is_bound()) { + indices.append(A, d->size()); + total_size += indices.back(); + } else { + indices.append(A, 0); + unbound.append(A, i); + } + } + auto tensor_size = self_info.tensor->sizes()[*idx]; + + if (unbound.size()) { + if (total_size > tensor_size) { + mpy::raise_error( + PyExc_TypeError, + "sizes of target dimensions add up to more (%d) than source dim (%d)", + int(total_size), + int(tensor_size)); + } + auto remaining_size = tensor_size - total_size; + auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); + for (auto u : unbound) { + auto sz = std::min(chunk_size, remaining_size); + Dim::unchecked_wrap(sizes[u])->set_size(sz); + indices[u] = sz; + remaining_size -= sz; + } + } else if (tensor_size != total_size) { + mpy::raise_error( + PyExc_TypeError, + "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", + int(total_size), + int(tensor_size)); + } + + auto result_tensors = self_info.tensor->split_with_sizes( + at::IntArrayRef(indices.begin(), indices.end()), *idx); + mpy::tuple result(result_tensors.size()); + Slice new_levels; + new_levels.extend(A, self_info.levels); + for (auto i : sizes.enumerate()) { + new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); + result.set( + i, + Tensor::from_positional( + A, std::move(result_tensors[i]), new_levels, true)); + } + + return result.release(); + + PY_END(nullptr) +} + +Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { + auto de = _wrap_dim(d, N, keepdim); + Slice r; + if (!de.is_none()) { + r.append(A, de); + } else { + mpy::sequence_view sq(d); + for (auto i : sq.enumerate()) { + r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); + } + } + return r; +} + +struct WrappedOperator : public mpy::base { + mpy::object orig; + PyMethodDef method_def; + mpy::object name, doc; + + bool is_pointwise = false; + int64_t dim_offset = 0; + int64_t keepdim_offset = 1; + std::string dim_name; + bool single_dim = false; + bool reduce = true; + + static PyTypeObject Type; + + void init( + mpy::object orig_, + PyCFunction wrapper_implementation, + std::string dim_name_ = "") { + orig = std::move(orig_); + method_def.ml_meth = wrapper_implementation; + name = orig.attr("__name__"); + doc = orig.attr("__doc__"); + dim_name = std::move(dim_name_); + if (!mpy::is_none(doc) && !dim_name.empty()) { + doc = mpy::unicode_from_format( + "%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", + doc.ptr(), + dim_name.c_str()); + } + method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); + method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); + method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; + } + + mpy::object function() { + return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + } +}; +} // namespace + +PyTypeObject WrappedOperator::Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_C.WrappedOperator", /* tp_name */ + sizeof(WrappedOperator), /* tp_basicsize */ + 0, /* tp_itemsize */ + WrappedOperator::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "Wrapped Object Holder", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + WrappedOperator::new_stub, /* tp_new */ +}; + +namespace { +PyObject* patched_dim_method( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + auto self = WrappedOperator::unchecked_wrap(self_); + PY_BEGIN + + mpy::vector_args va(args, nargs, kwnames); + + auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + return idx == -1 ? mpy::handle() : va[idx]; + }; + Slice patched_args; + patched_args.extend(A, va.begin(), va.end()); + auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + if (idx == -1) { + mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); + } + patched_args[idx] = value; + }; + + auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); + if (!dim.ptr()) { + auto info = TensorInfo::create(A, args[0], true); + EnableAllLayers l(A, info.levels); + l.inplace_update_layers(info.batchedtensor, info.levels); + patched_args[0] = handle_from_tensor(A, info.batchedtensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device) + .release(); + } + + auto info = TensorInfo::create(A, args[0]); + auto keepdim = false; + if (self->reduce) { + auto py_keepdim = _getarg("keepdim", self->keepdim_offset); + if (py_keepdim.ptr()) { + keepdim = mpy::to_bool(py_keepdim); + } + } + + auto ndim = info.ndim(); + auto dims = _wrap_dims(A, dim, ndim, keepdim); + Slice dim_indices; + auto seen = A.allocate(info.levels.size()); + std::fill(seen, seen + info.levels.size(), false); + + for (auto d : dims) { + auto midx = info.levels.index(d); + if (!midx) { + auto tup = levels_to_tuple(info.levels); + mpy::raise_error( + PyExc_ValueError, + "Tensor with dimensions %R does not contain one of %R\n", + tup.ptr(), + dim.ptr()); + } + seen[*midx] = true; + dim_indices.append(A, *midx); + } + Slice new_levels; + if (self->reduce && !keepdim) { + for (auto i : info.levels.enumerate()) { + if (!seen[i]) { + new_levels.append(A, info.levels[i]); + } + } + } else { + new_levels = info.levels; + } + mpy::object py_indices; + if (dim_indices.size() == 1) { + py_indices = mpy::from_int(dim_indices[0]); + } else { + mpy::tuple tup(dim_indices.size()); + for (auto i : dim_indices.enumerate()) { + tup.set(i, mpy::from_int(dim_indices[i])); + } + py_indices = std::move(tup); + } + _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); + patched_args[0] = handle_from_tensor(A, info.tensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); + } + return h; + }; + return tree_map(A, wrap, r).release(); + PY_END(nullptr) +} + +PyObject* _wrap( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + +#define ARGS(_) \ + _(mpy::handle, orig) \ + _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ + _(mpy::handle, dim_name) _(mpy::handle, single_dim) \ + _(mpy::handle, reduce) + MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) + + std::string dim_name_str; + if (dim_name.ptr()) { + dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); + } else { + dim_name_str = "dim"; + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), + (PyCFunction)(void*)patched_dim_method, + std::move(dim_name_str)); + if (dim_offset.ptr()) { + info->dim_offset = mpy::to_int(dim_offset); + } + if (keepdim_offset.ptr()) { + info->keepdim_offset = mpy::to_int(keepdim_offset); + } + + if (single_dim.ptr()) { + info->single_dim = mpy::to_bool(single_dim); + } + if (reduce.ptr()) { + info->reduce = mpy::to_bool(reduce); + } + return info->function().release(); +#undef ARGS + + PY_END(nullptr) +} + +PyObject* call_torch_function( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Arena A; + maybeInitializeGlobals(); + auto info = WrappedOperator::unchecked_wrap(self); + return __torch_function__( + A, + info->orig, + mpy::vector_args(args, nargs, kwnames), + info->is_pointwise) + .release(); + PY_END(nullptr) +} + +PyObject* _wrap_method( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + AT_ASSERT(nargs == 2); + // XXX - ignore python function wrapped, we will call torch function directly + mpy::handle orig = args[0]; + if (!pointwise.ptr()) { + auto dim = mpy::import("functorch.dim"); + pointwise = dim.attr("pointwise"); + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), (PyCFunction)(void*)call_torch_function); + info->is_pointwise = pointwise.contains(orig); + return PyInstanceMethod_New(info->function().release()); + PY_END(nullptr); +} + +PyObject* Tensor_sum( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + auto self_ = Tensor::unchecked_wrap(args[0]); + auto d = self_->delayed(); + if (!d) { + return _Tensor_sum.call_vector(va).release(); + } + mpy::handle self, dim, keepdim, dtype; + va.parse( + "sum", + {"self", "dim", "keepdim", "dtype"}, + {&self, &dim, &keepdim, &dtype}, + 1, + 1); + + if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { + // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; + return _Tensor_sum.call_vector(va).release(); + } + auto levels = self_->levels(); + + auto N = ndim_of_levels(levels); + auto reduced_dims = _wrap_dims(A, dim, N, false); + + return dot(A, + TensorInfo::create(A, d->args[0], false), + TensorInfo::create(A, d->args[1], false), + reduced_dims) + .release(); + PY_END(nullptr) +} + +PyObject* _parse_test( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + maybeInitializeGlobals(); + + int required = mpy::to_int(args[0]); + int kwonly = mpy::to_int(args[1]); + + mpy::vector_args va(args + 2, nargs - 2, kwnames); + + mpy::handle a, b, c, d; + va.parse( + "_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); + mpy::tuple r(4); + r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); + r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); + r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); + r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); + return r.release(); + + PY_END(nullptr) +} + +PyObject* _set_pointwise_optimize( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle value; + mpy::vector_args va(args, nargs, kwnames); + va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); + pointwise_optimize = mpy::to_bool(value); + Py_RETURN_NONE; + PY_END(nullptr) +} + +PyObject* _patch_tensor_class( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + + auto torch = mpy::import("torch"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + replaceMappingIfMatches(py_TensorBase); + + Py_RETURN_NONE; + PY_END(nullptr) +} + +const char* dims_doc = R"""( +dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] + +Creates and returns one or more Dim objects. + +Arg: + n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified. + sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be + created, specifying each dimensions size, or None to leave the size unset. + +Example:: + >>> batch, channel, width, height = dims(4) + >>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224]) +)"""; + +PyMethodDef methods[] = { + {"dims", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS, + dims_doc}, + {"dimlists", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS}, + {"_test_c", (PyCFunction)(void*)test_c, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap_method", + (PyCFunction)(void*)_wrap_method, + METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_from_positional", + (PyCFunction)(void*)py_Tensor_from_positional, + METH_FASTCALL | METH_KEYWORDS}, + {"__torch_function__", + (PyCFunction)(void*)py___torch_function__, + METH_FASTCALL | METH_KEYWORDS}, + {"tree_flatten", + (PyCFunction)(void*)py_tree_flatten, + METH_FASTCALL | METH_KEYWORDS}, + {"order", (PyCFunction)(void*)order, METH_FASTCALL | METH_KEYWORDS}, + {"index", (PyCFunction)(void*)py_index, METH_FASTCALL | METH_KEYWORDS}, + {"stack", (PyCFunction)(void*)py_stack, METH_FASTCALL | METH_KEYWORDS}, + {"split", (PyCFunction)(void*)py_split, METH_FASTCALL | METH_KEYWORDS}, + {"expand", (PyCFunction)(void*)expand, METH_FASTCALL | METH_KEYWORDS}, + {"__getitem__", + (PyCFunction)(void*)py___getitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"__setitem__", + (PyCFunction)(void*)py___setitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"_wrap", (PyCFunction)(void*)_wrap, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_sum", + (PyCFunction)(void*)Tensor_sum, + METH_FASTCALL | METH_KEYWORDS}, + {"_parse_test", + (PyCFunction)(void*)_parse_test, + METH_FASTCALL | METH_KEYWORDS}, + {"_set_pointwise_optimize", + (PyCFunction)(void*)_set_pointwise_optimize, + METH_FASTCALL | METH_KEYWORDS}, + {"_patch_tensor_class", + (PyCFunction)(void*)_patch_tensor_class, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_C", /* name of module */ + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + methods}; +} // namespace + +PyObject* Dim_init() { + Arena A; + try { + mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); + Dim::ready(mod, "Dim"); + DimList::ready(mod, "DimList"); + Tensor::ready(mod, "Tensor"); + WrappedOperator::ready(mod, "_WrappedOperator"); + Py_INCREF(&PyInstanceMethod_Type); + PyModule_AddObject( + mod.ptr(), "_instancemethod", (PyObject*)&PyInstanceMethod_Type); + + initializeGlobals(A); + return mod.release(); + } catch (mpy::exception_set& err) { + return nullptr; + } +} + +#endif diff --git a/functorch/csrc/dim/dim.h b/functorch/csrc/dim/dim.h new file mode 100644 index 000000000000..627caa729fc2 --- /dev/null +++ b/functorch/csrc/dim/dim.h @@ -0,0 +1,8 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +#pragma once +#include +PyObject* Dim_init(); diff --git a/functorch/csrc/dim/dim_opcode.c b/functorch/csrc/dim/dim_opcode.c new file mode 100644 index 000000000000..1b5d06773445 --- /dev/null +++ b/functorch/csrc/dim/dim_opcode.c @@ -0,0 +1,17 @@ +#include +#if defined(_WIN32) && IS_PYTHON_3_11_PLUS +#define Py_BUILD_CORE +#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt, _PyOpcode_Caches + +#if IS_PYTHON_3_13_PLUS +#include // To get PyUnstable_Code_GetFirstFree +#define NEED_OPCODE_METADATA +#include "internal/pycore_opcode_metadata.h" +#undef NEED_OPCODE_METADATA +#else +#include "internal/pycore_opcode.h" +#endif + +#undef NEED_OPCODE_TABLES +#undef Py_BUILD_CORE +#endif diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h new file mode 100644 index 000000000000..ceced399b40d --- /dev/null +++ b/functorch/csrc/dim/minpybind.h @@ -0,0 +1,692 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include + +#define PY_BEGIN try { +#define PY_END(v) } catch(mpy::exception_set & err) { return (v); } + +#if PY_VERSION_HEX < 0x03080000 + #define PY_VECTORCALL _PyObject_FastCallKeywords +#else + #define PY_VECTORCALL _PyObject_Vectorcall +#endif + +struct irange { + public: + irange(int64_t end) + : irange(0, end, 1) {} + irange(int64_t begin, int64_t end, int64_t step = 1) + : begin_(begin), end_(end), step_(step) {} + int64_t operator*() const { + return begin_; + } + irange& operator++() { + begin_ += step_; + return *this; + } + bool operator!=(const irange& other) { + return begin_ != other.begin_; + } + irange begin() { + return *this; + } + irange end() { + return irange {end_, end_, step_}; + } + private: + int64_t begin_; + int64_t end_; + int64_t step_; +}; + +namespace mpy { + +struct exception_set { +}; + +struct object; +struct vector_args; + +struct handle { + handle(PyObject* ptr) + : ptr_(ptr) {} + handle() = default; + + + PyObject* ptr() const { + return ptr_; + } + object attr(const char* key); + bool hasattr(const char* key); + handle type() const { + return (PyObject*) Py_TYPE(ptr()); + } + + template + object call(Args&&... args); + object call_object(mpy::handle args); + object call_object(mpy::handle args, mpy::handle kwargs); + object call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames); + object call_vector(vector_args args); + bool operator==(handle rhs) { + return ptr_ == rhs.ptr_; + } + + static handle checked(PyObject* ptr) { + if (!ptr) { + throw exception_set(); + } + return ptr; + } + +protected: + PyObject* ptr_ = nullptr; +}; + + +template +struct obj; + +template +struct hdl : public handle { + T* ptr() { + return (T*) handle::ptr(); + } + T* operator->() { + return ptr(); + } + hdl(T* ptr) + : hdl((PyObject*) ptr) {} + hdl(const obj& o) + : hdl(o.ptr()) {} +private: + hdl(handle h) : handle(h) {} +}; + +struct object : public handle { + object() = default; + object(const object& other) + : handle(other.ptr_) { + Py_XINCREF(ptr_); + } + object(object&& other) noexcept + : handle(other.ptr_) { + other.ptr_ = nullptr; + } + object& operator=(const object& other) { + return *this = object(other); + } + object& operator=(object&& other) noexcept { + PyObject* tmp = ptr_; + ptr_ = other.ptr_; + other.ptr_ = tmp; + return *this; + } + ~object() { + Py_XDECREF(ptr_); + } + static object steal(handle o) { + return object(o.ptr()); + } + static object checked_steal(handle o) { + if (!o.ptr()) { + throw exception_set(); + } + return steal(o); + } + static object borrow(handle o) { + Py_XINCREF(o.ptr()); + return steal(o); + } + PyObject* release() { + auto tmp = ptr_; + ptr_ = nullptr; + return tmp; + } +protected: + explicit object(PyObject* ptr) + : handle(ptr) {} +}; + +template +struct obj : public object { + obj() = default; + obj(const obj& other) + : object(other.ptr_) { + Py_XINCREF(ptr_); + } + obj(obj&& other) noexcept + : object(other.ptr_) { + other.ptr_ = nullptr; + } + obj& operator=(const obj& other) { + return *this = obj(other); + } + obj& operator=(obj&& other) noexcept { + PyObject* tmp = ptr_; + ptr_ = other.ptr_; + other.ptr_ = tmp; + return *this; + } + static obj steal(hdl o) { + return obj(o.ptr()); + } + static obj checked_steal(hdl o) { + if (!o.ptr()) { + throw exception_set(); + } + return steal(o); + } + static obj borrow(hdl o) { + Py_XINCREF(o.ptr()); + return steal(o); + } + T* ptr() const { + return (T*) object::ptr(); + } + T* operator->() { + return ptr(); + } +protected: + explicit obj(T* ptr) + : object((PyObject*)ptr) {} +}; + + +static bool isinstance(handle h, handle c) { + return PyObject_IsInstance(h.ptr(), c.ptr()); +} + +[[ noreturn ]] inline void raise_error(handle exception, const char *format, ...) { + va_list args; + va_start(args, format); + PyErr_FormatV(exception.ptr(), format, args); + va_end(args); + throw exception_set(); +} + +template +struct base { + PyObject_HEAD + PyObject* ptr() const { + return (PyObject*) this; + } + static obj alloc(PyTypeObject* type = nullptr) { + if (!type) { + type = &T::Type; + } + auto self = (T*) type->tp_alloc(type, 0); + if (!self) { + throw mpy::exception_set(); + } + new (self) T; + return obj::steal(self); + } + template + static obj create(Args ... args) { + auto self = alloc(); + self->init(std::forward(args)...); + return self; + } + static bool check(handle v) { + return isinstance(v, (PyObject*)&T::Type); + } + + static hdl unchecked_wrap(handle self_) { + return hdl((T*)self_.ptr()); + } + static hdl wrap(handle self_) { + if (!check(self_)) { + raise_error(PyExc_ValueError, "not an instance of %S", &T::Type); + } + return unchecked_wrap(self_); + } + + static obj unchecked_wrap(object self_) { + return obj::steal(unchecked_wrap(self_.release())); + } + static obj wrap(object self_) { + return obj::steal(wrap(self_.release())); + } + + static PyObject* new_stub(PyTypeObject *type, PyObject *args, PyObject *kwds) { + PY_BEGIN + return (PyObject*) alloc(type).release(); + PY_END(nullptr) + } + static void dealloc_stub(PyObject *self) { + ((T*)self)->~T(); + Py_TYPE(self)->tp_free(self); + } + static void ready(mpy::handle mod, const char* name) { + if (PyType_Ready(&T::Type)) { + throw exception_set(); + } + if(PyModule_AddObject(mod.ptr(), name, (PyObject*) &T::Type) < 0) { + throw exception_set(); + } + } +}; + +inline object handle::attr(const char* key) { + return object::checked_steal(PyObject_GetAttrString(ptr(), key)); +} + +inline bool handle::hasattr(const char* key) { + return PyObject_HasAttrString(ptr(), key); +} + +inline object import(const char* module) { + return object::checked_steal(PyImport_ImportModule(module)); +} + +template +inline object handle::call(Args&&... args) { + return object::checked_steal(PyObject_CallFunctionObjArgs(ptr_, args.ptr()..., nullptr)); +} + +inline object handle::call_object(mpy::handle args) { + return object::checked_steal(PyObject_CallObject(ptr(), args.ptr())); +} + + +inline object handle::call_object(mpy::handle args, mpy::handle kwargs) { + return object::checked_steal(PyObject_Call(ptr(), args.ptr(), kwargs.ptr())); +} + +inline object handle::call_vector(mpy::handle* begin, Py_ssize_t nargs, mpy::handle kwnames) { + return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) begin, nargs, kwnames.ptr())); +} + +struct tuple : public object { + void set(int i, object v) { + PyTuple_SET_ITEM(ptr_, i, v.release()); + } + tuple(int size) + : object(checked_steal(PyTuple_New(size))) {} +}; + +struct list : public object { + void set(int i, object v) { + PyList_SET_ITEM(ptr_, i, v.release()); + } + list(int size) + : object(checked_steal(PyList_New(size))) {} +}; + +namespace{ +mpy::object unicode_from_format(const char* format, ...) { + va_list args; + va_start(args, format); + auto r = PyUnicode_FromFormatV(format, args); + va_end(args); + return mpy::object::checked_steal(r); +} +mpy::object unicode_from_string(const char * str) { + return mpy::object::checked_steal(PyUnicode_FromString(str)); +} + +mpy::object from_int(Py_ssize_t s) { + return mpy::object::checked_steal(PyLong_FromSsize_t(s)); +} +mpy::object from_bool(bool b) { + return mpy::object::borrow(b ? Py_True : Py_False); +} + +bool is_sequence(handle h) { + return PySequence_Check(h.ptr()); +} +} + +struct sequence_view : public handle { + sequence_view(handle h) + : handle(h) {} + Py_ssize_t size() const { + auto r = PySequence_Size(ptr()); + if (r == -1 && PyErr_Occurred()) { + throw mpy::exception_set(); + } + return r; + } + irange enumerate() const { + return irange(size()); + } + static sequence_view wrap(handle h) { + if (!is_sequence(h)) { + raise_error(PyExc_ValueError, "expected a sequence"); + } + return sequence_view(h); + } + mpy::object operator[](Py_ssize_t i) const { + return mpy::object::checked_steal(PySequence_GetItem(ptr(), i)); + } +}; + +namespace { +mpy::object repr(handle h) { + return mpy::object::checked_steal(PyObject_Repr(h.ptr())); +} + +mpy::object str(handle h) { + return mpy::object::checked_steal(PyObject_Str(h.ptr())); +} + + +bool is_int(handle h) { + return PyLong_Check(h.ptr()); +} + +bool is_none(handle h) { + return h.ptr() == Py_None; +} + +Py_ssize_t to_int(handle h) { + Py_ssize_t r = PyLong_AsSsize_t(h.ptr()); + if (r == -1 && PyErr_Occurred()) { + throw mpy::exception_set(); + } + return r; +} + +bool to_bool(handle h) { + return PyObject_IsTrue(h.ptr()) != 0; +} +} + +struct slice_view { + slice_view(handle h, Py_ssize_t size) { + if(PySlice_Unpack(h.ptr(), &start, &stop, &step) == -1) { + throw mpy::exception_set(); + } + slicelength = PySlice_AdjustIndices(size, &start, &stop, step); + } + Py_ssize_t start, stop, step, slicelength; +}; + +static bool is_slice(handle h) { + return PySlice_Check(h.ptr()); +} + +inline std::ostream& operator<<(std::ostream& ss, handle h) { + ss << PyUnicode_AsUTF8(str(h).ptr()); + return ss; +} + +struct tuple_view : public handle { + tuple_view() = default; + tuple_view(handle h) : handle(h) {} + + Py_ssize_t size() const { + return PyTuple_GET_SIZE(ptr()); + } + + irange enumerate() const { + return irange(size()); + } + + handle operator[](Py_ssize_t i) { + return PyTuple_GET_ITEM(ptr(), i); + } + + static bool check(handle h) { + return PyTuple_Check(h.ptr()); + } +}; + +struct list_view : public handle { + list_view() = default; + list_view(handle h) : handle(h) {} + Py_ssize_t size() const { + return PyList_GET_SIZE(ptr()); + } + + irange enumerate() const { + return irange(size()); + } + + handle operator[](Py_ssize_t i) { + return PyList_GET_ITEM(ptr(), i); + } + + static bool check(handle h) { + return PyList_Check(h.ptr()); + } +}; + +struct dict_view : public handle { + dict_view() = default; + dict_view(handle h) : handle(h) {} + object keys() const { + return mpy::object::checked_steal(PyDict_Keys(ptr())); + } + object values() const { + return mpy::object::checked_steal(PyDict_Values(ptr())); + } + object items() const { + return mpy::object::checked_steal(PyDict_Items(ptr())); + } + bool contains(handle k) const { + return PyDict_Contains(ptr(), k.ptr()); + } + handle operator[](handle k) { + return mpy::handle::checked(PyDict_GetItem(ptr(), k.ptr())); + } + static bool check(handle h) { + return PyDict_Check(h.ptr()); + } + bool next(Py_ssize_t* pos, mpy::handle* key, mpy::handle* value) { + PyObject *k = nullptr, *v = nullptr; + auto r = PyDict_Next(ptr(), pos, &k, &v); + *key = k; + *value = v; + return r; + } + void set(handle k, handle v) { + if (-1 == PyDict_SetItem(ptr(), k.ptr(), v.ptr())) { + throw exception_set(); + } + } +}; + + +struct kwnames_view : public handle { + kwnames_view() = default; + kwnames_view(handle h) : handle(h) {} + + Py_ssize_t size() const { + return PyTuple_GET_SIZE(ptr()); + } + + irange enumerate() const { + return irange(size()); + } + + const char* operator[](Py_ssize_t i) const { + PyObject* obj = PyTuple_GET_ITEM(ptr(), i); + return PyUnicode_AsUTF8(obj); + } + + static bool check(handle h) { + return PyTuple_Check(h.ptr()); + } +}; + +inline mpy::object funcname(mpy::handle func) { + if (func.hasattr("__name__")) { + return func.attr("__name__"); + } else { + return mpy::str(func); + } +} + +struct vector_args { + vector_args(PyObject *const *a, + Py_ssize_t n, + PyObject *k) + : vector_args((mpy::handle*)a, n, k) {} + vector_args(mpy::handle* a, + Py_ssize_t n, + mpy::handle k) + : args((mpy::handle*)a), nargs(n), kwnames(k) {} + mpy::handle* args; + Py_ssize_t nargs; + kwnames_view kwnames; + + mpy::handle* begin() { + return args; + } + mpy::handle* end() { + return args + size(); + } + + mpy::handle operator[](int64_t i) const { + return args[i]; + } + bool has_keywords() const { + return kwnames.ptr(); + } + irange enumerate_positional() { + return irange(nargs); + } + irange enumerate_all() { + return irange(size()); + } + int64_t size() const { + return nargs + (has_keywords() ? kwnames.size() : 0); + } + + // bind a test function so this can be tested, first two args for required/kwonly, then return what was parsed... + + // provide write kwarg + // don't provide a required arg + // don't provide an optional arg + // provide a kwarg that is the name of already provided positional + // provide a kwonly argument positionally + // provide keyword arguments in the wrong order + // provide only keyword arguments + void parse(const char * fname_cstr, std::initializer_list names, std::initializer_list values, int required, int kwonly=0) { + auto error = [&]() { + // rather than try to match the slower infrastructure with error messages exactly, once we have detected an error, just use that + // infrastructure to format it and throw it + + // have to leak this, because python expects these to last + const char** names_buf = new const char*[names.size() + 1]; + std::copy(names.begin(), names.end(), &names_buf[0]); + names_buf[names.size()] = nullptr; + +#if PY_VERSION_HEX < 0x03080000 + char* format_str = new char[names.size() + 3]; + int i = 0; + char* format_it = format_str; + for (auto it = names.begin(); it != names.end(); ++it, ++i) { + if (i == required) { + *format_it++ = '|'; + } + if (i == (int)names.size() - kwonly) { + *format_it++ = '$'; + } + *format_it++ = 'O'; + } + *format_it++ = '\0'; + _PyArg_Parser* _parser = new _PyArg_Parser{format_str, &names_buf[0], fname_cstr, 0}; + PyObject *dummy = NULL; + _PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy); +#else + _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0}; + auto buf = std::make_unique(names.size()); + _PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]); +#endif + throw exception_set(); + }; + + auto values_it = values.begin(); + auto names_it = names.begin(); + auto npositional = values.size() - kwonly; + + if (nargs > (Py_ssize_t)npositional) { + // TOO MANY ARGUMENTS + error(); + } + for (auto i : irange(nargs)) { + *(*values_it++) = args[i]; + ++names_it; + } + + if (!kwnames.ptr()) { + if (nargs < required) { + // not enough positional arguments + error(); + } + } else { + int consumed = 0; + for (auto i : irange(nargs, values.size())) { + bool success = i >= required; + const char* target_name = *(names_it++); + for (auto j : kwnames.enumerate()) { + if (!strcmp(target_name,kwnames[j])) { + *(*values_it) = args[nargs + j]; + ++consumed; + success = true; + break; + } + } + ++values_it; + if (!success) { + // REQUIRED ARGUMENT NOT SPECIFIED + error(); + } + } + if (consumed != kwnames.size()) { + // NOT ALL KWNAMES ARGUMENTS WERE USED + error(); + } + } + } + int index(const char* name, int pos) { + if (pos < nargs) { + return pos; + } + if (kwnames.ptr()) { + for (auto j : kwnames.enumerate()) { + if (!strcmp(name, kwnames[j])) { + return nargs + j; + } + } + } + return -1; + } +}; + +inline object handle::call_vector(vector_args args) { + return object::checked_steal(PY_VECTORCALL(ptr(), (PyObject*const*) args.args, args.nargs, args.kwnames.ptr())); +} + + +} + +#define MPY_ARGS_NAME(typ, name) #name , +#define MPY_ARGS_DECLARE(typ, name) typ name; +#define MPY_ARGS_POINTER(typ, name) &name , +#define MPY_PARSE_ARGS_KWARGS(fmt, FORALL_ARGS) \ + static char* kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \ + FORALL_ARGS(MPY_ARGS_DECLARE) \ + if (!PyArg_ParseTupleAndKeywords(args, kwargs, fmt, kwlist, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \ + throw mpy::exception_set(); \ + } + +#define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \ + static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \ + FORALL_ARGS(MPY_ARGS_DECLARE) \ + static _PyArg_Parser parser = {fmt, kwlist, 0}; \ + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \ + throw mpy::exception_set(); \ + } diff --git a/functorch/csrc/dim/python_variable_simple.h b/functorch/csrc/dim/python_variable_simple.h new file mode 100644 index 000000000000..d8c22ca312e3 --- /dev/null +++ b/functorch/csrc/dim/python_variable_simple.h @@ -0,0 +1,49 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +// note: pytorch's python variable simple includes pybind which conflicts with minpybind +// so this file just reproduces the minimal API needed to extract Tensors from python objects. + +#include +#include +#include + +// Python object that backs torch.autograd.Variable +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +struct THPVariable { + PyObject_HEAD; + // Payload + c10::MaybeOwned cdata; + // Hooks to be run on backwards pass (corresponds to Python attr + // '_backwards_hooks', set by 'register_hook') + PyObject* backward_hooks = nullptr; +}; + +TORCH_PYTHON_API extern PyObject *THPVariableClass; +TORCH_PYTHON_API extern PyObject *ParameterClass; + +TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var); + +inline bool THPVariable_Check(PyObject *obj) +{ + if (!THPVariableClass) + return false; + + const auto result = PyObject_IsInstance(obj, THPVariableClass); + AT_ASSERT(result != -1); + return result; +} + +inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { + return *var->cdata; +} + +inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { + return THPVariable_Unpack(reinterpret_cast(obj)); +} + +TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); diff --git a/functorch/csrc/init_dim_only.cpp b/functorch/csrc/init_dim_only.cpp new file mode 100644 index 000000000000..88d4cbcff795 --- /dev/null +++ b/functorch/csrc/init_dim_only.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +namespace at { +namespace functorch { + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // initialize first-class dims and install it as a submodule on _C + auto dim = Dim_init(); + if (!dim) { + throw py::error_already_set(); + } + py::setattr(m, "dim", py::reinterpret_steal(dim)); +} + +}} diff --git a/setup.py b/setup.py index 7f0104846e3e..2bb63a93cec8 100644 --- a/setup.py +++ b/setup.py @@ -382,6 +382,12 @@ def _get_package_path(package_name: str) -> Path: BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL")) BUILD_PYTHON_ONLY = str2bool(os.getenv("BUILD_PYTHON_ONLY")) +# set up appropriate env variables +if BUILD_LIBTORCH_WHL: + # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so + # functorch is not supported without python + os.environ["BUILD_FUNCTORCH"] = "OFF" + if BUILD_PYTHON_ONLY: os.environ["BUILD_LIBTORCHLESS"] = "ON" os.environ["LIBTORCH_LIB_PATH"] = (_get_package_path("torch") / "lib").as_posix() @@ -1244,6 +1250,21 @@ class build_ext(setuptools.command.build_ext.build_ext): def build_extensions(self) -> None: self.create_compile_commands() + build_lib = Path(self.build_lib).resolve() + + # Copy functorch extension + for ext in self.extensions: + if ext.name != "functorch._C": + continue + fullname = self.get_ext_fullname(ext.name) + filename = Path(self.get_ext_filename(fullname)) + src = filename.with_stem("functorch") + dst = build_lib / filename + if src.exists(): + report(f"Copying {ext.name} from {src} to {dst}") + dst.parent.mkdir(parents=True, exist_ok=True) + self.copy_file(src, dst) + super().build_extensions() def get_outputs(self) -> list[str]: @@ -1531,6 +1552,11 @@ def configure_extension_build() -> tuple[ ) ext_modules.append(C) + # These extensions are built by cmake and copied manually in build_extensions() + # inside the build_ext implementation + if cmake_cache_vars["BUILD_FUNCTORCH"]: + ext_modules.append(Extension(name="functorch._C", sources=[])) + cmdclass = { "bdist_wheel": bdist_wheel, "build_ext": build_ext,