From 1b99c1859c3a50a8c2ca7198c8427dfd31178a68 Mon Sep 17 00:00:00 2001 From: PaliC Date: Thu, 24 Jul 2025 11:22:56 -0700 Subject: [PATCH] [BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427) This PR is a bit more involved but effectively works to drastically simplify PyObjectSlot and PyInterpreter. 1) For PyObjectSlot we now use a global pyinterpreter since there only is one. From here we change all of the call sites to rely on this assumption. 2) We also remove the "tags" of the PyInterpreter by deprecating `PyInterpreterStatus`. For the reviewer, sadly it seems like `functorch/csrc/dim/dim.cpp` needed to get linted, so there is an unreadable amount of changes there. Fortunately, the only actual change in the file is as follows which just removes `getPyInterpreter()` from the `check_pyobj` call. ``` mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} -} + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); +} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158427 Approved by: https://github.com/albanD --- build_variables.bzl | 1 + c10/core/impl/PyInterpreter.h | 20 - c10/core/impl/PyInterpreterHooks.cpp | 32 + c10/core/impl/PyInterpreterHooks.h | 39 + c10/core/impl/PyObjectSlot.cpp | 5 - c10/core/impl/PyObjectSlot.h | 20 +- functorch/csrc/dim/dim.cpp | 5768 ++++++++++++----------- torch/_dynamo/trace_rules.py | 1 - torch/csrc/Module.cpp | 10 +- torch/csrc/PyInterpreter.cpp | 6 +- torch/csrc/PyInterpreter.h | 2 +- torch/csrc/PyInterpreterHooks.h | 15 + torch/csrc/Storage.cpp | 47 +- torch/csrc/Storage.h | 1 - torch/csrc/StorageMethods.cpp | 5 +- torch/csrc/StorageSharing.cpp | 18 +- torch/csrc/autograd/python_variable.cpp | 62 +- torch/csrc/utils/python_dispatch.cpp | 18 +- 18 files changed, 3224 insertions(+), 2846 deletions(-) create mode 100644 c10/core/impl/PyInterpreterHooks.cpp create mode 100644 c10/core/impl/PyInterpreterHooks.h create mode 100644 torch/csrc/PyInterpreterHooks.h diff --git a/build_variables.bzl b/build_variables.bzl index 6f55b156f8a5..1dda77b63750 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -865,6 +865,7 @@ libtorch_python_core_sources = [ "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/PyInterpreter.cpp", + "torch/csrc/PyInterpreterHooks.cpp", "torch/csrc/python_dimname.cpp", "torch/csrc/Size.cpp", "torch/csrc/Storage.cpp", diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index 43492443c530..09d4801f7d83 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -240,24 +240,4 @@ struct C10_API PyInterpreter { void disarm() noexcept; }; -// PyInterpreterStatus describes what the state of its interpreter tag -// is, relative to the thread currently holding the GIL. -enum class PyInterpreterStatus { - // We just allocated the Tensor, it hasn't escaped to other threads, - // we know that it definitely hasn't been tagged to be associated - // with an interpreter. - DEFINITELY_UNINITIALIZED, - // We queried the interpreter field and it looked uninitialized. But - // another thread may have raced with us to tag it with some other - // interpreter id. So we will have to do a CEX to make sure we can - // actually nab it. - MAYBE_UNINITIALIZED, - // We queried the interpreter field and it was tagged to belong to us. - // This means we have sole write access (as we hold the GIL for this - // interpreter) - TAGGED_BY_US, - // Someone else tagged this. We can't use this TensorImpl from Python. - TAGGED_BY_OTHER, -}; - } // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.cpp b/c10/core/impl/PyInterpreterHooks.cpp new file mode 100644 index 000000000000..bd5325cf49c2 --- /dev/null +++ b/c10/core/impl/PyInterpreterHooks.cpp @@ -0,0 +1,32 @@ +#include + +namespace c10::impl { + +// Define the registry +C10_DEFINE_REGISTRY( + PyInterpreterHooksRegistry, + PyInterpreterHooksInterface, + PyInterpreterHooksArgs) + +const PyInterpreterHooksInterface& getPyInterpreterHooks() { + auto create_impl = [] { +#if !defined C10_MOBILE + auto hooks = PyInterpreterHooksRegistry()->Create( + "PyInterpreterHooks", PyInterpreterHooksArgs{}); + if (hooks) { + return hooks; + } +#endif + // Return stub implementation that will throw errors when methods are called + return std::make_unique(); + }; + static auto hooks = create_impl(); + return *hooks; +} + +// Main function to get global PyInterpreter +PyInterpreter* getGlobalPyInterpreter() { + return getPyInterpreterHooks().getPyInterpreter(); +} + +} // namespace c10::impl diff --git a/c10/core/impl/PyInterpreterHooks.h b/c10/core/impl/PyInterpreterHooks.h new file mode 100644 index 000000000000..32a17ad9a8a0 --- /dev/null +++ b/c10/core/impl/PyInterpreterHooks.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10::impl { + +// Minimal interface for PyInterpreter hooks +struct C10_API PyInterpreterHooksInterface { + virtual ~PyInterpreterHooksInterface() = default; + + // Get the PyInterpreter instance + // Stub implementation throws error when Python is not available + virtual PyInterpreter* getPyInterpreter() const { + TORCH_CHECK( + false, + "PyTorch was compiled without Python support. " + "Cannot access Python interpreter from C++."); + } +}; + +struct C10_API PyInterpreterHooksArgs{}; + +C10_DECLARE_REGISTRY( + PyInterpreterHooksRegistry, + PyInterpreterHooksInterface, + PyInterpreterHooksArgs); + +#define REGISTER_PYTHON_HOOKS(clsname) \ + C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname) + +// Get the global PyInterpreter hooks instance +C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks(); + +C10_API PyInterpreter* getGlobalPyInterpreter(); + +} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp index 62af2eae8e37..0f1bfb211074 100644 --- a/c10/core/impl/PyObjectSlot.cpp +++ b/c10/core/impl/PyObjectSlot.cpp @@ -34,11 +34,6 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { reinterpret_cast(pyobj_) & ~0x1ULL); } -void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load()); - pyobj_ = nullptr; -} - PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter) { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index af8b9fa4d0ec..58b2490eba00 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -24,11 +25,9 @@ struct C10_API PyObjectSlot { // // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after // PyObject if necessary! - void init_pyobj( - PyInterpreter* self_interpreter, - PyObject* pyobj, - PyInterpreterStatus status) { - pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); + void init_pyobj(PyObject* pyobj) { + pyobj_interpreter_.store( + getGlobalPyInterpreter(), std::memory_order_relaxed); pyobj_ = pyobj; } @@ -53,9 +52,10 @@ struct C10_API PyObjectSlot { // // NB: this lives in header so that we can avoid actually creating the // std::optional - std::optional check_pyobj( - PyInterpreter* self_interpreter, - bool ignore_hermetic_tls = false) const { + + // @todo alban: I'm not too sure what's going on here, we can probably delete + // it but it's worthwhile making sure + std::optional check_pyobj(bool ignore_hermetic_tls = false) const { impl::PyInterpreter* interpreter = pyobj_interpreter_.load(std::memory_order_acquire); if (interpreter == nullptr) { @@ -69,10 +69,6 @@ struct C10_API PyObjectSlot { } } - // Clear the PyObject field for an interpreter, in situations where we - // statically know the tensor is tagged with our interpreter. - void unchecked_clear_pyobj(PyInterpreter* interpreter); - PyInterpreter& load_pyobj_interpreter() const; bool owns_pyobj(); diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 19270d2f9225..8f1e561e2051 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -6,7 +6,6 @@ #include - // Many APIs have changed/don't exist anymore #if IS_PYTHON_3_12_PLUS @@ -14,24 +13,25 @@ // Re-enable this some day PyObject* Dim_init() { - PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); - return nullptr; + PyErr_SetString( + PyExc_RuntimeError, "First class dim doesn't work with python 3.12"); + return nullptr; } #else -#include "minpybind.h" #include #include -#include -#include #include +#include +#include #include -//#include -#include +#include "minpybind.h" +// #include +#include #include #include -#include +#include #include #include "arena.h" #include "dim.h" @@ -71,3115 +71,3498 @@ PyTypeObject* DimType = nullptr; PyObject* Tensor_getitem(PyObject* self, PyObject* index); int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value); -namespace{ +namespace { void maybeInitializeGlobals() { - // globals that depend on the python dim library, - // which we can't lookup until we finish initializing the _C module - if (_Tensor.ptr()) { - return; - } - auto dim = mpy::import("functorch.dim"); - _Tensor = dim.attr("_Tensor"); - pointwise = dim.attr("pointwise"); - _Tensor_sum = _Tensor.attr("sum"); - DimType = (PyTypeObject*) mpy::import("functorch.dim").attr("Dim").ptr(); + // globals that depend on the python dim library, + // which we can't lookup until we finish initializing the _C module + if (_Tensor.ptr()) { + return; + } + auto dim = mpy::import("functorch.dim"); + _Tensor = dim.attr("_Tensor"); + pointwise = dim.attr("pointwise"); + _Tensor_sum = _Tensor.attr("sum"); + DimType = (PyTypeObject*)mpy::import("functorch.dim").attr("Dim").ptr(); } void replaceMappingIfMatches(mpy::handle tp) { - auto T = (PyTypeObject*) tp.ptr(); - bool recurse = false; - if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { - T->tp_as_mapping->mp_subscript = Tensor_getitem; - recurse = true; - } - if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { - T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; - recurse = true; - } - if (recurse) { - auto result = tp.attr("__subclasses__").call(); - mpy::list_view lv(result); - for (auto i : lv.enumerate()) { - replaceMappingIfMatches(lv[i]); - } + auto T = (PyTypeObject*)tp.ptr(); + bool recurse = false; + if (T->tp_as_mapping->mp_subscript == THPVariable_getitem) { + T->tp_as_mapping->mp_subscript = Tensor_getitem; + recurse = true; + } + if (T->tp_as_mapping->mp_ass_subscript == THPVariable_setitem) { + T->tp_as_mapping->mp_ass_subscript = Tensor_setitem; + recurse = true; + } + if (recurse) { + auto result = tp.attr("__subclasses__").call(); + mpy::list_view lv(result); + for (auto i : lv.enumerate()) { + replaceMappingIfMatches(lv[i]); } + } } -void initializeGlobals(Arena & A) { - auto torch = mpy::import("torch"); - torch_Tensor = (PyTypeObject*) torch.attr("Tensor").ptr(); - torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); - - torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); - torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); - torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - auto TensorBase = (PyTypeObject*) py_TensorBase.ptr(); - THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; - THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; - NamedTuple = mpy::import("typing").attr("NamedTuple"); - no_slice = PySlice_New(NULL, NULL, NULL); +void initializeGlobals(Arena& A) { + auto torch = mpy::import("torch"); + torch_Tensor = (PyTypeObject*)torch.attr("Tensor").ptr(); + torch_Tensor___mul__ = torch.attr("Tensor").attr("__mul__"); + torch_Tensor_expand = torch.attr("_C").attr("TensorBase").attr("expand"); + torch_Tensor_split = torch.attr("_C").attr("TensorBase").attr("split"); + torch_Tensor_copy_ = torch.attr("Tensor").attr("copy_"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + auto TensorBase = (PyTypeObject*)py_TensorBase.ptr(); + THPVariable_getitem = TensorBase->tp_as_mapping->mp_subscript; + THPVariable_setitem = TensorBase->tp_as_mapping->mp_ass_subscript; + NamedTuple = mpy::import("typing").attr("NamedTuple"); + no_slice = PySlice_New(NULL, NULL, NULL); } mpy::handle DimensionBindError_; mpy::handle DimensionBindError() { - if(!DimensionBindError_.ptr()) { - DimensionBindError_ = mpy::import("functorch.dim").attr("DimensionBindError"); - } - return DimensionBindError_; + if (!DimensionBindError_.ptr()) { + DimensionBindError_ = + mpy::import("functorch.dim").attr("DimensionBindError"); + } + return DimensionBindError_; } static int64_t n_dims_created = 65; struct Dim : public mpy::base { - int64_t level_; // for stable comparisons in prototype - mpy::object name_; - Dim() - : level_(n_dims_created++) {} - void init(mpy::object name, int64_t s = -1) { - name_ = std::move(name); - size_ = s; - } + int64_t level_; // for stable comparisons in prototype + mpy::object name_; + Dim() : level_(n_dims_created++) {} + void init(mpy::object name, int64_t s = -1) { + name_ = std::move(name); + size_ = s; + } - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == DimType; - } + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == DimType; + } - int64_t size() const { - if (size_ == -1) { - mpy::raise_error(PyExc_ValueError, "dimension %S is unbound", name_.ptr()); - } - return size_; + int64_t size() const { + if (size_ == -1) { + mpy::raise_error( + PyExc_ValueError, "dimension %S is unbound", name_.ptr()); } - void set_size(int64_t v) { - if (size_ == -1) { - size_ = v; - } else if(size_ != v) { - mpy::raise_error(DimensionBindError(), "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", this, this->size_, v); - } + return size_; + } + void set_size(int64_t v) { + if (size_ == -1) { + size_ = v; + } else if (size_ != v) { + mpy::raise_error( + DimensionBindError(), + "Dim '%R' previously bound to a dimension of size %lld cannot bind to a dimension of size %lld", + this, + this->size_, + v); } - bool is_bound() const { - return size_ != -1; + } + bool is_bound() const { + return size_ != -1; + } + static mpy::obj create(mpy::object name, int64_t s = -1) { + if (!DimType) { + maybeInitializeGlobals(); } - static mpy::obj create(mpy::object name, int64_t s = -1) { - if (!DimType) { - maybeInitializeGlobals(); - } - auto r = Dim::alloc(DimType); - r->init(std::move(name), s); - return r; + auto r = Dim::alloc(DimType); + r->init(std::move(name), s); + return r; + } + static PyTypeObject Type; + const at::Tensor& range() { + if (!range_.defined()) { + range_ = at::arange(size()); } - static PyTypeObject Type; - const at::Tensor& range() { - if (!range_.defined()) { - range_ = at::arange(size()); - } - return range_; + return range_; + } + const at::Tensor& batchtensor() { + if (!batchtensor_.defined()) { + batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); } - const at::Tensor& batchtensor() { - if (!batchtensor_.defined()) { - batchtensor_ = at::functorch::addBatchDim(range(), 0, level_); - } - return batchtensor_; - } -private: - int64_t size_{-1}; - at::Tensor range_; - at::Tensor batchtensor_; + return batchtensor_; + } + + private: + int64_t size_{-1}; + at::Tensor range_; + at::Tensor batchtensor_; }; - struct DimEntry { - // union of either a negative number indicating which dimension this is from the rhs, - // or a pointer to a first-class dimension. - // pointers do not have their highest bit set, so checking the number is negative tells us - // that it is not a dim. - bool is_positional() const { - return data_ < 0; - } - bool is_none() const { - return data_ == 0; - } - int64_t position() const { - return data_; - } - mpy::hdl dim() const { - Dim* result; - std::memcpy(&result, &data_, sizeof(Dim*)); - return mpy::hdl(result); - } + // union of either a negative number indicating which dimension this is from + // the rhs, or a pointer to a first-class dimension. pointers do not have + // their highest bit set, so checking the number is negative tells us that it + // is not a dim. + bool is_positional() const { + return data_ < 0; + } + bool is_none() const { + return data_ == 0; + } + int64_t position() const { + return data_; + } + mpy::hdl dim() const { + Dim* result; + std::memcpy(&result, &data_, sizeof(Dim*)); + return mpy::hdl(result); + } - DimEntry() - : data_(0) {} + DimEntry() : data_(0) {} - DimEntry(int64_t pos) - : data_(pos) { - AT_ASSERT(pos < 0); - } - DimEntry(mpy::hdl d) { - std::memcpy(&data_, &d, sizeof(int64_t)); - } - bool operator==(const DimEntry& rhs) const { - return data_ == rhs.data_; - } -private: - int64_t data_; + DimEntry(int64_t pos) : data_(pos) { + AT_ASSERT(pos < 0); + } + DimEntry(mpy::hdl d) { + std::memcpy(&data_, &d, sizeof(int64_t)); + } + bool operator==(const DimEntry& rhs) const { + return data_ == rhs.data_; + } + + private: + int64_t data_; }; // Dim wrapper methods DimEntry _wrap_dim(mpy::handle d, size_t N, bool keepdim) { - if (Dim::check(d)) { - if (keepdim) { - mpy::raise_error(PyExc_ValueError, "cannot preserve first-class dimensions with keepdim=True"); - } - return Dim::unchecked_wrap(d); - } else if (mpy::is_int(d)) { - auto i = mpy::to_int(d); - while (i >= 0) { - i -= N; - } - return i; - } else { - return DimEntry(); + if (Dim::check(d)) { + if (keepdim) { + mpy::raise_error( + PyExc_ValueError, + "cannot preserve first-class dimensions with keepdim=True"); } + return Dim::unchecked_wrap(d); + } else if (mpy::is_int(d)) { + auto i = mpy::to_int(d); + while (i >= 0) { + i -= N; + } + return i; + } else { + return DimEntry(); + } } - -int Dim_init(mpy::hdl self, PyObject *args, PyObject *kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"name", "size", nullptr}; - mpy::handle name; - mpy::handle size = nullptr; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", const_cast(kwlist), &name, &size)) { - return -1; - } - self->init(mpy::object::borrow(name), (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); - return 0; - PY_END(-1) +int Dim_init(mpy::hdl self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"name", "size", nullptr}; + mpy::handle name; + mpy::handle size = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "O|O", const_cast(kwlist), &name, &size)) { + return -1; + } + self->init( + mpy::object::borrow(name), + (size.ptr() && !mpy::is_none(size)) ? mpy::to_int(size) : -1); + return 0; + PY_END(-1) } PyObject* Dim_repr(Dim* self) { - PY_BEGIN - mpy::object name = (self->name_.ptr()) ? self->name_ : mpy::unicode_from_string(""); - return name.release(); - PY_END(nullptr) + PY_BEGIN + mpy::object name = (self->name_.ptr()) + ? self->name_ + : mpy::unicode_from_string(""); + return name.release(); + PY_END(nullptr) } - PyObject* Dim_getsize(Dim* self, void*) { - PY_BEGIN - return mpy::from_int(self->size()).release(); - PY_END(nullptr) + PY_BEGIN + return mpy::from_int(self->size()).release(); + PY_END(nullptr) } int Dim_setsize(Dim* self, PyObject* size, void*) { - PY_BEGIN - self->set_size(mpy::to_int(size)); - return 0; - PY_END(-1) + PY_BEGIN + self->set_size(mpy::to_int(size)); + return 0; + PY_END(-1) } PyObject* Dim_getis_bound(Dim* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } PyObject* Dim_getlevel(Dim* self, void*) { - return PyLong_FromLong(self->level_); + return PyLong_FromLong(self->level_); } PyObject* Dim_get_levels(Dim* self, void*) { - mpy::tuple t(1); - t.set(0, mpy::object::borrow(self->ptr())); - return t.release(); + mpy::tuple t(1); + t.set(0, mpy::object::borrow(self->ptr())); + return t.release(); } PyObject* Dim_get_has_device(Dim* self, void*) { - Py_RETURN_FALSE; + Py_RETURN_FALSE; } PyObject* Dim_get_tensor(Dim* self, void*) { - return THPVariable_Wrap(self->range()); + return THPVariable_Wrap(self->range()); } PyObject* Dim_get_batchtensor(Dim* self, void*) { - return THPVariable_Wrap(self->batchtensor()); + return THPVariable_Wrap(self->batchtensor()); } - PyGetSetDef Dim_getsetters[] = { - {"size", (getter) Dim_getsize, (setter) Dim_setsize, - "Dimension size", NULL}, - {"is_bound", (getter) Dim_getis_bound, NULL, "is_bound", NULL}, - {"_level", (getter) Dim_getlevel, NULL, "_level", NULL}, - {"_levels", (getter) Dim_get_levels, NULL, "_levels", NULL}, - {"_has_device", (getter) Dim_get_has_device, NULL, "_has_device", NULL}, - {"_tensor", (getter) Dim_get_tensor, NULL, "_tensor", NULL}, - {"_batchtensor", (getter) Dim_get_batchtensor, NULL, "_batchtensor", NULL}, - {"ndim", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_int(1).release(); }, NULL, "ndim", NULL}, - {NULL} /* Sentinel */ -}; + {"size", (getter)Dim_getsize, (setter)Dim_setsize, "Dimension size", NULL}, + {"is_bound", (getter)Dim_getis_bound, NULL, "is_bound", NULL}, + {"_level", (getter)Dim_getlevel, NULL, "_level", NULL}, + {"_levels", (getter)Dim_get_levels, NULL, "_levels", NULL}, + {"_has_device", (getter)Dim_get_has_device, NULL, "_has_device", NULL}, + {"_tensor", (getter)Dim_get_tensor, NULL, "_tensor", NULL}, + {"_batchtensor", (getter)Dim_get_batchtensor, NULL, "_batchtensor", NULL}, + {"ndim", + (getter)[](PyObject* self, void*) + ->PyObject* {return mpy::from_int(1).release(); +} // namespace +, NULL, "ndim", NULL +} +, { + NULL +} /* Sentinel */ +} +; } PyTypeObject Dim::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Dim", /* tp_name */ - sizeof(Dim), /* tp_basicsize */ - 0, /* tp_itemsize */ - Dim::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)Dim_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - "Dim Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - Dim_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)(void*)static_cast,PyObject*,PyObject*)>(Dim_init), /* tp_init */ - 0, /* tp_alloc */ - Dim::new_stub, /* tp_new */ + "_C.Dim", /* tp_name */ + sizeof(Dim), /* tp_basicsize */ + 0, /* tp_itemsize */ + Dim::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)Dim_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Dim Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + Dim_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)(void*)static_cast, PyObject*, PyObject*)>( + Dim_init), /* tp_init */ + 0, /* tp_alloc */ + Dim::new_stub, /* tp_new */ }; // class DimList ------------ struct DimList : public mpy::base { - mpy::object name_; - std::vector> dims_; - static PyTypeObject Type; - void init(mpy::object name) { - name_ = std::move(name); + mpy::object name_; + std::vector> dims_; + static PyTypeObject Type; + void init(mpy::object name) { + name_ = std::move(name); + } + void set_dims(std::vector> dims) { + bound_ = true; + dims_ = std::move(dims); + } + bool is_bound() { + return bound_; + } + void bind_len(int64_t size) { + if (bound_) { + int64_t b_size = dims_.size(); + if (b_size != size) { + mpy::raise_error( + DimensionBindError(), + "Dimlist has size %lld but it is being bound to size %d", + b_size, + size); + } + } else { + bound_ = true; + dims_.resize(size); + for (Py_ssize_t i = 0; i < size; ++i) { + dims_[i] = + Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); + } } - void set_dims(std::vector> dims) { - bound_ = true; - dims_ = std::move(dims); + } + int64_t size() const { + if (!bound_) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); } - bool is_bound() { - return bound_; - } - void bind_len(int64_t size) { - if (bound_) { - int64_t b_size = dims_.size(); - if (b_size != size) { - mpy::raise_error(DimensionBindError(), "Dimlist has size %lld but it is being bound to size %d", b_size, size); - } - } else { - bound_ = true; - dims_.resize(size); - for (Py_ssize_t i = 0; i < size; ++i) { - dims_[i] = Dim::create(mpy::unicode_from_format("%S%i", name_.ptr(), (int)i)); - } - } - } - int64_t size() const { - if (!bound_) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - return dims_.size(); - } - void set_bound(bool b) { - bound_ = b; - } -private: - bool bound_ = false; + return dims_.size(); + } + void set_bound(bool b) { + bound_ = b; + } + + private: + bool bound_ = false; }; - -static int DimList_init(DimList *self, PyObject *args, PyObject *kwds); +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds); static PyObject* DimList_repr(DimList* self) { - PY_BEGIN - if (self->is_bound()) { - size_t size = self->dims_.size(); - mpy::tuple t(size); - for(size_t i = 0; i < size; ++i) { - t.set(i, self->dims_[i]); - } - return mpy::repr(t).release(); - } else if(!mpy::is_none(self->name_)) { - return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); - } else { - return mpy::unicode_from_string("").release(); + PY_BEGIN + if (self->is_bound()) { + size_t size = self->dims_.size(); + mpy::tuple t(size); + for (size_t i = 0; i < size; ++i) { + t.set(i, self->dims_[i]); } - PY_END(nullptr) + return mpy::repr(t).release(); + } else if (!mpy::is_none(self->name_)) { + return mpy::unicode_from_format("*%S", self->name_.ptr()).release(); + } else { + return mpy::unicode_from_string("").release(); + } + PY_END(nullptr) } -static PyObject* DimList_bind(DimList *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - mpy::handle sizes; - static const char * const _keywords[] = {"sizes", nullptr}; - static _PyArg_Parser parser = {"O", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { - return nullptr; - } - if (!mpy::is_sequence(sizes)) { - mpy::raise_error(PyExc_ValueError, "expected a sequence"); - } - mpy::sequence_view seq = sizes; - auto size = seq.size(); - self->bind_len(size); - for (Py_ssize_t i = 0; i < size; ++i) { - self->dims_[i]->set_size(mpy::to_int(seq[i])); - } - Py_RETURN_NONE; - PY_END(nullptr) +static PyObject* DimList_bind( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle sizes; + static const char* const _keywords[] = {"sizes", nullptr}; + static _PyArg_Parser parser = {"O", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) { + return nullptr; + } + if (!mpy::is_sequence(sizes)) { + mpy::raise_error(PyExc_ValueError, "expected a sequence"); + } + mpy::sequence_view seq = sizes; + auto size = seq.size(); + self->bind_len(size); + for (Py_ssize_t i = 0; i < size; ++i) { + self->dims_[i]->set_size(mpy::to_int(seq[i])); + } + Py_RETURN_NONE; + PY_END(nullptr) } -static PyObject* DimList_bind_len(DimList *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - int size; - static const char * const _keywords[] = {"N", nullptr}; - static _PyArg_Parser parser = {"i", _keywords, 0}; - if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { - return nullptr; - } - self->bind_len(size); - Py_RETURN_NONE; - PY_END(nullptr) +static PyObject* DimList_bind_len( + DimList* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + int size; + static const char* const _keywords[] = {"N", nullptr}; + static _PyArg_Parser parser = {"i", _keywords, 0}; + if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) { + return nullptr; + } + self->bind_len(size); + Py_RETURN_NONE; + PY_END(nullptr) } static PyMethodDef DimList_methods[] = { - {"bind", (PyCFunction)(void*) DimList_bind, METH_FASTCALL | METH_KEYWORDS}, - {"bind_len", (PyCFunction)(void*) DimList_bind_len, METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"bind", (PyCFunction)(void*)DimList_bind, METH_FASTCALL | METH_KEYWORDS}, + {"bind_len", + (PyCFunction)(void*)DimList_bind_len, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; - static Py_ssize_t DimList_len(DimList* self) { - PY_BEGIN - return self->size(); - PY_END(-1) + PY_BEGIN + return self->size(); + PY_END(-1) } -static PyObject * DimList_item(DimList* self, Py_ssize_t idx) { - PY_BEGIN - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - if (idx < 0 || (size_t) idx >= self->dims_.size()) { - mpy::raise_error(PyExc_IndexError, "index out of bounds"); - } - mpy::object r = self->dims_[idx]; - return r.release(); - PY_END(nullptr) +static PyObject* DimList_item(DimList* self, Py_ssize_t idx) { + PY_BEGIN + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); + } + if (idx < 0 || (size_t)idx >= self->dims_.size()) { + mpy::raise_error(PyExc_IndexError, "index out of bounds"); + } + mpy::object r = self->dims_[idx]; + return r.release(); + PY_END(nullptr) } -PySequenceMethods DimList_seq { - (lenfunc) DimList_len, //lenfunc sq_length; - 0, //binaryfunc sq_concat; - 0, //ssizeargfunc sq_repeat; - (ssizeargfunc) DimList_item, //ssizeargfunc sq_item; - 0, //void *was_sq_slice; - 0, //ssizeobjargproc sq_ass_item; - 0, //void *was_sq_ass_slice; - 0, //objobjproc sq_contains; +PySequenceMethods DimList_seq{ + (lenfunc)DimList_len, // lenfunc sq_length; + 0, // binaryfunc sq_concat; + 0, // ssizeargfunc sq_repeat; + (ssizeargfunc)DimList_item, // ssizeargfunc sq_item; + 0, // void *was_sq_slice; + 0, // ssizeobjargproc sq_ass_item; + 0, // void *was_sq_ass_slice; + 0, // objobjproc sq_contains; - 0, //binaryfunc sq_inplace_concat; - 0, //ssizeargfunc sq_inplace_repeat; + 0, // binaryfunc sq_inplace_concat; + 0, // ssizeargfunc sq_inplace_repeat; }; static PyObject* DimList_getis_bound(DimList* self, void*) { - return PyBool_FromLong(self->is_bound()); + return PyBool_FromLong(self->is_bound()); } static PyGetSetDef DimList_getsetters[] = { - {"is_bound", (getter) DimList_getis_bound, NULL, "is_bound", NULL}, - {NULL} /* Sentinel */ + {"is_bound", (getter)DimList_getis_bound, NULL, "is_bound", NULL}, + {NULL} /* Sentinel */ }; - static PyObject* DimList_subscript(DimList* self, mpy::handle idx) { - PY_BEGIN - if (mpy::is_int(idx)) { - return DimList_item(self, mpy::to_int(idx)); - } else if (mpy::is_slice(idx)) { - if (!self->is_bound()) { - mpy::raise_error(DimensionBindError(), "DimList not bound"); - } - mpy::slice_view s(idx, self->dims_.size()); - mpy::tuple r(s.slicelength); - for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { - r.set(j++, self->dims_[i]); - } - return r.release(); - } else { - mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); - return nullptr; + PY_BEGIN + if (mpy::is_int(idx)) { + return DimList_item(self, mpy::to_int(idx)); + } else if (mpy::is_slice(idx)) { + if (!self->is_bound()) { + mpy::raise_error(DimensionBindError(), "DimList not bound"); } - PY_END(nullptr) + mpy::slice_view s(idx, self->dims_.size()); + mpy::tuple r(s.slicelength); + for (Py_ssize_t i = s.start, j = 0; i < s.stop; i += s.step) { + r.set(j++, self->dims_[i]); + } + return r.release(); + } else { + mpy::raise_error(PyExc_ValueError, "expected an int or a slice"); + return nullptr; + } + PY_END(nullptr) } PyMappingMethods DimList_mapping = { - 0, //lenfunc mp_length; - (binaryfunc)(void*) DimList_subscript, //binaryfunc mp_subscript; - 0, //objobjargproc mp_ass_subscript; + 0, // lenfunc mp_length; + (binaryfunc)(void*)DimList_subscript, // binaryfunc mp_subscript; + 0, // objobjargproc mp_ass_subscript; }; - - PyTypeObject DimList::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.DimList", /* tp_name */ - sizeof(DimList), /* tp_basicsize */ - 0, /* tp_itemsize */ - DimList::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)DimList_repr, /* tp_repr */ - 0, /* tp_as_number */ - &DimList_seq, /* tp_as_sequence */ - &DimList_mapping, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - 0, /* tp_flags */ - "DimList Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - DimList_methods, /* tp_methods */ - 0, /* tp_members */ - DimList_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc) DimList_init, /* tp_init */ - 0, /* tp_alloc */ - DimList::new_stub, /* tp_new */ + "_C.DimList", /* tp_name */ + sizeof(DimList), /* tp_basicsize */ + 0, /* tp_itemsize */ + DimList::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + (reprfunc)DimList_repr, /* tp_repr */ + 0, /* tp_as_number */ + &DimList_seq, /* tp_as_sequence */ + &DimList_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + 0, /* tp_flags */ + "DimList Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DimList_methods, /* tp_methods */ + 0, /* tp_members */ + DimList_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DimList_init, /* tp_init */ + 0, /* tp_alloc */ + DimList::new_stub, /* tp_new */ }; -static int DimList_init(DimList *self, PyObject *args, PyObject *kwds) { - PY_BEGIN - static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; - mpy::handle len_or_dims = nullptr; - PyObject* name = nullptr; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { - return -1; - } - self->init(mpy::object::borrow(name ? name : Py_None)); - if (len_or_dims.ptr()) { - if(mpy::is_int(len_or_dims)) { - self->bind_len(mpy::to_int(len_or_dims)); - } else if (mpy::is_sequence(len_or_dims)) { - mpy::sequence_view s(len_or_dims); - std::vector> dims; - size_t size = s.size(); - dims.reserve(size); - for (size_t i = 0; i < size; ++i) { - auto r = s[i]; - if (mpy::is_int(r)) { - dims.emplace_back(Dim::create(mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), mpy::to_int(r))); - } else { - dims.emplace_back(Dim::wrap(r)); - } - } - self->set_dims(std::move(dims)); +static int DimList_init(DimList* self, PyObject* args, PyObject* kwds) { + PY_BEGIN + static constexpr const char* kwlist[] = {"len_or_dims", "name", nullptr}; + mpy::handle len_or_dims = nullptr; + PyObject* name = nullptr; + if (!PyArg_ParseTupleAndKeywords( + args, kwds, "|OO", const_cast(kwlist), &len_or_dims, &name)) { + return -1; + } + self->init(mpy::object::borrow(name ? name : Py_None)); + if (len_or_dims.ptr()) { + if (mpy::is_int(len_or_dims)) { + self->bind_len(mpy::to_int(len_or_dims)); + } else if (mpy::is_sequence(len_or_dims)) { + mpy::sequence_view s(len_or_dims); + std::vector> dims; + size_t size = s.size(); + dims.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto r = s[i]; + if (mpy::is_int(r)) { + dims.emplace_back(Dim::create( + mpy::unicode_from_format("%S%i", self->name_.ptr(), (int)i), + mpy::to_int(r))); } else { - PyErr_Format(PyExc_ValueError, "expected a length or a sequence of dimensions"); - return -1; + dims.emplace_back(Dim::wrap(r)); } - return 0; + } + self->set_dims(std::move(dims)); + } else { + PyErr_Format( + PyExc_ValueError, "expected a length or a sequence of dimensions"); + return -1; } return 0; - PY_END(-1); + } + return 0; + PY_END(-1); } // Tensor ----------------------------- PyTypeObject* TensorType = nullptr; // the python wrapper type. -mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise); +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise); -namespace{ +namespace { at::Tensor _add_batch_dims(Arena& A, at::Tensor t, Slice levels_) { - auto levels = Slice(); - levels.extend(A, levels_); - while (true) { - int64_t min_real_index = -1; - int64_t min_index = -1; - int64_t min_value = INT_MAX; - int64_t i = 0; - int64_t r = 0; - for (auto l : levels) { - if (!l.is_none()) { - if (!l.is_positional() && l.dim()->level_ < min_value) { - min_value = l.dim()->level_; - min_index = i; - min_real_index = r; - } - ++i; - } - ++r; + auto levels = Slice(); + levels.extend(A, levels_); + while (true) { + int64_t min_real_index = -1; + int64_t min_index = -1; + int64_t min_value = INT_MAX; + int64_t i = 0; + int64_t r = 0; + for (auto l : levels) { + if (!l.is_none()) { + if (!l.is_positional() && l.dim()->level_ < min_value) { + min_value = l.dim()->level_; + min_index = i; + min_real_index = r; } - if (min_index == -1) { - return t; - } - auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); - t = std::move(t2); - levels[min_real_index] = DimEntry(); + ++i; + } + ++r; } + if (min_index == -1) { + return t; + } + auto t2 = at::functorch::addBatchDim(std::move(t), min_index, min_value); + t = std::move(t2); + levels[min_real_index] = DimEntry(); + } } - - struct DelayedOperator { - DelayedOperator(mpy::object o, mpy::vector_args a) - : orig(std::move(o)), args(a) { - auto all = a.size(); - // this will outlive the call so - // take ownership of temporaries - // in vector args - auto buf = new mpy::handle[all]; - memcpy(buf, args.args, sizeof(mpy::handle)*all); - args.args = buf; - for (auto i : args.enumerate_all()) { - Py_INCREF(args.args[i].ptr()); - } - Py_XINCREF(args.kwnames.ptr()); + DelayedOperator(mpy::object o, mpy::vector_args a) + : orig(std::move(o)), args(a) { + auto all = a.size(); + // this will outlive the call so + // take ownership of temporaries + // in vector args + auto buf = new mpy::handle[all]; + memcpy(buf, args.args, sizeof(mpy::handle) * all); + args.args = buf; + for (auto i : args.enumerate_all()) { + Py_INCREF(args.args[i].ptr()); } - ~DelayedOperator() { - for (auto i : args.enumerate_all()) { - Py_DECREF(args[i].ptr()); - } - if (args.has_keywords()) { - Py_XDECREF(args.kwnames.ptr()); - } - delete [] args.args; + Py_XINCREF(args.kwnames.ptr()); + } + ~DelayedOperator() { + for (auto i : args.enumerate_all()) { + Py_DECREF(args[i].ptr()); } - mpy::object orig; - mpy::vector_args args; + if (args.has_keywords()) { + Py_XDECREF(args.kwnames.ptr()); + } + delete[] args.args; + } + mpy::object orig; + mpy::vector_args args; }; void free_levels_dims(Slice levels) { - for(auto e : levels) { - if (!e.is_positional()) { - mpy::object::steal(e.dim()); - } + for (auto e : levels) { + if (!e.is_positional()) { + mpy::object::steal(e.dim()); } + } } -} +} // namespace struct Tensor : public mpy::base { -private: - at::Tensor tensor_; - at::Tensor batchtensor_; - OwnedSlice levels_; - bool has_device_; - std::unique_ptr delayed_; -public: + private: + at::Tensor tensor_; + at::Tensor batchtensor_; + OwnedSlice levels_; + bool has_device_; + std::unique_ptr delayed_; - at::Tensor& tensor(Arena& A) { - if (C10_UNLIKELY(!tensor_.defined())) { - AT_ASSERT(delayed_); - auto t = Tensor::wrap(run_torch_function(A, delayed_->orig, delayed_->args, true)); - tensor_ = t->tensor(A); - delayed_.reset(); - // don't force creation of batch tensor if it wasn't already provided. - batchtensor_ = t->batchtensor_; - AT_ASSERT(levels() == t->levels()); - } - return tensor_; + public: + at::Tensor& tensor(Arena& A) { + if (C10_UNLIKELY(!tensor_.defined())) { + AT_ASSERT(delayed_); + auto t = Tensor::wrap( + run_torch_function(A, delayed_->orig, delayed_->args, true)); + tensor_ = t->tensor(A); + delayed_.reset(); + // don't force creation of batch tensor if it wasn't already provided. + batchtensor_ = t->batchtensor_; + AT_ASSERT(levels() == t->levels()); } - at::Tensor& batchtensor(Arena& A) { - if (C10_UNLIKELY(!batchtensor_.defined())) { - batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); - } - return batchtensor_; + return tensor_; + } + at::Tensor& batchtensor(Arena& A) { + if (C10_UNLIKELY(!batchtensor_.defined())) { + batchtensor_ = _add_batch_dims(A, tensor(A), levels_.slice()); } - Slice levels() { - return levels_.slice(); - } - bool has_device() { - return has_device_; - } - DelayedOperator* delayed() { - return delayed_.get(); - } - static PyTypeObject Type; + return batchtensor_; + } + Slice levels() { + return levels_.slice(); + } + bool has_device() { + return has_device_; + } + DelayedOperator* delayed() { + return delayed_.get(); + } + static PyTypeObject Type; - static bool check_exact(mpy::handle v) { - return Py_TYPE(v.ptr()) == TensorType; - } + static bool check_exact(mpy::handle v) { + return Py_TYPE(v.ptr()) == TensorType; + } - - static mpy::obj create() { - if (!TensorType) { - TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").release(); - } - return Tensor::alloc(TensorType); + static mpy::obj create() { + if (!TensorType) { + TensorType = + (PyTypeObject*)mpy::import("functorch.dim").attr("Tensor").release(); } - void capture_levels(Slice levels) { - // grab ownership of the dims inside levels - for (auto l : levels) { - if (!l.is_positional()) { - mpy::object::borrow(l.dim()).release(); - } - } - levels_.set(levels, free_levels_dims); + return Tensor::alloc(TensorType); + } + void capture_levels(Slice levels) { + // grab ownership of the dims inside levels + for (auto l : levels) { + if (!l.is_positional()) { + mpy::object::borrow(l.dim()).release(); + } } - static mpy::object from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device); - static mpy::obj create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device); - friend struct EnableAllLayers; + levels_.set(levels, free_levels_dims); + } + static mpy::object from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device); + static mpy::obj create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device); + friend struct EnableAllLayers; }; -namespace{ +namespace { // version in header does a unnecessary refcount +/- -at::functorch::BatchedTensorImpl* maybeGetBatchedImpl(const at::Tensor& tensor) { - if (at::functorch::isBatchedTensor(tensor)) { - return static_cast(tensor.unsafeGetTensorImpl()); - } - return nullptr; +at::functorch::BatchedTensorImpl* maybeGetBatchedImpl( + const at::Tensor& tensor) { + if (at::functorch::isBatchedTensor(tensor)) { + return static_cast( + tensor.unsafeGetTensorImpl()); + } + return nullptr; } TensorRef unchecked_tensor_from(mpy::handle p) { - auto v = (THPVariable*) p.ptr(); - return TensorRef(*v->cdata); + auto v = (THPVariable*)p.ptr(); + return TensorRef(*v->cdata); } static int64_t ndim_of_levels(Slice levels) { - int64_t r = 0; - for (auto l : levels) { - if (l.is_positional()) { - ++r; - } + int64_t r = 0; + for (auto l : levels) { + if (l.is_positional()) { + ++r; } - return r; + } + return r; } struct TensorInfo { - TensorRef tensor; - Slice levels; - bool has_device; - TensorRef batchedtensor; - int64_t ndim() const { - return ndim_of_levels(levels); - } - operator bool() const { - return tensor; - } + TensorRef tensor; + Slice levels; + bool has_device; + TensorRef batchedtensor; + int64_t ndim() const { + return ndim_of_levels(levels); + } + operator bool() const { + return tensor; + } - static TensorInfo create(Arena& A, mpy::handle h, bool ensure_batched=true, bool ensure_present=true) { - if (Tensor::check_exact(h)) { - auto t = Tensor::unchecked_wrap(h); - return TensorInfo {t->tensor(A), t->levels(), t->has_device(), ensure_batched ? t->batchtensor(A) : TensorRef()}; - } else if (Dim::check_exact(h)) { - auto d = Dim::unchecked_wrap(h); - return TensorInfo {d->range(), Slice(A, DimEntry(d)), false, ensure_batched ? d->batchtensor() : TensorRef()}; - } else if (THPVariable_Check(h.ptr())) { - TensorRef t = unchecked_tensor_from(h); - Slice levels; - for (auto i : irange(-t->dim(), 0)) { - levels.append(A, i); - } - return TensorInfo {t, levels, true, t}; - } else { - if (ensure_present) { - mpy::raise_error(PyExc_ValueError, "expected a tensor object"); - } - return TensorInfo {}; - } + static TensorInfo create( + Arena& A, + mpy::handle h, + bool ensure_batched = true, + bool ensure_present = true) { + if (Tensor::check_exact(h)) { + auto t = Tensor::unchecked_wrap(h); + return TensorInfo{ + t->tensor(A), + t->levels(), + t->has_device(), + ensure_batched ? t->batchtensor(A) : TensorRef()}; + } else if (Dim::check_exact(h)) { + auto d = Dim::unchecked_wrap(h); + return TensorInfo{ + d->range(), + Slice(A, DimEntry(d)), + false, + ensure_batched ? d->batchtensor() : TensorRef()}; + } else if (THPVariable_Check(h.ptr())) { + TensorRef t = unchecked_tensor_from(h); + Slice levels; + for (auto i : irange(-t->dim(), 0)) { + levels.append(A, i); + } + return TensorInfo{t, levels, true, t}; + } else { + if (ensure_present) { + mpy::raise_error(PyExc_ValueError, "expected a tensor object"); + } + return TensorInfo{}; } - - + } }; -static PyObject* py_Tensor_from_positional(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - #define ARGS(_) _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) - MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) - #undef ARGS +static PyObject* py_Tensor_from_positional( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN +#define ARGS(_) \ + _(mpy::handle, tensor) _(mpy::handle, py_levels) _(int, has_device) + MPY_PARSE_ARGS_KWNAMES("OOp", ARGS) +#undef ARGS - if (!THPVariable_Check(tensor.ptr())) { - mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); - } + if (!THPVariable_Check(tensor.ptr())) { + mpy::raise_error(PyExc_ValueError, "_tensor is not a Tensor?"); + } - Slice levels; - mpy::sequence_view sq(py_levels); - for (auto i : sq.enumerate()) { - mpy::object v = sq[i]; - if (mpy::is_int(v)) { - auto vi = mpy::to_int(v); - levels.append(A, vi); - } else { - auto dim = Dim::wrap(std::move(v)); - mpy::hdl hdim = dim; - levels.append(A, hdim); - } + Slice levels; + mpy::sequence_view sq(py_levels); + for (auto i : sq.enumerate()) { + mpy::object v = sq[i]; + if (mpy::is_int(v)) { + auto vi = mpy::to_int(v); + levels.append(A, vi); + } else { + auto dim = Dim::wrap(std::move(v)); + mpy::hdl hdim = dim; + levels.append(A, hdim); } - return Tensor::from_positional(A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0).release(); - PY_END(nullptr) + } + return Tensor::from_positional( + A, THPVariable_Unpack(tensor.ptr()), levels, has_device != 0) + .release(); + PY_END(nullptr) } +} // namespace + +mpy::object Tensor::from_positional( + Arena& A, + at::Tensor tensor, + Slice levels, + bool has_device) { + size_t seen_dims = 0; + int last = 0; + // auto sz = tensor.sizes(); + for (auto i : levels.enumerate()) { + auto l = levels[i]; + if (l.is_positional()) { + AT_ASSERT(last == 0 || last + 1 == l.position()); + last = l.position(); + } else { + mpy::object::borrow(l.dim()).release(); + // AT_ASSERT(sz[i] == l.dim()->size()); + ++seen_dims; + } + } + AT_ASSERT(last == 0 || last == -1); + if (!seen_dims) { + return mpy::object::steal(THPVariable_Wrap(tensor)); + } + + mpy::obj self = Tensor::create(); + self->tensor_ = std::move(tensor); + AT_ASSERT(self->tensor_.dim() == levels.size()); + self->levels_.set(levels, free_levels_dims); + self->has_device_ = has_device; + mpy::object r = std::move(self); + return r; } -mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice levels, bool has_device) { - size_t seen_dims = 0; - int last = 0; - //auto sz = tensor.sizes(); - for (auto i : levels.enumerate()) { - auto l = levels[i]; - if (l.is_positional()) { - AT_ASSERT(last == 0 || last + 1 == l.position()); - last = l.position(); - } else { - mpy::object::borrow(l.dim()).release(); - //AT_ASSERT(sz[i] == l.dim()->size()); - ++seen_dims; - } - } - AT_ASSERT(last == 0 || last == -1); - if (!seen_dims) { - return mpy::object::steal(THPVariable_Wrap(tensor)); - } - - mpy::obj self = Tensor::create(); - self->tensor_ = std::move(tensor); - AT_ASSERT(self->tensor_.dim() == levels.size()); - self->levels_.set(levels, free_levels_dims); - self->has_device_ = has_device; - mpy::object r = std::move(self); - return r; +mpy::obj Tensor::create_delayed( + mpy::object op, + mpy::vector_args args, + Slice levels, + bool has_device) { + mpy::obj self = Tensor::create(); + self->capture_levels(levels); + self->has_device_ = has_device; + self->delayed_ = std::make_unique(std::move(op), args); + return self; } - -mpy::obj Tensor::create_delayed(mpy::object op, mpy::vector_args args, Slice levels, bool has_device) { - mpy::obj self = Tensor::create(); - self->capture_levels(levels); - self->has_device_ = has_device; - self->delayed_ = std::make_unique(std::move(op), args); - return self; -} - -namespace{ +namespace { mpy::list slice_to_list(Slice h) { - mpy::list lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::list lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } mpy::tuple slice_to_tuple(Slice h) { - mpy::tuple lst(h.size()); - for (auto i : h.enumerate()) { - lst.set(i, mpy::object::borrow(h[i])); - } - return lst; + mpy::tuple lst(h.size()); + for (auto i : h.enumerate()) { + lst.set(i, mpy::object::borrow(h[i])); + } + return lst; } enum UType { - U_ELEM, - U_TUPLE_LIKE, - U_DICT, + U_ELEM, + U_TUPLE_LIKE, + U_DICT, }; struct Unflatten { - mpy::object operator()(Slice& elements) { - mpy::object r; - switch (type) { - case U_ELEM: { - r = mpy::object::borrow(elements[0]); - elements = elements.slice(1); - } break; - case U_TUPLE_LIKE: { - mpy::tuple tup(children.size()); - for (auto i : children.enumerate()) { - tup.set(i, children[i](elements)); - } - r = obj.call(tup); - } break; - case U_DICT: { - r = mpy::object::checked_steal(PyDict_New()); - mpy::dict_view rv(r); - mpy::dict_view d(obj); - Py_ssize_t pos = 0; - mpy::handle k, v; - for (int i = 0; d.next(&pos, &k, &v); ++i) { - rv.set(k, children[i](elements)); - } - } break; + mpy::object operator()(Slice& elements) { + mpy::object r; + switch (type) { + case U_ELEM: { + r = mpy::object::borrow(elements[0]); + elements = elements.slice(1); + } break; + case U_TUPLE_LIKE: { + mpy::tuple tup(children.size()); + for (auto i : children.enumerate()) { + tup.set(i, children[i](elements)); } - return r; - } - UType type; - mpy::handle obj; - Slice children; -}; - -Unflatten tree_flatten(Arena& A, mpy::handle agg, Slice& flat_elements) { - Slice c; - UType utype; - mpy::handle obj; - if (mpy::list_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - mpy::list_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::tuple_view::check(agg)) { - obj = agg.type(); - utype = U_TUPLE_LIKE; - // includes named tuples - mpy::tuple_view l(agg); - for (auto i : l.enumerate()) { - c.append(A, tree_flatten(A, l[i], flat_elements)); - } - } else if (mpy::dict_view::check(agg)) { - utype = U_DICT; - mpy::dict_view d(agg); - obj = agg; + r = obj.call(tup); + } break; + case U_DICT: { + r = mpy::object::checked_steal(PyDict_New()); + mpy::dict_view rv(r); + mpy::dict_view d(obj); Py_ssize_t pos = 0; mpy::handle k, v; - while (d.next(&pos, &k, &v)) { - c.append(A, tree_flatten(A, v, flat_elements)); + for (int i = 0; d.next(&pos, &k, &v); ++i) { + rv.set(k, children[i](elements)); } - } else { - utype = U_ELEM; - flat_elements.append(A, agg); + } break; } - return Unflatten {utype, obj, c}; + return r; + } + UType type; + mpy::handle obj; + Slice children; +}; + +Unflatten tree_flatten( + Arena& A, + mpy::handle agg, + Slice& flat_elements) { + Slice c; + UType utype; + mpy::handle obj; + if (mpy::list_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + mpy::list_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::tuple_view::check(agg)) { + obj = agg.type(); + utype = U_TUPLE_LIKE; + // includes named tuples + mpy::tuple_view l(agg); + for (auto i : l.enumerate()) { + c.append(A, tree_flatten(A, l[i], flat_elements)); + } + } else if (mpy::dict_view::check(agg)) { + utype = U_DICT; + mpy::dict_view d(agg); + obj = agg; + Py_ssize_t pos = 0; + mpy::handle k, v; + while (d.next(&pos, &k, &v)) { + c.append(A, tree_flatten(A, v, flat_elements)); + } + } else { + utype = U_ELEM; + flat_elements.append(A, agg); + } + return Unflatten{utype, obj, c}; } struct UnflattenVectorArgs { - mpy::vector_args operator()(Arena& A, Slice& elements) { - if (!had_nested) { - auto args = elements.begin(); - elements = Slice(); - return mpy::vector_args(args, nargs, kwnames); - } - Slice args; - for (auto u : children) { - args.append(A, A.autorelease(u(elements))); - } - return mpy::vector_args(args.begin(), nargs, kwnames); + mpy::vector_args operator()(Arena& A, Slice& elements) { + if (!had_nested) { + auto args = elements.begin(); + elements = Slice(); + return mpy::vector_args(args, nargs, kwnames); } - Slice children; - Py_ssize_t nargs; - mpy::handle kwnames; - bool had_nested; + Slice args; + for (auto u : children) { + args.append(A, A.autorelease(u(elements))); + } + return mpy::vector_args(args.begin(), nargs, kwnames); + } + Slice children; + Py_ssize_t nargs; + mpy::handle kwnames; + bool had_nested; }; -UnflattenVectorArgs tree_flatten(Arena& A, mpy::vector_args args, Slice& flat_elements) { - UnflattenVectorArgs r; - r.kwnames = args.kwnames; - r.nargs = args.nargs; - r.had_nested = false; - auto N = args.size(); - for(auto i : irange(N)) { - auto typ = Py_TYPE(args[i].ptr()); - // fast checks that this thing isn't something that is nested. - bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || typ == TensorType || typ == DimType; - if (!is_element) { - flat_elements.extend(A, args.args, args.args + i); - for (auto j : irange(i)) { - (void)j; - r.children.append(A, Unflatten {U_ELEM}); - } - for (auto j : irange(i, N)) { - r.children.append(A, tree_flatten(A, args[j], flat_elements)); - if (r.children.back().type != U_ELEM) { - r.had_nested = true; - } - } - return r; +UnflattenVectorArgs tree_flatten( + Arena& A, + mpy::vector_args args, + Slice& flat_elements) { + UnflattenVectorArgs r; + r.kwnames = args.kwnames; + r.nargs = args.nargs; + r.had_nested = false; + auto N = args.size(); + for (auto i : irange(N)) { + auto typ = Py_TYPE(args[i].ptr()); + // fast checks that this thing isn't something that is nested. + bool is_element = !typ->tp_as_sequence || typ == torch_Tensor || + typ == TensorType || typ == DimType; + if (!is_element) { + flat_elements.extend(A, args.args, args.args + i); + for (auto j : irange(i)) { + (void)j; + r.children.append(A, Unflatten{U_ELEM}); + } + for (auto j : irange(i, N)) { + r.children.append(A, tree_flatten(A, args[j], flat_elements)); + if (r.children.back().type != U_ELEM) { + r.had_nested = true; } + } + return r; } - flat_elements.extend(A, args.args, args.args + N); - return r; + } + flat_elements.extend(A, args.args, args.args + N); + return r; } - struct UnflattenArena { - Arena A; - Unflatten unflatten; + Arena A; + Unflatten unflatten; }; -PyObject* py_unflatten(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - #define ARGS(_) _(mpy::handle, ns) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) - #undef ARGS - mpy::sequence_view sv(ns); - // because we do not have a autorelase pool yet... - Arena A; - Slice slice; - mpy::handle Tuple = (PyObject*) &PyTuple_Type; - auto inputs = Tuple.call(ns); - mpy::tuple_view tv(inputs); - for (auto i : tv.enumerate()) { - slice.append(A, tv[i]); - } - auto AA = (UnflattenArena*) PyCapsule_GetPointer(self, "arena"); - auto r = AA->unflatten(slice).release(); - AT_ASSERT(r != nullptr); - return r; - PY_END(nullptr) +PyObject* py_unflatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, ns) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + mpy::sequence_view sv(ns); + // because we do not have a autorelase pool yet... + Arena A; + Slice slice; + mpy::handle Tuple = (PyObject*)&PyTuple_Type; + auto inputs = Tuple.call(ns); + mpy::tuple_view tv(inputs); + for (auto i : tv.enumerate()) { + slice.append(A, tv[i]); + } + auto AA = (UnflattenArena*)PyCapsule_GetPointer(self, "arena"); + auto r = AA->unflatten(slice).release(); + AT_ASSERT(r != nullptr); + return r; + PY_END(nullptr) } -PyMethodDef py_unflatten_def = {"unflatten", (PyCFunction)(void*) py_unflatten, METH_FASTCALL | METH_KEYWORDS}; +PyMethodDef py_unflatten_def = { + "unflatten", + (PyCFunction)(void*)py_unflatten, + METH_FASTCALL | METH_KEYWORDS}; -void free_unflatten_arena(PyObject * pc) { - delete (UnflattenArena*) PyCapsule_GetPointer(pc, "arena"); +void free_unflatten_arena(PyObject* pc) { + delete (UnflattenArena*)PyCapsule_GetPointer(pc, "arena"); } -PyObject* py_tree_flatten(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - #define ARGS(_) _(mpy::handle, tree) - MPY_PARSE_ARGS_KWNAMES("O", ARGS) - #undef ARGS - auto A = new UnflattenArena; - Slice elements; - A->unflatten = tree_flatten(A->A, tree, elements); - auto cap = mpy::object::checked_steal(PyCapsule_New(A, "arena", free_unflatten_arena)); - auto unflatten = mpy::object::checked_steal(PyCFunction_New(&py_unflatten_def, cap.release())); - mpy::tuple r(2); - r.set(0, slice_to_list(elements)); - r.set(1, std::move(unflatten)); - return r.release(); - PY_END(nullptr) +PyObject* py_tree_flatten( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN +#define ARGS(_) _(mpy::handle, tree) + MPY_PARSE_ARGS_KWNAMES("O", ARGS) +#undef ARGS + auto A = new UnflattenArena; + Slice elements; + A->unflatten = tree_flatten(A->A, tree, elements); + auto cap = mpy::object::checked_steal( + PyCapsule_New(A, "arena", free_unflatten_arena)); + auto unflatten = mpy::object::checked_steal( + PyCFunction_New(&py_unflatten_def, cap.release())); + mpy::tuple r(2); + r.set(0, slice_to_list(elements)); + r.set(1, std::move(unflatten)); + return r.release(); + PY_END(nullptr) } - - -mpy::object tree_map(Arena& A, const std::function& fn, mpy::handle agg) { - Slice elements; - auto unflatten = tree_flatten(A, agg, elements); - for (auto i : elements.enumerate()) { - elements[i] = fn(elements[i]); - } - return unflatten(elements); +mpy::object tree_map( + Arena& A, + const std::function& fn, + mpy::handle agg) { + Slice elements; + auto unflatten = tree_flatten(A, agg, elements); + for (auto i : elements.enumerate()) { + elements[i] = fn(elements[i]); + } + return unflatten(elements); } // prereq: isinstance(h, _Tensor) int64_t _Tensor_ndim(mpy::handle h) { - if (Tensor::check(h)) { - int64_t r = 0; - for (auto l : Tensor::unchecked_wrap(h)->levels()) { - if (l.is_positional()) { - ++r; - } - } - return r; + if (Tensor::check(h)) { + int64_t r = 0; + for (auto l : Tensor::unchecked_wrap(h)->levels()) { + if (l.is_positional()) { + ++r; + } } - // Dim or DelayedMulTensor - return 0; + return r; + } + // Dim or DelayedMulTensor + return 0; } mpy::handle handle_from_tensor(Arena& A, TensorRef t) { - // fast case: tensor is live in python - std::optional mb_obj = - t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - return *mb_obj; - } - return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); -} + // fast case: tensor is live in python + std::optional mb_obj = + t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( + /*ignore_hermetic_tls=*/false); + if (mb_obj.has_value() && + !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { + return *mb_obj; + } + return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t))); } +} // namespace struct EnableAllLayers { - EnableAllLayers(Arena& A, Slice levels) { - std::vector> layers; - layers.reserve(levels.size()); - for (auto l : levels) { - if (!l.is_positional()) { - auto d = l.dim(); - levels_to_dim_.append(A, d); - } - } - std::sort(levels_to_dim_.begin(), levels_to_dim_.end(), [](mpy::hdl lhs, mpy::hdl rhs) { return lhs->level_ < rhs->level_;}); + EnableAllLayers(Arena& A, Slice levels) { + std::vector> layers; + layers.reserve(levels.size()); + for (auto l : levels) { + if (!l.is_positional()) { + auto d = l.dim(); + levels_to_dim_.append(A, d); + } + } + std::sort( + levels_to_dim_.begin(), + levels_to_dim_.end(), + [](mpy::hdl lhs, mpy::hdl rhs) { + return lhs->level_ < rhs->level_; + }); - for (auto i : levels_to_dim_.enumerate()) { - auto batch_size = levels_to_dim_[i]->size(); - auto level = at::functorch::initAndPushDynamicLayer(at::functorch::TransformType::Vmap, batch_size, at::functorch::RandomnessType::Different); - if (i == 0) { - levels_start_ = level; - } - } + for (auto i : levels_to_dim_.enumerate()) { + auto batch_size = levels_to_dim_[i]->size(); + auto level = at::functorch::initAndPushDynamicLayer( + at::functorch::TransformType::Vmap, + batch_size, + at::functorch::RandomnessType::Different); + if (i == 0) { + levels_start_ = level; + } + } + } + + ~EnableAllLayers() { + auto to_remove = levels_start_ + levels_to_dim_.size() - 1; + for (auto i : levels_to_dim_.enumerate()) { + AT_ASSERT( + at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == + to_remove - i); + } + } + + mpy::obj from_batched( + Arena& A, + at::Tensor batchedtensor, + bool has_device) { + Slice levels; + for (auto i : irange(-batchedtensor.dim(), 0)) { + levels.append(A, i); + } + TensorRef tensor; + at::functorch::BatchedTensorImpl* impl = maybeGetBatchedImpl(batchedtensor); + while (true) { + auto level = impl->level(); + AT_ASSERT( + level >= levels_start_ && + level < levels_start_ + levels_to_dim_.size()); + mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); + levels.insert(A, impl->bdim(), dim); + at::functorch::BatchedTensorImpl* nimpl = + maybeGetBatchedImpl(impl->value()); + if (!nimpl) { + tensor = impl->value(); + break; + } + impl = nimpl; } - ~EnableAllLayers() { - auto to_remove = levels_start_ + levels_to_dim_.size() - 1; - for (auto i : levels_to_dim_.enumerate()) { - AT_ASSERT(at::functorch::popDynamicLayerAndDeleteMetadata().layerId() == to_remove - i); - } + mpy::obj self = Tensor::create(); + // grab ownership of the tensors + self->tensor_ = *tensor; + self->batchtensor_ = std::move(batchedtensor); + self->has_device_ = has_device; + self->capture_levels(levels); + return self; + } + void inplace_update_layers(TensorRef batchtensor, Slice levels) { + // XXX - requires a patch to functorch to att set_level + auto impl = maybeGetBatchedImpl(*batchtensor); + for (auto i : levels_to_dim_.reversed_enumerate()) { + if (!impl) { + break; + } + if (levels.contains(levels_to_dim_[i])) { + impl->_unsafe_set_level(levels_start_ + i); + impl = maybeGetBatchedImpl(impl->value()); + } } + } - mpy::obj from_batched(Arena& A, at::Tensor batchedtensor, bool has_device) { - Slice levels; - for (auto i : irange(-batchedtensor.dim(), 0)) { - levels.append(A, i); - } - TensorRef tensor; - at::functorch::BatchedTensorImpl * impl = maybeGetBatchedImpl(batchedtensor); - while(true) { - auto level = impl->level(); - AT_ASSERT(level >= levels_start_ && level < levels_start_ + levels_to_dim_.size()); - mpy::hdl dim = levels_to_dim_[level - levels_start_].ptr(); - levels.insert(A, impl->bdim(), dim); - at::functorch::BatchedTensorImpl * nimpl = maybeGetBatchedImpl(impl->value()); - if (!nimpl) { - tensor = impl->value(); - break; - } - impl = nimpl; - } - - mpy::obj self = Tensor::create(); - // grab ownership of the tensors - self->tensor_ = *tensor; - self->batchtensor_ = std::move(batchedtensor); - self->has_device_ = has_device; - self->capture_levels(levels); - return self; - } - void inplace_update_layers(TensorRef batchtensor, Slice levels) { - // XXX - requires a patch to functorch to att set_level - auto impl = maybeGetBatchedImpl(*batchtensor); - for (auto i : levels_to_dim_.reversed_enumerate()) { - if (!impl) { - break; - } - if (levels.contains(levels_to_dim_[i])) { - impl->_unsafe_set_level(levels_start_ + i); - impl = maybeGetBatchedImpl(impl->value()); - - } - } - } -private: - int64_t levels_start_{}; - Slice> levels_to_dim_; + private: + int64_t levels_start_{}; + Slice> levels_to_dim_; }; -namespace{ -TensorRef _match_levels(Arena& A, TensorRef v, Slice from_levels, Slice to_levels, bool drop_levels=false) { - if (from_levels == to_levels) { - return v; - } - // drop_levels -> if a dim appears in from_levels but not to_levels, it is assumed it has stride 0. - at::IntArrayRef sz = v->sizes(); - at::IntArrayRef sd = v->strides(); - AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); - Slice nsz; - Slice nsd; - for (auto l : to_levels) { - auto oidx = from_levels.index(l); - if (!oidx) { - nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); - nsd.append(A, 0); - } else { - auto idx = *oidx; - nsz.append(A, sz[idx]); - nsd.append(A, sd[idx]); - } - } - return A.autorelease(v->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()), at::IntArrayRef(nsd.begin(), nsd.end()), v->storage_offset())); -} -} -mpy::object run_torch_function(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { - if (!pointwise_optimize) { - is_pointwise = false; - } - // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : "functorch") << " " << orig << "\n"; - - Slice> all_dims; - Slice flat_args; - auto unflatten_args = tree_flatten(A, args, flat_args); - TensorRef device_holding_tensor; - - Slice infos; - Slice result_levels; - for (auto f : flat_args) { - infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); - if (infos.back()) { - TensorInfo& info = infos.back(); - AT_ASSERT(is_pointwise || info.batchedtensor); - if (!device_holding_tensor && info.has_device) { - device_holding_tensor = infos.back().tensor; - } - for (auto l : info.levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } - } - } - - if (is_pointwise) { - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef tensor = infos[i].tensor; - if (device_holding_tensor && !infos[i].has_device) { - tensor = A.autorelease(tensor->to(device_holding_tensor->device())); - } - auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); - flat_args[i] = handle_from_tensor(A, std::move(ml)); - } - } - - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - - mpy::object result = orig.call_vector(uargs); - - // fast wrap for normal case where operator just returns a tensor. - if (THPVariable_Check(result.ptr())) { - return Tensor::from_positional(A, THPVariable_Unpack(result.ptr()), result_levels, device_holding_tensor); - } - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())){ - return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), result_levels, device_holding_tensor)); - } - return h; - }; - return tree_map(A, wrap, result); +namespace { +TensorRef _match_levels( + Arena& A, + TensorRef v, + Slice from_levels, + Slice to_levels, + bool drop_levels = false) { + if (from_levels == to_levels) { + return v; + } + // drop_levels -> if a dim appears in from_levels but not to_levels, it is + // assumed it has stride 0. + at::IntArrayRef sz = v->sizes(); + at::IntArrayRef sd = v->strides(); + AT_ASSERT(drop_levels || from_levels.size() <= to_levels.size()); + Slice nsz; + Slice nsd; + for (auto l : to_levels) { + auto oidx = from_levels.index(l); + if (!oidx) { + nsz.append(A, l.is_positional() ? 1 : l.dim()->size()); + nsd.append(A, 0); } else { - // std::cout << orig << " calling functorch...\n"; - // std::cout << "rl: " << result_levels << "\n"; - EnableAllLayers guard(A, result_levels); - for (auto i : flat_args.enumerate()) { - if (infos[i]) { - TensorRef batched = infos[i].batchedtensor; - if (device_holding_tensor && !infos[i].has_device) { - batched = A.autorelease(batched->to(device_holding_tensor->device())); - } - guard.inplace_update_layers(batched, infos[i].levels); - flat_args[i] = handle_from_tensor(A, batched); - } - } - Slice flat_it = flat_args; - mpy::vector_args uargs = unflatten_args(A, flat_it); - AT_ASSERT(flat_it.size() == 0); - mpy::object result = orig.call_vector(uargs); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(guard.from_batched(A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); - } - return h; - }; - if (THPVariable_Check(result.ptr())) { - return guard.from_batched(A, THPVariable_Unpack(result.ptr()), device_holding_tensor); - } - return tree_map(A, wrap, result); + auto idx = *oidx; + nsz.append(A, sz[idx]); + nsd.append(A, sd[idx]); } + } + return A.autorelease(v->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + v->storage_offset())); +} +} // namespace +mpy::object run_torch_function( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (!pointwise_optimize) { + is_pointwise = false; + } + // std::cout << "__torch_function__ " << ((is_pointwise) ? "pointwise" : + // "functorch") << " " << orig << "\n"; + + Slice> all_dims; + Slice flat_args; + auto unflatten_args = tree_flatten(A, args, flat_args); + TensorRef device_holding_tensor; + + Slice infos; + Slice result_levels; + for (auto f : flat_args) { + infos.append(A, TensorInfo::create(A, f, !is_pointwise, false)); + if (infos.back()) { + TensorInfo& info = infos.back(); + AT_ASSERT(is_pointwise || info.batchedtensor); + if (!device_holding_tensor && info.has_device) { + device_holding_tensor = infos.back().tensor; + } + for (auto l : info.levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } + } + } + } + + if (is_pointwise) { + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef tensor = infos[i].tensor; + if (device_holding_tensor && !infos[i].has_device) { + tensor = A.autorelease(tensor->to(device_holding_tensor->device())); + } + auto ml = _match_levels(A, tensor, infos[i].levels, result_levels); + flat_args[i] = handle_from_tensor(A, std::move(ml)); + } + } + + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + + mpy::object result = orig.call_vector(uargs); + + // fast wrap for normal case where operator just returns a tensor. + if (THPVariable_Check(result.ptr())) { + return Tensor::from_positional( + A, + THPVariable_Unpack(result.ptr()), + result_levels, + device_holding_tensor); + } + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, + THPVariable_Unpack(h.ptr()), + result_levels, + device_holding_tensor)); + } + return h; + }; + return tree_map(A, wrap, result); + } else { + // std::cout << orig << " calling functorch...\n"; + // std::cout << "rl: " << result_levels << "\n"; + EnableAllLayers guard(A, result_levels); + for (auto i : flat_args.enumerate()) { + if (infos[i]) { + TensorRef batched = infos[i].batchedtensor; + if (device_holding_tensor && !infos[i].has_device) { + batched = A.autorelease(batched->to(device_holding_tensor->device())); + } + guard.inplace_update_layers(batched, infos[i].levels); + flat_args[i] = handle_from_tensor(A, batched); + } + } + Slice flat_it = flat_args; + mpy::vector_args uargs = unflatten_args(A, flat_it); + AT_ASSERT(flat_it.size() == 0); + mpy::object result = orig.call_vector(uargs); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(guard.from_batched( + A, THPVariable_Unpack(h.ptr()), device_holding_tensor)); + } + return h; + }; + if (THPVariable_Check(result.ptr())) { + return guard.from_batched( + A, THPVariable_Unpack(result.ptr()), device_holding_tensor); + } + return tree_map(A, wrap, result); + } } -namespace{ +namespace { -mpy::object __torch_function__(Arena &A, mpy::handle orig, mpy::vector_args args, bool is_pointwise) { - if (orig == torch_Tensor___mul__) { - AT_ASSERT(args.nargs == 2 && !args.has_keywords()); - auto lhs = args[0]; - auto rhs = args[1]; - if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { - bool has_device = false; - Slice levels; - for (auto i : args.enumerate_positional()) { - auto t = TensorInfo::create(A, args[i], false); - // something like a mask * rhs, which matrix multiplies don't correctly promote - if (!t.tensor->is_floating_point()) { - return run_torch_function(A, orig, args, is_pointwise); - } - has_device = has_device || t.has_device; - for (auto l : t.levels) { - if (!levels.contains(l)) { - levels.append(A, l); - } - } - } - // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; - return Tensor::create_delayed(mpy::object::borrow(orig), args, levels, has_device); +mpy::object __torch_function__( + Arena& A, + mpy::handle orig, + mpy::vector_args args, + bool is_pointwise) { + if (orig == torch_Tensor___mul__) { + AT_ASSERT(args.nargs == 2 && !args.has_keywords()); + auto lhs = args[0]; + auto rhs = args[1]; + if (mpy::isinstance(lhs, _Tensor) && mpy::isinstance(rhs, _Tensor) && + _Tensor_ndim(lhs) == 0 && _Tensor_ndim(rhs) == 0) { + bool has_device = false; + Slice levels; + for (auto i : args.enumerate_positional()) { + auto t = TensorInfo::create(A, args[i], false); + // something like a mask * rhs, which matrix multiplies don't correctly + // promote + if (!t.tensor->is_floating_point()) { + return run_torch_function(A, orig, args, is_pointwise); } + has_device = has_device || t.has_device; + for (auto l : t.levels) { + if (!levels.contains(l)) { + levels.append(A, l); + } + } + } + // std::cout << "__torch_function__ " << "delay" << " " << orig << "\n"; + return Tensor::create_delayed( + mpy::object::borrow(orig), args, levels, has_device); } - return run_torch_function(A, orig, args, is_pointwise); + } + return run_torch_function(A, orig, args, is_pointwise); } -mpy::vector_args as_vector_args(Arena& A, mpy::handle args, mpy::handle kwargs) { - auto pos_args = (mpy::handle*) &PyTuple_GET_ITEM(args.ptr(), 0); - auto pos_n = PyTuple_GET_SIZE(args.ptr()); - if (!kwargs.ptr()) { - return mpy::vector_args(pos_args, pos_n, nullptr); - } - Slice all_args; - Slice kwnames; - all_args.extend(A, pos_args, pos_args + pos_n); - mpy::dict_view dv(kwargs); - Py_ssize_t pos = 0; - mpy::handle key, value; - while (dv.next(&pos, &key, &value)) { - all_args.append(A, value); - kwnames.append(A, key); - } - return mpy::vector_args(all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); +mpy::vector_args as_vector_args( + Arena& A, + mpy::handle args, + mpy::handle kwargs) { + auto pos_args = (mpy::handle*)&PyTuple_GET_ITEM(args.ptr(), 0); + auto pos_n = PyTuple_GET_SIZE(args.ptr()); + if (!kwargs.ptr()) { + return mpy::vector_args(pos_args, pos_n, nullptr); + } + Slice all_args; + Slice kwnames; + all_args.extend(A, pos_args, pos_args + pos_n); + mpy::dict_view dv(kwargs); + Py_ssize_t pos = 0; + mpy::handle key, value; + while (dv.next(&pos, &key, &value)) { + all_args.append(A, value); + kwnames.append(A, key); + } + return mpy::vector_args( + all_args.begin(), pos_n, A.autorelease(slice_to_tuple(kwnames))); } -PyObject* py___torch_function__(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - AT_ASSERT(nargs == 4 || nargs == 5); - auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); - bool is_pointwise = pointwise.contains(args[1]); - return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); - PY_END(nullptr) +PyObject* py___torch_function__( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + AT_ASSERT(nargs == 4 || nargs == 5); + auto va = as_vector_args(A, args[3], nargs == 5 ? args[4] : nullptr); + bool is_pointwise = pointwise.contains(args[1]); + return __torch_function__(A, args[1], std::move(va), is_pointwise).release(); + PY_END(nullptr) } mpy::object levels_to_tuple(Slice slice) { - mpy::tuple t(slice.size()); - for (auto i : slice.enumerate()) { - t.set(i, slice[i].is_positional() ? mpy::from_int(slice[i].position()) : mpy::object::borrow(slice[i].dim())); - } - mpy::object r = std::move(t); - return r; + mpy::tuple t(slice.size()); + for (auto i : slice.enumerate()) { + t.set( + i, + slice[i].is_positional() ? mpy::from_int(slice[i].position()) + : mpy::object::borrow(slice[i].dim())); + } + mpy::object r = std::move(t); + return r; } PyObject* Tensor_ndim(Tensor* self, void*) { - Py_ssize_t i = 0; - for (auto l : self->levels()) { - if (l.is_positional()) { - ++i; - } + Py_ssize_t i = 0; + for (auto l : self->levels()) { + if (l.is_positional()) { + ++i; } - return mpy::from_int(i).release(); + } + return mpy::from_int(i).release(); } PyGetSetDef Tensor_getsetters[] = { - {"_has_device", (getter) [](PyObject* self, void*) -> PyObject* { return mpy::from_bool(((Tensor*)self)->has_device()).release(); }, NULL}, - {"_tensor", (getter) [](PyObject* self, void*) -> PyObject* { - Arena A; - return THPVariable_Wrap(((Tensor*)self)->tensor(A)); }, NULL}, - {"_batchtensor", (getter) [](PyObject* self, void*) -> PyObject* { - Arena A; - return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); }, NULL}, - {"_levels", (getter) [](PyObject* self, void*) -> PyObject* { - PY_BEGIN - return levels_to_tuple(((Tensor*)self)->levels()).release(); - PY_END(nullptr) - }}, - {"ndim", (getter) Tensor_ndim, NULL, "ndim", NULL}, - {NULL} /* Sentinel */ -}; + {"_has_device", + (getter)[](PyObject* self, void*) + ->PyObject* { + return mpy::from_bool(((Tensor*)self)->has_device()).release(); +} // namespace +, NULL +} +, {"_tensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->tensor(A)); +} +, NULL +} +, {"_batchtensor", (getter)[](PyObject* self, void*)->PyObject* {Arena A; +return THPVariable_Wrap(((Tensor*)self)->batchtensor(A)); +} +, NULL +} +, + {"_levels", + (getter)[](PyObject* self, void*) + ->PyObject* {PY_BEGIN return levels_to_tuple(((Tensor*)self)->levels()) + .release(); +PY_END(nullptr) +} +} +, {"ndim", (getter)Tensor_ndim, NULL, "ndim", NULL}, { + NULL +} /* Sentinel */ +} +; PyMethodDef Tensor_methods[] = { - {NULL, NULL, 0, NULL} /* Sentinel */ + {NULL, NULL, 0, NULL} /* Sentinel */ }; } - PyTypeObject Tensor::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.Tensor", /* tp_name */ - sizeof(Tensor), /* tp_basicsize */ - 0, /* tp_itemsize */ - Tensor::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE , /* tp_flags */ - "Tensor Object", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - Tensor_methods, /* tp_methods */ - 0, /* tp_members */ - Tensor_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - Tensor::new_stub, /* tp_new */ + "_C.Tensor", /* tp_name */ + sizeof(Tensor), /* tp_basicsize */ + 0, /* tp_itemsize */ + Tensor::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "Tensor Object", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Tensor_methods, /* tp_methods */ + 0, /* tp_members */ + Tensor_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + Tensor::new_stub, /* tp_new */ }; - // dim() -------------------- static bool relevant_op(_Py_CODEUNIT c) { - switch(c) { - case STORE_NAME: - case STORE_GLOBAL: - case STORE_FAST: - case STORE_DEREF: - return true; - default: - return false; - } + switch (c) { + case STORE_NAME: + case STORE_GLOBAL: + case STORE_FAST: + case STORE_DEREF: + return true; + default: + return false; + } } static mpy::object create_dim(mpy::object name, mpy::handle size) { - auto d = Dim::create(std::move(name)); - if (!mpy::is_none(size)) { - d->set_size(mpy::to_int(size)); - } - return std::move(d); + auto d = Dim::create(std::move(name)); + if (!mpy::is_none(size)) { + d->set_size(mpy::to_int(size)); + } + return std::move(d); } static mpy::object create_dimlist(mpy::object name, mpy::handle size) { - auto d = DimList::create(std::move(name)); - if (!mpy::is_none(size)) { - if (mpy::is_int(size)) { - d->bind_len(mpy::to_int(size)); - } else { - mpy::sequence_view s(size); - d->bind_len(s.size()); - for (auto i : irange(d->size())) { - d->dims_[i]->set_size(mpy::to_int(s[i])); - } - } + auto d = DimList::create(std::move(name)); + if (!mpy::is_none(size)) { + if (mpy::is_int(size)) { + d->bind_len(mpy::to_int(size)); + } else { + mpy::sequence_view s(size); + d->bind_len(s.size()); + for (auto i : irange(d->size())) { + d->dims_[i]->set_size(mpy::to_int(s[i])); + } } - return std::move(d); + } + return std::move(d); } - - -// Python wrappers that make new reflection primitives available for older runtimes +// Python wrappers that make new reflection primitives available for older +// runtimes #if !(IS_PYTHON_3_11_PLUS) #define _PyCode_CODE(CO) ((_Py_CODEUNIT*)PyBytes_AS_STRING((CO)->co_code)) #endif -namespace{ +namespace { struct PyInstDecoder { - PyInstDecoder(PyCodeObject* code_object, int lasti) - : code_object_(code_object), code_(_PyCode_CODE(code_object)), offset_(lasti / sizeof(_Py_CODEUNIT)) {} - // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols - // See https://github.com/pytorch/pytorch/issues/93854 - void next() { - #if IS_PYTHON_3_11_PLUS - offset_ += _PyOpcode_Caches[opcode()]; - #endif - offset_ += 1; - } - int opcode() { - auto r = _Py_OPCODE(code_[offset_]); - #if IS_PYTHON_3_11_PLUS - r = _PyOpcode_Deopt[r]; - #endif - return r; - } - int oparg() { - return _Py_OPARG(code_[offset_]); - } + PyInstDecoder(PyCodeObject* code_object, int lasti) + : code_object_(code_object), + code_(_PyCode_CODE(code_object)), + offset_(lasti / sizeof(_Py_CODEUNIT)) {} + // On Windows, _PyOpcode_Caches and _PyOpcode_Deopt are private symbols + // See https://github.com/pytorch/pytorch/issues/93854 + void next() { +#if IS_PYTHON_3_11_PLUS + offset_ += _PyOpcode_Caches[opcode()]; +#endif + offset_ += 1; + } + int opcode() { + auto r = _Py_OPCODE(code_[offset_]); +#if IS_PYTHON_3_11_PLUS + r = _PyOpcode_Deopt[r]; +#endif + return r; + } + int oparg() { + return _Py_OPARG(code_[offset_]); + } - mpy::object name() { - mpy::object names; - switch(opcode()) { - case STORE_NAME: - case STORE_GLOBAL: - names = mpy::object::borrow(code_object_->co_names); - break; - case STORE_FAST: - names = mpy::object::steal(PyCode_GetVarnames(code_object_)); - break; - case STORE_DEREF: - names = mpy::object::steal(PyCode_GetCellvars(code_object_)); - break; - default: - return mpy::object(); - } - return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + mpy::object name() { + mpy::object names; + switch (opcode()) { + case STORE_NAME: + case STORE_GLOBAL: + names = mpy::object::borrow(code_object_->co_names); + break; + case STORE_FAST: + names = mpy::object::steal(PyCode_GetVarnames(code_object_)); + break; + case STORE_DEREF: + names = mpy::object::steal(PyCode_GetCellvars(code_object_)); + break; + default: + return mpy::object(); } -private: - PyCodeObject* code_object_; - _Py_CODEUNIT* code_; - int offset_; + return mpy::object::steal(PySequence_GetItem(names.ptr(), oparg())); + } + + private: + PyCodeObject* code_object_; + _Py_CODEUNIT* code_; + int offset_; }; -template -static PyObject* _dims(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - Py_ssize_t specified_ndims = -1; - Py_ssize_t found_ndims = 0; - Py_ssize_t sizes = -1; - mpy::handle n = Py_None; - mpy::handle py_sizes = Py_None; +template +static PyObject* _dims( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Py_ssize_t specified_ndims = -1; + Py_ssize_t found_ndims = 0; + Py_ssize_t sizes = -1; + mpy::handle n = Py_None; + mpy::handle py_sizes = Py_None; - if (nargs || kwnames) { - mpy::vector_args va(args, nargs, kwnames); - va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); - if (!mpy::is_none(py_sizes)) { - sizes = mpy::sequence_view(py_sizes).size(); - specified_ndims = sizes; - } - if (!mpy::is_none(n)) { - specified_ndims = mpy::to_int(n); - } + if (nargs || kwnames) { + mpy::vector_args va(args, nargs, kwnames); + va.parse("dims", {"n", "sizes"}, {&n, &py_sizes}, 0); + if (!mpy::is_none(py_sizes)) { + sizes = mpy::sequence_view(py_sizes).size(); + specified_ndims = sizes; } + if (!mpy::is_none(n)) { + specified_ndims = mpy::to_int(n); + } + } - PyThreadState* state = PyThreadState_GET(); - auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); - auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); - auto lasti = PyFrame_GetLasti(f.ptr()); - auto decoder = PyInstDecoder(c.ptr(), lasti); - #if IS_PYTHON_3_11_PLUS - // When py3.11 adapts bytecode lasti points to the precall - // rather than the call instruction after it - if (decoder.opcode() == PRECALL) { - decoder.next(); - } - #endif + PyThreadState* state = PyThreadState_GET(); + auto f = mpy::obj::steal(PyThreadState_GetFrame(state)); + auto c = mpy::obj::steal(PyFrame_GetCode(f.ptr())); + auto lasti = PyFrame_GetLasti(f.ptr()); + auto decoder = PyInstDecoder(c.ptr(), lasti); +#if IS_PYTHON_3_11_PLUS + // When py3.11 adapts bytecode lasti points to the precall + // rather than the call instruction after it + if (decoder.opcode() == PRECALL) { decoder.next(); + } +#endif + decoder.next(); - if (relevant_op(decoder.opcode())) { - found_ndims = 1; - } else if (decoder.opcode() == UNPACK_SEQUENCE) { - found_ndims = decoder.oparg(); - decoder.next(); - } + if (relevant_op(decoder.opcode())) { + found_ndims = 1; + } else if (decoder.opcode() == UNPACK_SEQUENCE) { + found_ndims = decoder.oparg(); + decoder.next(); + } - if (specified_ndims == -1) { - if (found_ndims == 0) { - mpy::raise_error(PyExc_SyntaxError, "dims() must be assigned to a sequence of variable names or have argument n specified"); - } - specified_ndims = found_ndims; - } - if (found_ndims != specified_ndims) { - found_ndims = 0; // avoid taking the wrong names for dimensions + if (specified_ndims == -1) { + if (found_ndims == 0) { + mpy::raise_error( + PyExc_SyntaxError, + "dims() must be assigned to a sequence of variable names or have argument n specified"); } + specified_ndims = found_ndims; + } + if (found_ndims != specified_ndims) { + found_ndims = 0; // avoid taking the wrong names for dimensions + } - auto genobject = [&](int i) -> mpy::object { - mpy::object name; - if (i < found_ndims) { - name = decoder.name(); - } - if (!name.ptr()) { - name = mpy::unicode_from_format("d%d", i); - found_ndims = 0; // once we fail at finding a name, we can find any more - } else { - decoder.next(); - } - return create_object(std::move(name), sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); - }; - if (sizes != -1 && sizes != specified_ndims) { - mpy::raise_error(PyExc_ValueError, "expected %d sizes but found %d", int(specified_ndims), int(sizes)); + auto genobject = [&](int i) -> mpy::object { + mpy::object name; + if (i < found_ndims) { + name = decoder.name(); } - if (specified_ndims == 1) { - return genobject(0).release(); + if (!name.ptr()) { + name = mpy::unicode_from_format("d%d", i); + found_ndims = 0; // once we fail at finding a name, we can find any more + } else { + decoder.next(); } - mpy::tuple result(specified_ndims); - for (int i = 0; i < specified_ndims; ++i) { - result.set(i, genobject(i)); - } - return result.release(); - PY_END(nullptr) + return create_object( + std::move(name), + sizes != -1 ? mpy::sequence_view(py_sizes)[i] : mpy::handle(Py_None)); + }; + if (sizes != -1 && sizes != specified_ndims) { + mpy::raise_error( + PyExc_ValueError, + "expected %d sizes but found %d", + int(specified_ndims), + int(sizes)); + } + if (specified_ndims == 1) { + return genobject(0).release(); + } + mpy::tuple result(specified_ndims); + for (int i = 0; i < specified_ndims; ++i) { + result.set(i, genobject(i)); + } + return result.release(); + PY_END(nullptr) } struct DotPart { - Slice dims; - size_t total_size = 1; - void append(Arena& A, mpy::hdl d) { - total_size *= d->size(); - dims.append(A, d); - } + Slice dims; + size_t total_size = 1; + void append(Arena& A, mpy::hdl d) { + total_size *= d->size(); + dims.append(A, d); + } }; -template +template static at::ArrayRef as_array_ref(Slice t) { - return at::ArrayRef(t.begin(), t.end()); + return at::ArrayRef(t.begin(), t.end()); } -static TensorRef dot_prepare(Arena& A, std::initializer_list parts, const TensorInfo& t) { - Slice new_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - new_levels.extend(A, p.dims); +static TensorRef dot_prepare( + Arena& A, + std::initializer_list parts, + const TensorInfo& t) { + Slice new_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; } - auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); - if (!needs_reshape) { - return r; - } - Slice view; - for (auto p : parts) { - view.append(A, p.total_size); - } - return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); + new_levels.extend(A, p.dims); + } + auto r = _match_levels(A, t.tensor, t.levels, new_levels, true); + if (!needs_reshape) { + return r; + } + Slice view; + for (auto p : parts) { + view.append(A, p.total_size); + } + return A.autorelease(r->reshape(at::IntArrayRef(view.begin(), view.end()))); } -static mpy::object dot_finish(Arena& A, std::initializer_list parts, at::Tensor r) { - Slice result_levels; - bool needs_reshape = false; - for (auto p : parts) { - if (p.dims.size() != 1) { - needs_reshape = true; - } - result_levels.extend(A, p.dims); +static mpy::object dot_finish( + Arena& A, + std::initializer_list parts, + at::Tensor r) { + Slice result_levels; + bool needs_reshape = false; + for (auto p : parts) { + if (p.dims.size() != 1) { + needs_reshape = true; } - if (needs_reshape) { - Slice new_size; - for (auto l : result_levels) { - new_size.append(A, l.dim()->size()); - } - r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + result_levels.extend(A, p.dims); + } + if (needs_reshape) { + Slice new_size; + for (auto l : result_levels) { + new_size.append(A, l.dim()->size()); } - return Tensor::from_positional(A, std::move(r), result_levels, true); + r = r.reshape(at::IntArrayRef(new_size.begin(), new_size.end())); + } + return Tensor::from_positional(A, std::move(r), result_levels, true); } +static mpy::object dot( + Arena& A, + TensorInfo lhs, + TensorInfo rhs, + Slice sum) { + auto lhs_strides = lhs.tensor->strides(); + auto rhs_strides = rhs.tensor->strides(); + DotPart lro_dims; + DotPart lo_dims; + DotPart ro_dims; + DotPart lr_dims; -static mpy::object dot(Arena& A, TensorInfo lhs, TensorInfo rhs, Slice sum) { - auto lhs_strides = lhs.tensor->strides(); - auto rhs_strides = rhs.tensor->strides(); - - DotPart lro_dims; - DotPart lo_dims; - DotPart ro_dims; - DotPart lr_dims; - - auto insert_dim = [&] (mpy::hdl d, std::optional lhs_idx, std::optional rhs_idx) { - bool reduced = sum.contains(d); - int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; - int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; - if (reduced) { - // lr - lr_dims.append(A, d); - } else { - if ((lhs_stride == 0) == (rhs_stride == 0)) { - // lro - lro_dims.append(A, d); - } else if (lhs_stride != 0) { - // lo - lo_dims.append(A, d); - } else { - AT_ASSERT(rhs_stride != 0); - ro_dims.append(A, d); - } - } - }; - - - auto rhs_seen = A.allocate(rhs.levels.size()); - std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); - - for (auto i : lhs.levels.enumerate()) { - auto d = lhs.levels[i]; - auto rhs_idx = rhs.levels.index(d); - if (rhs_idx) { - rhs_seen[*rhs_idx] = true; - } - insert_dim(d.dim(), i, rhs_idx); - } - - for (auto i : rhs.levels.enumerate()) { - if (rhs_seen[i]) { - continue; - } - auto d = rhs.levels[i]; - insert_dim(d.dim(), std::nullopt, i); - } - - if (lr_dims.dims.size() != sum.size()) { - for (auto & d : sum) { - if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { - mpy::raise_error(DimensionBindError(), "summing over non-existent dimension %S", d.dim().ptr()); - } - } - } - - // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; - // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << " " << lr_dims.dims << "\n"; - - // no batch, just call mm - if (lro_dims.dims.size() != 0) { - auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); - return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + auto insert_dim = [&](mpy::hdl d, + std::optional lhs_idx, + std::optional rhs_idx) { + bool reduced = sum.contains(d); + int64_t lhs_stride = lhs_idx ? lhs_strides[*lhs_idx] : 0; + int64_t rhs_stride = rhs_idx ? rhs_strides[*rhs_idx] : 0; + if (reduced) { + // lr + lr_dims.append(A, d); } else { - auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); - auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); - return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + if ((lhs_stride == 0) == (rhs_stride == 0)) { + // lro + lro_dims.append(A, d); + } else if (lhs_stride != 0) { + // lo + lo_dims.append(A, d); + } else { + AT_ASSERT(rhs_stride != 0); + ro_dims.append(A, d); + } } + }; + auto rhs_seen = A.allocate(rhs.levels.size()); + std::fill(rhs_seen, rhs_seen + rhs.levels.size(), false); + + for (auto i : lhs.levels.enumerate()) { + auto d = lhs.levels[i]; + auto rhs_idx = rhs.levels.index(d); + if (rhs_idx) { + rhs_seen[*rhs_idx] = true; + } + insert_dim(d.dim(), i, rhs_idx); + } + + for (auto i : rhs.levels.enumerate()) { + if (rhs_seen[i]) { + continue; + } + auto d = rhs.levels[i]; + insert_dim(d.dim(), std::nullopt, i); + } + + if (lr_dims.dims.size() != sum.size()) { + for (auto& d : sum) { + if (!lhs.levels.contains(d) && !rhs.levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "summing over non-existent dimension %S", + d.dim().ptr()); + } + } + } + + // std::cout << lhs.levels << " " << rhs.levels << " " << sum << "\n"; + // std::cout << lro_dims.dims << " " << lo_dims.dims << " " << ro_dims.dims << + // " " << lr_dims.dims << "\n"; + + // no batch, just call mm + if (lro_dims.dims.size() != 0) { + auto lhs_ = dot_prepare(A, {lro_dims, lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lro_dims, lr_dims, ro_dims}, rhs); + return dot_finish(A, {lro_dims, lo_dims, ro_dims}, at::bmm(*lhs_, *rhs_)); + } else { + auto lhs_ = dot_prepare(A, {lo_dims, lr_dims}, lhs); + auto rhs_ = dot_prepare(A, {lr_dims, ro_dims}, rhs); + return dot_finish(A, {lo_dims, ro_dims}, at::mm(*lhs_, *rhs_)); + } } -static PyObject* test_c(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN +static PyObject* test_c( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN - Arena A; - Slice s(A, 3, 4, 5); - AT_ASSERT(s.size() == 3 && s.capacity() == 8); - AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); - s.append(A, 6); - AT_ASSERT(s[3] == 6); - for(int i : irange(10)) { - s.append(A, i); - } - AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); + Arena A; + Slice s(A, 3, 4, 5); + AT_ASSERT(s.size() == 3 && s.capacity() == 8); + AT_ASSERT(s[0] == 3 && s[1] == 4 && s[2] == 5); + s.append(A, 6); + AT_ASSERT(s[3] == 6); + for (int i : irange(10)) { + s.append(A, i); + } + AT_ASSERT(s[0] == 3 && s.back() == 9 && s.size() == 14 && s.capacity() == 16); - Slice s2(A, -1, -2, -3); - AT_ASSERT(s2[1] == -2 && s[0] == 3); + Slice s2(A, -1, -2, -3); + AT_ASSERT(s2[1] == -2 && s[0] == 3); - auto ss = s.slice(1,2); - AT_ASSERT(ss.size() == 1); - AT_ASSERT(ss[0] == 4); - AT_ASSERT(ss.capacity() == 1); - ss.append(A, -4); - AT_ASSERT(ss.size() == 2 && ss[1] == -4); - ss[0] = 3; - AT_ASSERT(s[1] == 4); + auto ss = s.slice(1, 2); + AT_ASSERT(ss.size() == 1); + AT_ASSERT(ss[0] == 4); + AT_ASSERT(ss.capacity() == 1); + ss.append(A, -4); + AT_ASSERT(ss.size() == 2 && ss[1] == -4); + ss[0] = 3; + AT_ASSERT(s[1] == 4); - s.insert(A, s.slice(1, 4), ss); - AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); + s.insert(A, s.slice(1, 4), ss); + AT_ASSERT(s[1] == 3 && s[2] == -4 && s[3] == 0); - auto sz = s.size(); - s.insert(A, s.slice(1, 1), 4); - AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + auto sz = s.size(); + s.insert(A, s.slice(1, 1), 4); + AT_ASSERT(s[1] == 4 && sz + 1 == s.size()); + Slice d(A, 0, 1, 2, 3, 4); - Slice d(A, 0, 1, 2, 3, 4); + Slice b(A, 0, 1, 2, 3, 4); + b.insert(A, b.slice(1, 1), d); + AT_ASSERT(b.size() == 10); + AT_ASSERT(b[1] == 0); + AT_ASSERT(b[5] == 4); + AT_ASSERT(b.back() == 4); - Slice b(A, 0, 1, 2, 3, 4); - b.insert(A, b.slice(1,1), d); - AT_ASSERT(b.size() == 10); - AT_ASSERT(b[1] == 0); - AT_ASSERT(b[5] == 4); - AT_ASSERT(b.back() == 4); + Py_RETURN_NONE; - Py_RETURN_NONE; - - PY_END(nullptr); + PY_END(nullptr); } +static PyObject* order( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + if (kwnames) { + mpy::raise_error( + PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + } + AT_ASSERT(nargs-- > 0); + Slice orig_levels; + Slice levels; + TensorRef data; + mpy::handle self = args++[0]; + bool has_device; + if (Tensor::check_exact(self)) { + auto t = Tensor::unchecked_wrap(self); + orig_levels = t->levels(); + data = t->tensor(A); + has_device = t->has_device(); + } else { + auto d = Dim::unchecked_wrap(self); + orig_levels.append(A, d); + data = d->range(); + has_device = false; + } -static PyObject* order(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - if (kwnames) { - mpy::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames); + Slice flat_positional_dims; + Slice> to_flatten; + levels.extend(A, orig_levels); + + int orig_ndim = ndim_of_levels(levels); + auto append = [&](DimEntry d) { + auto midx = levels.index(d); + if (!midx) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_ValueError, + "tensor has %d positional dimensions, but %d specified, or it was specified twice", + int(orig_ndim), + int(d.position() + orig_ndim)); + } else { + mpy::raise_error( + PyExc_ValueError, + "tensor of dimensions %R does not contain dim %R or it was specified twice", + levels_to_tuple(orig_levels).ptr(), + d.dim().ptr()); + } } - AT_ASSERT(nargs-- > 0); - Slice orig_levels; - Slice levels; - TensorRef data; - mpy::handle self = args++[0]; - bool has_device; - if (Tensor::check_exact(self)) { - auto t = Tensor::unchecked_wrap(self); - orig_levels = t->levels(); - data = t->tensor(A); - has_device = t->has_device(); + levels[*midx] = DimEntry(); + flat_positional_dims.append(A, d); + }; + + int n_new_positional = 0; + for (auto i : irange(nargs)) { + mpy::handle arg = args[i]; + DimEntry entry = _wrap_dim(arg, orig_ndim, false); + if (!entry.is_none()) { + append(entry); + ++n_new_positional; + } else if (DimList::check(arg)) { + auto dl = DimList::unchecked_wrap(arg); + for (mpy::obj& d : dl->dims_) { + append(mpy::hdl(d)); + ++n_new_positional; + } } else { - auto d = Dim::unchecked_wrap(self); - orig_levels.append(A, d); - data = d->range(); - has_device = false; + ++n_new_positional; + if (!mpy::is_sequence(arg)) { + mpy::raise_error( + PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); + } + mpy::sequence_view sq(arg); + auto N = sq.size(); + to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); + for (auto j : irange(N)) { + DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); + if (e.is_none()) { + mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); + } + append(e); + } } + } - Slice flat_positional_dims; - Slice> to_flatten; - levels.extend(A, orig_levels); - - int orig_ndim = ndim_of_levels(levels); - auto append = [&](DimEntry d) { - auto midx = levels.index(d); - if (!midx) { - if (d.is_positional()) { - mpy::raise_error(PyExc_ValueError, "tensor has %d positional dimensions, but %d specified, or it was specified twice", int(orig_ndim), int(d.position() + orig_ndim)); - } else { - mpy::raise_error(PyExc_ValueError, "tensor of dimensions %R does not contain dim %R or it was specified twice", levels_to_tuple(orig_levels).ptr(), d.dim().ptr()); - } - } - levels[*midx] = DimEntry(); - flat_positional_dims.append(A, d); - }; - - int n_new_positional = 0; - for (auto i :irange(nargs)) { - mpy::handle arg = args[i]; - DimEntry entry = _wrap_dim(arg, orig_ndim, false); - if (!entry.is_none()) { - append(entry); - ++n_new_positional; - } else if (DimList::check(arg)) { - auto dl = DimList::unchecked_wrap(arg); - for (mpy::obj & d : dl->dims_) { - append(mpy::hdl(d)); - ++n_new_positional; - } - } else { - ++n_new_positional; - if (!mpy::is_sequence(arg)) { - mpy::raise_error(PyExc_ValueError, "expected a Dim, List[Dim], or Sequence[Dim]"); - } - mpy::sequence_view sq(arg); - auto N = sq.size(); - to_flatten.append(A, std::make_pair(flat_positional_dims.size(), N)); - for (auto j : irange(N)) { - DimEntry e = _wrap_dim(A.autorelease(sq[j]), orig_ndim, false); - if (e.is_none()) { - mpy::raise_error(PyExc_ValueError, "expected a Dim, or int"); - } - append(e); - } - } + int insert_point = -1; + Slice new_levels; + for (auto l : levels) { + if (l.is_none()) { + continue; } - - int insert_point = -1; - Slice new_levels; - for (auto l : levels) { - if (l.is_none()) { - continue; - } - if (l.is_positional()) { - if (insert_point == -1) { - insert_point = new_levels.size(); - new_levels.extend(A, flat_positional_dims); - } - } - new_levels.append(A, l); - } - if (insert_point == -1) { + if (l.is_positional()) { + if (insert_point == -1) { insert_point = new_levels.size(); new_levels.extend(A, flat_positional_dims); + } } + new_levels.append(A, l); + } + if (insert_point == -1) { + insert_point = new_levels.size(); + new_levels.extend(A, flat_positional_dims); + } - at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); + at::Tensor ndata = *_match_levels(A, data, orig_levels, new_levels); - if (to_flatten.size()) { - Slice view; - auto sz = ndata.sizes(); - // before the new positional dims - for (auto i : irange(0, insert_point)) { - view.append(A, sz[i]); - } - int i = 0; - for (auto to_flat : to_flatten) { - for (;i < to_flat.first; ++i) { - view.append(A, sz[insert_point + i]); - } - int64_t new_size = 1; - int last = i + to_flat.second; - for (; i < last; ++i) { - new_size *= sz[insert_point + i]; - } - view.append(A, new_size); - } - for (; i < flat_positional_dims.size(); ++i) { - view.append(A, sz[insert_point + i]); - } - // after the new positional dims - for (auto i : irange(insert_point + flat_positional_dims.size(), levels.size())) { - view.append(A, sz[i]); - } - // we shorted the number of dimension, so remove them from new levels - // we will renumber them later - auto n_to_remove = flat_positional_dims.size() - n_new_positional; - new_levels.insert(A, new_levels.slice(insert_point, insert_point + n_to_remove), Slice()); - ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + if (to_flatten.size()) { + Slice view; + auto sz = ndata.sizes(); + // before the new positional dims + for (auto i : irange(0, insert_point)) { + view.append(A, sz[i]); } - - // renumber the positional dimension - int seen = 0; - for (auto i : new_levels.reversed_enumerate()) { - if (new_levels[i].is_positional() || (i >= insert_point && i < insert_point + n_new_positional)) { - new_levels[i] = --seen; - } + int i = 0; + for (auto to_flat : to_flatten) { + for (; i < to_flat.first; ++i) { + view.append(A, sz[insert_point + i]); + } + int64_t new_size = 1; + int last = i + to_flat.second; + for (; i < last; ++i) { + new_size *= sz[insert_point + i]; + } + view.append(A, new_size); } - return Tensor::from_positional(A, std::move(ndata), new_levels, has_device).release(); + for (; i < flat_positional_dims.size(); ++i) { + view.append(A, sz[insert_point + i]); + } + // after the new positional dims + for (auto i : + irange(insert_point + flat_positional_dims.size(), levels.size())) { + view.append(A, sz[i]); + } + // we shorted the number of dimension, so remove them from new levels + // we will renumber them later + auto n_to_remove = flat_positional_dims.size() - n_new_positional; + new_levels.insert( + A, + new_levels.slice(insert_point, insert_point + n_to_remove), + Slice()); + ndata = std::move(ndata).reshape(at::IntArrayRef(view.begin(), view.end())); + } - PY_END(nullptr) + // renumber the positional dimension + int seen = 0; + for (auto i : new_levels.reversed_enumerate()) { + if (new_levels[i].is_positional() || + (i >= insert_point && i < insert_point + n_new_positional)) { + new_levels[i] = --seen; + } + } + return Tensor::from_positional(A, std::move(ndata), new_levels, has_device) + .release(); + + PY_END(nullptr) } -static PyObject* expand(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs-- > 0); - auto info = TensorInfo::create(A, args++[0], false); - for (auto i : irange(nargs)) { - if (!Dim::check(args[i])) { - maybeInitializeGlobals(); - mpy::vector_args vargs(args - 1, nargs + 1, kwnames); - if (THPVariable_Check(args[-1])) { - return torch_Tensor_expand.call_vector(vargs).release(); - } else { - return __torch_function__(A, torch_Tensor_expand, vargs, false).release(); - } - } +static PyObject* expand( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs-- > 0); + auto info = TensorInfo::create(A, args++[0], false); + for (auto i : irange(nargs)) { + if (!Dim::check(args[i])) { + maybeInitializeGlobals(); + mpy::vector_args vargs(args - 1, nargs + 1, kwnames); + if (THPVariable_Check(args[-1])) { + return torch_Tensor_expand.call_vector(vargs).release(); + } else { + return __torch_function__(A, torch_Tensor_expand, vargs, false) + .release(); + } } - const at::Tensor& data = *info.tensor; - auto levels = info.levels; - Slice new_levels; - Slice sz; - Slice sd; - for (auto i : irange(nargs)) { - auto d = Dim::unchecked_wrap(args[i]); - if (levels.contains(d) || new_levels.contains(d)) { - mpy::raise_error(DimensionBindError(), "expanding dimension %R already exists in tensor with dims", d.ptr()); - } - new_levels.append(A, d); - sz.append(A, d->size()); - sd.append(A, 0); + } + const at::Tensor& data = *info.tensor; + auto levels = info.levels; + Slice new_levels; + Slice sz; + Slice sd; + for (auto i : irange(nargs)) { + auto d = Dim::unchecked_wrap(args[i]); + if (levels.contains(d) || new_levels.contains(d)) { + mpy::raise_error( + DimensionBindError(), + "expanding dimension %R already exists in tensor with dims", + d.ptr()); } - new_levels.extend(A, levels); - at::IntArrayRef osz = data.sizes(); - at::IntArrayRef osd = data.strides(); - sz.extend(A, osz.begin(), osz.end()); - sd.extend(A, osd.begin(), osd.end()); - at::Tensor ndata = data.as_strided(at::IntArrayRef(sz.begin(), sz.end()), at::IntArrayRef(sd.begin(), sd.end()), data.storage_offset()); - return Tensor::from_positional(A, std::move(ndata), new_levels, info.has_device).release(); - PY_END(nullptr) + new_levels.append(A, d); + sz.append(A, d->size()); + sd.append(A, 0); + } + new_levels.extend(A, levels); + at::IntArrayRef osz = data.sizes(); + at::IntArrayRef osd = data.strides(); + sz.extend(A, osz.begin(), osz.end()); + sd.extend(A, osd.begin(), osd.end()); + at::Tensor ndata = data.as_strided( + at::IntArrayRef(sz.begin(), sz.end()), + at::IntArrayRef(sd.begin(), sd.end()), + data.storage_offset()); + return Tensor::from_positional( + A, std::move(ndata), new_levels, info.has_device) + .release(); + PY_END(nullptr) } - -static void _bind_dims_to_size(Arena & A, int64_t sz, int64_t sd, - Slice> dims, Slice& nsz, Slice& nsd) { - int64_t rhs_prod = 1; - for (auto i : dims.enumerate()) { - if (!dims[i]->is_bound()) { - for (auto j : irange(i + 1, dims.size())) { - if (!dims[j]->is_bound()) { - mpy::raise_error(DimensionBindError(), "cannot infer the sizes of two dimensions at once %R and %R", dims[i].ptr(), dims[j].ptr()); - } - rhs_prod *= dims[j]->size(); - } - if (sz % rhs_prod != 0) { - mpy::tuple tup(dims.size()); - for (auto j : dims.enumerate()) { - tup.set(j, dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) : mpy::unicode_from_string("?")); - } - mpy::raise_error(DimensionBindError(), "inferred dimension does not evenly fit into larger dimension: %d vs %R", (int) sz, tup.ptr()); - } - int64_t inferred_size = sz / rhs_prod; - dims[i]->set_size(inferred_size); - rhs_prod = sz; - break; +static void _bind_dims_to_size( + Arena& A, + int64_t sz, + int64_t sd, + Slice> dims, + Slice& nsz, + Slice& nsd) { + int64_t rhs_prod = 1; + for (auto i : dims.enumerate()) { + if (!dims[i]->is_bound()) { + for (auto j : irange(i + 1, dims.size())) { + if (!dims[j]->is_bound()) { + mpy::raise_error( + DimensionBindError(), + "cannot infer the sizes of two dimensions at once %R and %R", + dims[i].ptr(), + dims[j].ptr()); } - rhs_prod *= dims[i]->size(); - } - if (rhs_prod != sz) { + rhs_prod *= dims[j]->size(); + } + if (sz % rhs_prod != 0) { mpy::tuple tup(dims.size()); for (auto j : dims.enumerate()) { - tup.set(j, mpy::object::borrow(dims[j])); + tup.set( + j, + dims[j]->is_bound() ? mpy::from_int(dims[j]->size()) + : mpy::unicode_from_string("?")); } - mpy::raise_error(DimensionBindError(), "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", (int) sz, (int) rhs_prod, tup.ptr()); + mpy::raise_error( + DimensionBindError(), + "inferred dimension does not evenly fit into larger dimension: %d vs %R", + (int)sz, + tup.ptr()); + } + int64_t inferred_size = sz / rhs_prod; + dims[i]->set_size(inferred_size); + rhs_prod = sz; + break; } - auto new_strides = A.allocate(dims.size()); - auto prev_stride = sd; - for (auto i : dims.reversed_enumerate()) { - new_strides[i] = prev_stride; - prev_stride = dims[i]->size()*prev_stride; - } - for (auto i : dims.enumerate()) { - nsd.append(A, new_strides[i]); - nsz.append(A, dims[i]->size()); + rhs_prod *= dims[i]->size(); + } + if (rhs_prod != sz) { + mpy::tuple tup(dims.size()); + for (auto j : dims.enumerate()) { + tup.set(j, mpy::object::borrow(dims[j])); } + mpy::raise_error( + DimensionBindError(), + "Dimension sizes to do not match (%d != %d) when matching dimension pack %R", + (int)sz, + (int)rhs_prod, + tup.ptr()); + } + auto new_strides = A.allocate(dims.size()); + auto prev_stride = sd; + for (auto i : dims.reversed_enumerate()) { + new_strides[i] = prev_stride; + prev_stride = dims[i]->size() * prev_stride; + } + for (auto i : dims.enumerate()) { + nsd.append(A, new_strides[i]); + nsz.append(A, dims[i]->size()); + } } static bool has_dims(mpy::handle d) { - return Dim::check_exact(d) || Tensor::check_exact(d); + return Dim::check_exact(d) || Tensor::check_exact(d); } struct IndexingInfo { - bool can_call_original; // if true, then it is safe to just call getitem or setitem, these objects do not need special handling - bool advanced_indexing; // requires actual lookup - TensorRef self; - Slice flat_inputs; - Slice result_levels; - bool has_device; + bool can_call_original; // if true, then it is safe to just call getitem or + // setitem, these objects do not need special handling + bool advanced_indexing; // requires actual lookup + TensorRef self; + Slice flat_inputs; + Slice result_levels; + bool has_device; }; -} +} // namespace -IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none); -namespace{ +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none); +namespace { Slice as_slice(mpy::tuple_view tv) { - PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(),0); - return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); + PyObject** begin = &PyTuple_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); } Slice as_slice(mpy::list_view tv) { - PyObject** begin = &PyList_GET_ITEM(tv.ptr(),0); - return Slice((mpy::handle*)begin, (mpy::handle*) (begin + tv.size())); + PyObject** begin = &PyList_GET_ITEM(tv.ptr(), 0); + return Slice( + (mpy::handle*)begin, (mpy::handle*)(begin + tv.size())); } - -bool maybe_dimpack(Slice& elements, mpy::handle s, bool check_first=true) { - // can we avoid rechecking? - if (mpy::list_view::check(s)) { - mpy::list_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } +bool maybe_dimpack( + Slice& elements, + mpy::handle s, + bool check_first = true) { + // can we avoid rechecking? + if (mpy::list_view::check(s)) { + mpy::list_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; } - // can we avoid rechecking? - if (mpy::tuple_view::check(s)) { - mpy::tuple_view tv(s); - if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { - elements = as_slice(tv); - return true; - } + } + // can we avoid rechecking? + if (mpy::tuple_view::check(s)) { + mpy::tuple_view tv(s); + if (!check_first || (tv.size() && Dim::check_exact(tv[0]))) { + elements = as_slice(tv); + return true; } - return false; + } + return false; }; bool is_dimpack(mpy::handle s) { - Slice e; - return maybe_dimpack(e, s); + Slice e; + return maybe_dimpack(e, s); } mpy::object invoke_getitem(Arena& A, const IndexingInfo& iinfo) { - at::Tensor rtensor; - if (iinfo.advanced_indexing) { - auto self_hdl = handle_from_tensor(A, iinfo.self); - auto tup = slice_to_tuple(iinfo.flat_inputs); - // std::cout << "calling original getindex " << self_hdl << " " << tup << "\n"; - auto pytensor = mpy::object::checked_steal(THPVariable_getitem(self_hdl.ptr(), tup.ptr())); - rtensor = THPVariable_Unpack(pytensor.ptr()); - } else { - // std::cout << "skipping original getindex\n"; - rtensor = *iinfo.self; - } - // std::cout << "returning (from_positional)\n"; - return Tensor::from_positional(A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); + at::Tensor rtensor; + if (iinfo.advanced_indexing) { + auto self_hdl = handle_from_tensor(A, iinfo.self); + auto tup = slice_to_tuple(iinfo.flat_inputs); + // std::cout << "calling original getindex " << self_hdl << " " << tup << + // "\n"; + auto pytensor = mpy::object::checked_steal( + THPVariable_getitem(self_hdl.ptr(), tup.ptr())); + rtensor = THPVariable_Unpack(pytensor.ptr()); + } else { + // std::cout << "skipping original getindex\n"; + rtensor = *iinfo.self; + } + // std::cout << "returning (from_positional)\n"; + return Tensor::from_positional( + A, std::move(rtensor), iinfo.result_levels, iinfo.has_device); } -mpy::object index(Arena& A, mpy::handle self, mpy::handle dims, mpy::handle indices) { - maybeInitializeGlobals(); - Slice dims_list; - Slice indices_list; - // we allow for matching single dims to multiple dims, - // so we first have to normalize everything into the case where there is a list on lhs and the rhs - bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); - bool rhs_list = mpy::tuple_view::check(indices) || mpy::list_view::check(indices); - if (lhs_list && rhs_list) { - mpy::sequence_view dv(dims); - mpy::sequence_view ind(indices); - Py_ssize_t N = dv.size(); - if (N != ind.size()) { - mpy::raise_error(PyExc_TypeError, "dims (%d) and indices (%d) must have the same length", int(N), int(ind.size())); - } - for (auto i : irange(N)) { - dims_list.append(A, A.autorelease(dv[i])); - indices_list.append(A, A.autorelease(ind[i])); - } +mpy::object index( + Arena& A, + mpy::handle self, + mpy::handle dims, + mpy::handle indices) { + maybeInitializeGlobals(); + Slice dims_list; + Slice indices_list; + // we allow for matching single dims to multiple dims, + // so we first have to normalize everything into the case where there is a + // list on lhs and the rhs + bool lhs_list = mpy::tuple_view::check(dims) || mpy::list_view::check(dims); + bool rhs_list = + mpy::tuple_view::check(indices) || mpy::list_view::check(indices); + if (lhs_list && rhs_list) { + mpy::sequence_view dv(dims); + mpy::sequence_view ind(indices); + Py_ssize_t N = dv.size(); + if (N != ind.size()) { + mpy::raise_error( + PyExc_TypeError, + "dims (%d) and indices (%d) must have the same length", + int(N), + int(ind.size())); + } + for (auto i : irange(N)) { + dims_list.append(A, A.autorelease(dv[i])); + indices_list.append(A, A.autorelease(ind[i])); + } + } else { + dims_list.append(A, dims); + indices_list.append(A, indices); + } + + // dims being indexed can be grouped together into a single index space, and + // we have to flatten them int a single dimension before we can index them... + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + Slice new_levels; + Slice to_flatten; + Slice dims_list_flat; + auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { + auto d = _wrap_dim(s, ndim, false); + if (d.is_none()) { + mpy::raise_error( + PyExc_TypeError, + "expected a dimension specifyer but found %R", + s.ptr()); + } + return d; + }; + auto dim_not_present = [&](DimEntry d) { + if (d.is_positional()) { + mpy::raise_error( + PyExc_TypeError, + "dimension %d not in tensor of %d dimensions", + d.position() + ndim, + ndim); } else { - dims_list.append(A, dims); - indices_list.append(A, indices); + mpy::raise_error( + PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); } + }; - // dims being indexed can be grouped together into a single index space, and we have to - // flatten them int a single dimension before we can index them... - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - Slice new_levels; - Slice to_flatten; - Slice dims_list_flat; - auto parse_dim_entry = [&](mpy::handle s) -> DimEntry { - auto d = _wrap_dim(s, ndim, false); - if (d.is_none()) { - mpy::raise_error(PyExc_TypeError, "expected a dimension specifyer but found %R", s.ptr()); + for (auto i : dims_list.enumerate()) { + Slice m; + if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { + if (m.size() == 0) { + // plausible semantics work for this to have 0 elements (e.g. the index + // will always be 0) + dims_list_flat.append(A, DimEntry()); // value is just dropped + } + auto first = parse_dim_entry(m[0]); + dims_list_flat.append(A, first); + if (m.size() == 1) { + continue; + } + if (to_flatten.size() == 0) { + new_levels.extend(A, self_info.levels); + } + Slice rest; + for (auto i : irange(1, m.size())) { + auto d = parse_dim_entry(m[i]); + if (!new_levels.remove(A, d)) { + dim_not_present(d); } - return d; - }; - auto dim_not_present = [&](DimEntry d) { - if (d.is_positional()) { - mpy::raise_error(PyExc_TypeError, "dimension %d not in tensor of %d dimensions", d.position() + ndim , ndim); - } else { - mpy::raise_error(PyExc_TypeError, "dimension %R not in tensor", d.dim()->ptr()); - } - }; + rest.append(A, d); + } - for (auto i : dims_list.enumerate()) { - Slice m; - if (maybe_dimpack(m, dims_list[i], /*check_first=*/false)) { - if (m.size() == 0) { - // plausible semantics work for this to have 0 elements (e.g. the index will always be 0) - dims_list_flat.append(A, DimEntry()); // value is just dropped - } - auto first = parse_dim_entry(m[0]); - dims_list_flat.append(A, first); - if (m.size() == 1) { - continue; - } - if (to_flatten.size() == 0) { - new_levels.extend(A, self_info.levels); - } - Slice rest; - for (auto i : irange(1, m.size())) { - auto d = parse_dim_entry(m[i]); - if (!new_levels.remove(A, d)) { - dim_not_present(d); - } - rest.append(A, d); - } - - auto first_idx = new_levels.index(first); - if (!first_idx) { - dim_not_present(first); - } - new_levels.insert(A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); - to_flatten.extend(A, rest); - } else { - dims_list_flat.append(A, parse_dim_entry(dims_list[i])); - } + auto first_idx = new_levels.index(first); + if (!first_idx) { + dim_not_present(first); + } + new_levels.insert( + A, new_levels.slice(*first_idx + 1, *first_idx + 1), rest); + to_flatten.extend(A, rest); + } else { + dims_list_flat.append(A, parse_dim_entry(dims_list[i])); } - if (to_flatten.size() > 0) { - TensorRef rearranged = _match_levels(A, self_info.tensor, self_info.levels, new_levels); - at::IntArrayRef sizes = rearranged->sizes(); - Slice new_sizes; - Slice reshape_levels; - for (auto i : new_levels.enumerate()) { - if (to_flatten.contains(new_levels[i])) { - new_sizes.back() *= sizes[i]; - } else { - new_sizes.append(A, sizes[i]); - reshape_levels.append(A, new_levels[i]); - } - } - self_info.tensor = A.autorelease(rearranged->reshape(at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); + } + if (to_flatten.size() > 0) { + TensorRef rearranged = + _match_levels(A, self_info.tensor, self_info.levels, new_levels); + at::IntArrayRef sizes = rearranged->sizes(); + Slice new_sizes; + Slice reshape_levels; + for (auto i : new_levels.enumerate()) { + if (to_flatten.contains(new_levels[i])) { + new_sizes.back() *= sizes[i]; + } else { + new_sizes.append(A, sizes[i]); + reshape_levels.append(A, new_levels[i]); + } + } + self_info.tensor = A.autorelease(rearranged->reshape( + at::IntArrayRef(new_sizes.begin(), new_sizes.end()))); - self_info.levels = reshape_levels; // note: we are using the first level in a flattened group to represent the group for the rest of the op - // we need to be careful not to rely the dimensions size because it doesn't match the size of the whole group + self_info.levels = + reshape_levels; // note: we are using the first level in a flattened + // group to represent the group for the rest of the op + // we need to be careful not to rely the dimensions size + // because it doesn't match the size of the whole group + } + bool has_dimpacks = false; + for (auto idx : indices_list) { + if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { + has_dimpacks = true; + break; } - bool has_dimpacks = false; - for (auto idx : indices_list) { - if (mpy::tuple_view::check(idx) || mpy::list_view::check(idx)) { - has_dimpacks = true; - break; - } - } - IndexingInfo info = getsetitem_flat(A, self_info, Slice(), dims_list_flat, indices_list, has_dimpacks); - return invoke_getitem(A, info); + } + IndexingInfo info = getsetitem_flat( + A, + self_info, + Slice(), + dims_list_flat, + indices_list, + has_dimpacks); + return invoke_getitem(A, info); } // true -- the indices were flattened out of a tuple, list or sequence... Slice slice_from_sequence(Arena& A, mpy::handle value) { - if (mpy::tuple_view::check(value)) { - return as_slice(mpy::tuple_view(value)); - } else if (mpy::list_view::check(value)) { - return as_slice(mpy::list_view(value)); - } else { - mpy::sequence_view sv(value); - Slice r; - for (auto i : sv.enumerate()) { - r.append(A, A.autorelease(sv[i])); - } - return r; + if (mpy::tuple_view::check(value)) { + return as_slice(mpy::tuple_view(value)); + } else if (mpy::list_view::check(value)) { + return as_slice(mpy::list_view(value)); + } else { + mpy::sequence_view sv(value); + Slice r; + for (auto i : sv.enumerate()) { + r.append(A, A.autorelease(sv[i])); } + return r; + } } bool extractIndices(Arena& A, mpy::handle index, Slice& indices) { - if (mpy::tuple_view::check(index)) { - indices.extend(A, as_slice(mpy::tuple_view(index))); - return true; - } else if (THPVariable_Check(index.ptr())) { - indices.append(A, index); - return false; - } else if (!mpy::is_sequence(index)) { - indices.append(A, index); - return false; - } - // a copy of treatSequenceAsTuple modified to add Dim and our wrapped tensors.. - mpy::sequence_view sv(index); - if (sv.size() >= 32) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - for (auto i : sv.enumerate()) { - mpy::handle item; - try { - item = sv[i]; - } catch (mpy::exception_set & e) { - PyErr_Clear(); - indices.append(A, index); - return false; - } - if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || mpy::is_none(item) || has_dims(item)) { - indices.extend(A, slice_from_sequence(A, index)); - return true; - } - } + if (mpy::tuple_view::check(index)) { + indices.extend(A, as_slice(mpy::tuple_view(index))); + return true; + } else if (THPVariable_Check(index.ptr())) { indices.append(A, index); return false; + } else if (!mpy::is_sequence(index)) { + indices.append(A, index); + return false; + } + // a copy of treatSequenceAsTuple modified to add Dim and our wrapped + // tensors.. + mpy::sequence_view sv(index); + if (sv.size() >= 32) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + for (auto i : sv.enumerate()) { + mpy::handle item; + try { + item = sv[i]; + } catch (mpy::exception_set& e) { + PyErr_Clear(); + indices.append(A, index); + return false; + } + if (THPVariable_Check(item.ptr()) || mpy::is_sequence(item) || + PySlice_Check(item.ptr()) || item.ptr() == Py_Ellipsis || + mpy::is_none(item) || has_dims(item)) { + indices.extend(A, slice_from_sequence(A, index)); + return true; + } + } + indices.append(A, index); + return false; } -IndexingInfo getsetitem(Arena & A, mpy::handle self, mpy::handle index, bool tensors_have_dims) { - bool can_call_original_getitem = !tensors_have_dims; +IndexingInfo getsetitem( + Arena& A, + mpy::handle self, + mpy::handle index, + bool tensors_have_dims) { + bool can_call_original_getitem = !tensors_have_dims; - Slice input; - if (has_dims(index)) { - input.append(A, index); - } else { - bool is_sequence = extractIndices(A, index, input); - // nothing about first class dims here, fallback to getitem - if (can_call_original_getitem && !is_sequence) { - return { true }; - } + Slice input; + if (has_dims(index)) { + input.append(A, index); + } else { + bool is_sequence = extractIndices(A, index, input); + // nothing about first class dims here, fallback to getitem + if (can_call_original_getitem && !is_sequence) { + return {true}; } + } - int64_t dims_indexed = 0; - int64_t expanding_object = -1; - DimList* unbound_dim_list = nullptr; - auto check_expanding = [&](int64_t i) { - if (expanding_object != -1) { - mpy::raise_error(DimensionBindError(), "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", (int) expanding_object, (int) i); - } - expanding_object = i; - }; - Slice dimlists; - - // calculate how many dimensioned have been indexed in order to compute the size of ... - // or expand a potentially unbound dimension list. - - bool has_dimpacks_or_none = false; - for (auto i : input.enumerate()) { - mpy::handle s = input[i]; - if (Dim::check_exact(s) || Tensor::check_exact(s)) { - can_call_original_getitem = false; - ++dims_indexed; - } else if (s.ptr() == Py_Ellipsis) { - check_expanding(i); - } else if (DimList::check(s)) { - can_call_original_getitem = false; - auto dl = DimList::unchecked_wrap(s); - if (!dl->is_bound()) { - check_expanding(i); - unbound_dim_list = dl.ptr(); - } else { - dims_indexed += dl->dims_.size(); - } - dimlists.append(A, i); - } else if (mpy::is_none(s)) { - has_dimpacks_or_none = true; - } else if (is_dimpack(s)) { - can_call_original_getitem = false; - has_dimpacks_or_none = true; - ++dims_indexed; - } else { - ++dims_indexed; - } - } - - // at this point if we haven't seen any Dim objects, we also can fallback to the original getitem. - if (can_call_original_getitem) { - return {true}; - } - - // std::cout << "__getitem__ " << self << " " << index << "\n"; - - TensorInfo self_info = TensorInfo::create(A, self, false, true); - auto ndim = self_info.ndim(); - if (dims_indexed > ndim) { - mpy::raise_error(PyExc_ValueError, "at least %d indices were supplied but the tensor only has %d dimensions", (int) dims_indexed, (int) ndim); - } - // expand any unbound dimension list, or expand ... into individual : slices. - auto expanding_dims = ndim - dims_indexed; + int64_t dims_indexed = 0; + int64_t expanding_object = -1; + DimList* unbound_dim_list = nullptr; + auto check_expanding = [&](int64_t i) { if (expanding_object != -1) { - if (unbound_dim_list) { - unbound_dim_list->bind_len(expanding_dims); - } else { - // ... - Slice no_slices; - for (auto i : irange(expanding_dims)) { - (void) i; - no_slices.append(A, no_slice); - } - input.insert(A, input.slice(expanding_object, expanding_object + 1), no_slices); - } + mpy::raise_error( + DimensionBindError(), + "at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets %d and %d", + (int)expanding_object, + (int)i); } + expanding_object = i; + }; + Slice dimlists; - // flatten out any dimensions stored in dimlist elements directly into the inputs - // std::cout << dimlists << " <- dim lists!\n"; - for (int64_t i = dimlists.size() - 1; i >=0; --i) { - auto idx = dimlists[i]; - // we added more elements to input because of ... - // so we need to also adjust the index to get back to where the - // dimlist existed - if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { - idx += expanding_dims; - } - auto dl = DimList::unchecked_wrap(input[idx]); - // XXX would be better if we used an OwnedSlice in DimList - Slice more_dims((mpy::handle*) &*dl->dims_.begin(), (mpy::handle*) &*dl->dims_.end()); - input.insert(A, input.slice(idx, idx + 1), more_dims); - } + // calculate how many dimensioned have been indexed in order to compute the + // size of ... or expand a potentially unbound dimension list. - return getsetitem_flat(A, self_info, input, Slice(), Slice(), has_dimpacks_or_none); -} -} -IndexingInfo getsetitem_flat(Arena& A, TensorInfo self_info, Slice input, Slice keys, Slice values, bool has_dimpacks_or_none) { - // At this point: - // ..., DimList have been eliminated - // Dim, Tensor, Tuple[Dim,...], int, slice still remain - - - // we have to count how many times we see a dimension. - // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires advanced indexing. - Slice> seen_dims; - Slice seen_dims_nuses; - auto add_dim = [&](mpy::hdl entry) { - auto midx = seen_dims.index(entry); - if (!midx) { - seen_dims.append(A, entry); - seen_dims_nuses.append(A, 1); - } else { - ++seen_dims_nuses[*midx]; - } - }; - - Slice input_it = input; - - Slice flat_inputs; - // flat inputs will start with an empty mpy::handle if the - // actual value is in the tensor-like object in the tensor info - Slice tensor_inputs; - - auto append_flat_handle = [&](mpy::handle h) { - flat_inputs.append(A, h); - tensor_inputs.append(A, TensorInfo()); - }; - TensorRef device_holding_tensor; - auto append_tensor_input = [&](TensorInfo ti) { - flat_inputs.append(A, mpy::handle()); - tensor_inputs.append(A, ti); - if (ti.has_device && !device_holding_tensor) { - device_holding_tensor = ti.tensor; - } - }; - - Slice nsz; - Slice nsd; - at::IntArrayRef sz = self_info.tensor->sizes(); - at::IntArrayRef sd = self_info.tensor->strides(); - - auto append_size = [&](int i) { - if (has_dimpacks_or_none) { - nsz.append(A, sz[i]); - nsd.append(A, sd[i]); - } - }; - // std::cout << "self levels: " << self_info.levels << "\n"; - - auto parse_nones = [&]() { - while (input_it.size() && mpy::is_none(input_it[0])) { - append_flat_handle(no_slice); - nsz.append(A, 1); - nsd.append(A, 0); - input_it = input_it.slice(1); - } - }; - - - auto append_item = [&](int i, mpy::handle arg) { - if (Dim::check_exact(arg)) { - auto d = Dim::unchecked_wrap(arg); - d->set_size(sz[i]); - add_dim(d); - append_size(i); - append_flat_handle(arg); - return; - } - auto info = TensorInfo::create(A, arg, false, false); - if (info) { - append_size(i); - append_tensor_input(info); - for (auto il : info.levels) { - if (!il.is_positional()) { - add_dim(il.dim()); - } - } - return; - } - - if (has_dimpacks_or_none) { - Slice mp; - if (maybe_dimpack(mp, arg)) { - // dim pack - Slice> dim_pack; - for (auto d : mp) { - dim_pack.append(A, Dim::wrap(d)); - add_dim(dim_pack.back()); - append_flat_handle(dim_pack.back()); - } - _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); - return; - } - } - - append_size(i); - append_flat_handle(arg); - }; - - // pair up the indexing expressions with dimension of self it indexes - // self may have first-class dims, which do not participate the indexing. - for (auto i : self_info.levels.enumerate()) { - auto l = self_info.levels[i]; - auto idx = keys.index(l); - if (idx) { - append_item(i, values[*idx]); - } else if (l.is_positional()) { - // grab and index from the positional list - parse_nones(); - if (!input_it.size()) { - // we might have fewer indices than tensor dimensions, - // which implicitly indexes the remaining dimensions with : - append_flat_handle(no_slice); - append_size(i); - } else { - mpy::handle arg = input_it[0]; - input_it = input_it.slice(1); - append_item(i, arg); - } - } else { - add_dim(l.dim()); - append_flat_handle(l.dim()); - append_size(i); - } - } - // any training Nones may have no existing dimension associated with them in self. - parse_nones(); - - // we have to restride the tensor to collapse dimension packs and introduce our none dimensions. - if (has_dimpacks_or_none) { - self_info.tensor = A.autorelease(self_info.tensor->as_strided(at::IntArrayRef(nsz.begin(), nsz.end()),at::IntArrayRef(nsd.begin(), nsd.end()), self_info.tensor->storage_offset())); - } - - - // figure out what the shape of the indexing tensors will be - // and what the shape of the resulting tensor will be - Slice result_levels; - Slice index_levels; - int64_t tensor_insert_point = -1; - bool requires_getindex = false; - auto mark_tensor_index = [&] { - if (tensor_insert_point == -1) { - tensor_insert_point = result_levels.size(); - } else if (tensor_insert_point != result_levels.size()) { - tensor_insert_point = 0; - } - }; - for (auto i : flat_inputs.enumerate()) { - auto inp = flat_inputs[i]; - if(tensor_inputs[i]) { - requires_getindex = true; - mark_tensor_index(); - for (auto l : tensor_inputs[i].levels) { - // std::cout << "Consider to add " << l << "\n"; - if (!index_levels.contains(l)) { - index_levels.append(A, l); - } - } - } else if (Dim::check_exact(inp)) { - auto d = Dim::unchecked_wrap(inp); - // dimensions used once are just binding operations - if (1 == seen_dims_nuses[*seen_dims.index(d)]) { - flat_inputs[i] = no_slice; - result_levels.append(A, d); - } else { - requires_getindex = true; - flat_inputs[i] = mpy::handle(); - tensor_inputs[i] = TensorInfo {d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; - if (!index_levels.contains(d)) { - index_levels.append(A, d); - } - mark_tensor_index(); - } - } else { - if (inp.ptr() != no_slice.ptr()) { - requires_getindex = true; - } - if (!mpy::is_int(inp)) { - // note: actual positional indexes are accurately computed later - result_levels.append(A, -1); - } - } - } - - // indexing dimensions appear in the tensor at the _first use of a tensor_ in the indexing. So insert - // the indexing leveles into the result klevels at this spot - if (tensor_insert_point != -1) { - result_levels.insert(A, result_levels.slice(tensor_insert_point, tensor_insert_point), index_levels); - } - - // std::cout << "flat inputs: " << flat_inputs << "\n"; - // std::cout << "result_levels: " << result_levels << "\n"; - // std::cout << "index_levels: " << index_levels << "\n"; - - // get all the tensors to be the right shape for indexing - if (requires_getindex) { - for (auto i : flat_inputs.enumerate()) { - if (tensor_inputs[i]) { - AT_ASSERT(!flat_inputs[i].ptr()); - // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << "\n"; - TensorRef t = tensor_inputs[i].tensor; - if (!tensor_inputs[i].has_device && device_holding_tensor) { - t = A.autorelease(t->to(device_holding_tensor->device())); - } - flat_inputs[i] = handle_from_tensor(A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); - } - } - } - - // previously we didn't know how many positional dimensions there would be so we couldn't number them right - // so fill it in now. - auto seen_positionals = 0; - for (auto i : result_levels.reversed_enumerate()) { - if (result_levels[i].is_positional()) { - result_levels[i] = -(++seen_positionals); - } - } - - return IndexingInfo {false, requires_getindex, self_info.tensor, flat_inputs, result_levels, self_info.has_device}; -} -namespace{ -mpy::object __getitem__(Arena & A, mpy::handle self, mpy::handle index) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self)); - if (iinfo.can_call_original) { - return mpy::object::checked_steal(THPVariable_getitem(self.ptr(), index.ptr())); - } - - return invoke_getitem(A, iinfo); -} - - - -void __setitem__(Arena & A, mpy::handle self, mpy::handle index, mpy::handle rhs) { - maybeInitializeGlobals(); - auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); - if (iinfo.can_call_original) { - if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { - throw mpy::exception_set(); - } - return; - } - - auto rhs_info = TensorInfo::create(A, rhs, false, false); - if (rhs_info) { // otherwise rhs can be a scalar... - for (auto l : rhs_info.levels) { - if (!iinfo.result_levels.contains(l)) { - if (l.is_positional()) { - mpy::raise_error(DimensionBindError(), "rhs contains too many dimensions (%d) compared to indexed value (%d)", ndim_of_levels(iinfo.result_levels), rhs_info.ndim()); - } else { - auto tup = levels_to_tuple(iinfo.result_levels); - mpy::raise_error(DimensionBindError(), "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", l.dim().ptr(), tup.ptr()); - } - } - } - auto rhs_matched = _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); - rhs = handle_from_tensor(A, rhs_matched); - } - self = handle_from_tensor(A, iinfo.self); - - if (iinfo.advanced_indexing) { - auto tup = slice_to_tuple(iinfo.flat_inputs); - if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { - throw mpy::exception_set(); - } + bool has_dimpacks_or_none = false; + for (auto i : input.enumerate()) { + mpy::handle s = input[i]; + if (Dim::check_exact(s) || Tensor::check_exact(s)) { + can_call_original_getitem = false; + ++dims_indexed; + } else if (s.ptr() == Py_Ellipsis) { + check_expanding(i); + } else if (DimList::check(s)) { + can_call_original_getitem = false; + auto dl = DimList::unchecked_wrap(s); + if (!dl->is_bound()) { + check_expanding(i); + unbound_dim_list = dl.ptr(); + } else { + dims_indexed += dl->dims_.size(); + } + dimlists.append(A, i); + } else if (mpy::is_none(s)) { + has_dimpacks_or_none = true; + } else if (is_dimpack(s)) { + can_call_original_getitem = false; + has_dimpacks_or_none = true; + ++dims_indexed; } else { - torch_Tensor_copy_.call(self, rhs); + ++dims_indexed; } + } + + // at this point if we haven't seen any Dim objects, we also can fallback to + // the original getitem. + if (can_call_original_getitem) { + return {true}; + } + + // std::cout << "__getitem__ " << self << " " << index << "\n"; + + TensorInfo self_info = TensorInfo::create(A, self, false, true); + auto ndim = self_info.ndim(); + if (dims_indexed > ndim) { + mpy::raise_error( + PyExc_ValueError, + "at least %d indices were supplied but the tensor only has %d dimensions", + (int)dims_indexed, + (int)ndim); + } + // expand any unbound dimension list, or expand ... into individual : slices. + auto expanding_dims = ndim - dims_indexed; + if (expanding_object != -1) { + if (unbound_dim_list) { + unbound_dim_list->bind_len(expanding_dims); + } else { + // ... + Slice no_slices; + for (auto i : irange(expanding_dims)) { + (void)i; + no_slices.append(A, no_slice); + } + input.insert( + A, input.slice(expanding_object, expanding_object + 1), no_slices); + } + } + + // flatten out any dimensions stored in dimlist elements directly into the + // inputs std::cout << dimlists << " <- dim lists!\n"; + for (int64_t i = dimlists.size() - 1; i >= 0; --i) { + auto idx = dimlists[i]; + // we added more elements to input because of ... + // so we need to also adjust the index to get back to where the + // dimlist existed + if (!unbound_dim_list && expanding_object != -1 && idx > expanding_object) { + idx += expanding_dims; + } + auto dl = DimList::unchecked_wrap(input[idx]); + // XXX would be better if we used an OwnedSlice in DimList + Slice more_dims( + (mpy::handle*)&*dl->dims_.begin(), (mpy::handle*)&*dl->dims_.end()); + input.insert(A, input.slice(idx, idx + 1), more_dims); + } + + return getsetitem_flat( + A, + self_info, + input, + Slice(), + Slice(), + has_dimpacks_or_none); } +} // namespace +IndexingInfo getsetitem_flat( + Arena& A, + TensorInfo self_info, + Slice input, + Slice keys, + Slice values, + bool has_dimpacks_or_none) { + // At this point: + // ..., DimList have been eliminated + // Dim, Tensor, Tuple[Dim,...], int, slice still remain + + // we have to count how many times we see a dimension. + // A[i,j] is a simple binding operation, but A[i, i+j] or A[i, i] requires + // advanced indexing. + Slice> seen_dims; + Slice seen_dims_nuses; + auto add_dim = [&](mpy::hdl entry) { + auto midx = seen_dims.index(entry); + if (!midx) { + seen_dims.append(A, entry); + seen_dims_nuses.append(A, 1); + } else { + ++seen_dims_nuses[*midx]; + } + }; + + Slice input_it = input; + + Slice flat_inputs; + // flat inputs will start with an empty mpy::handle if the + // actual value is in the tensor-like object in the tensor info + Slice tensor_inputs; + + auto append_flat_handle = [&](mpy::handle h) { + flat_inputs.append(A, h); + tensor_inputs.append(A, TensorInfo()); + }; + TensorRef device_holding_tensor; + auto append_tensor_input = [&](TensorInfo ti) { + flat_inputs.append(A, mpy::handle()); + tensor_inputs.append(A, ti); + if (ti.has_device && !device_holding_tensor) { + device_holding_tensor = ti.tensor; + } + }; + + Slice nsz; + Slice nsd; + at::IntArrayRef sz = self_info.tensor->sizes(); + at::IntArrayRef sd = self_info.tensor->strides(); + + auto append_size = [&](int i) { + if (has_dimpacks_or_none) { + nsz.append(A, sz[i]); + nsd.append(A, sd[i]); + } + }; + // std::cout << "self levels: " << self_info.levels << "\n"; + + auto parse_nones = [&]() { + while (input_it.size() && mpy::is_none(input_it[0])) { + append_flat_handle(no_slice); + nsz.append(A, 1); + nsd.append(A, 0); + input_it = input_it.slice(1); + } + }; + + auto append_item = [&](int i, mpy::handle arg) { + if (Dim::check_exact(arg)) { + auto d = Dim::unchecked_wrap(arg); + d->set_size(sz[i]); + add_dim(d); + append_size(i); + append_flat_handle(arg); + return; + } + auto info = TensorInfo::create(A, arg, false, false); + if (info) { + append_size(i); + append_tensor_input(info); + for (auto il : info.levels) { + if (!il.is_positional()) { + add_dim(il.dim()); + } + } + return; + } + + if (has_dimpacks_or_none) { + Slice mp; + if (maybe_dimpack(mp, arg)) { + // dim pack + Slice> dim_pack; + for (auto d : mp) { + dim_pack.append(A, Dim::wrap(d)); + add_dim(dim_pack.back()); + append_flat_handle(dim_pack.back()); + } + _bind_dims_to_size(A, sz[i], sd[i], dim_pack, nsz, nsd); + return; + } + } + + append_size(i); + append_flat_handle(arg); + }; + + // pair up the indexing expressions with dimension of self it indexes + // self may have first-class dims, which do not participate the indexing. + for (auto i : self_info.levels.enumerate()) { + auto l = self_info.levels[i]; + auto idx = keys.index(l); + if (idx) { + append_item(i, values[*idx]); + } else if (l.is_positional()) { + // grab and index from the positional list + parse_nones(); + if (!input_it.size()) { + // we might have fewer indices than tensor dimensions, + // which implicitly indexes the remaining dimensions with : + append_flat_handle(no_slice); + append_size(i); + } else { + mpy::handle arg = input_it[0]; + input_it = input_it.slice(1); + append_item(i, arg); + } + } else { + add_dim(l.dim()); + append_flat_handle(l.dim()); + append_size(i); + } + } + // any training Nones may have no existing dimension associated with them in + // self. + parse_nones(); + + // we have to restride the tensor to collapse dimension packs and introduce + // our none dimensions. + if (has_dimpacks_or_none) { + self_info.tensor = A.autorelease(self_info.tensor->as_strided( + at::IntArrayRef(nsz.begin(), nsz.end()), + at::IntArrayRef(nsd.begin(), nsd.end()), + self_info.tensor->storage_offset())); + } + + // figure out what the shape of the indexing tensors will be + // and what the shape of the resulting tensor will be + Slice result_levels; + Slice index_levels; + int64_t tensor_insert_point = -1; + bool requires_getindex = false; + auto mark_tensor_index = [&] { + if (tensor_insert_point == -1) { + tensor_insert_point = result_levels.size(); + } else if (tensor_insert_point != result_levels.size()) { + tensor_insert_point = 0; + } + }; + for (auto i : flat_inputs.enumerate()) { + auto inp = flat_inputs[i]; + if (tensor_inputs[i]) { + requires_getindex = true; + mark_tensor_index(); + for (auto l : tensor_inputs[i].levels) { + // std::cout << "Consider to add " << l << "\n"; + if (!index_levels.contains(l)) { + index_levels.append(A, l); + } + } + } else if (Dim::check_exact(inp)) { + auto d = Dim::unchecked_wrap(inp); + // dimensions used once are just binding operations + if (1 == seen_dims_nuses[*seen_dims.index(d)]) { + flat_inputs[i] = no_slice; + result_levels.append(A, d); + } else { + requires_getindex = true; + flat_inputs[i] = mpy::handle(); + tensor_inputs[i] = TensorInfo{ + d->range(), Slice(A, DimEntry(d)), false, TensorRef()}; + if (!index_levels.contains(d)) { + index_levels.append(A, d); + } + mark_tensor_index(); + } + } else { + if (inp.ptr() != no_slice.ptr()) { + requires_getindex = true; + } + if (!mpy::is_int(inp)) { + // note: actual positional indexes are accurately computed later + result_levels.append(A, -1); + } + } + } + + // indexing dimensions appear in the tensor at the _first use of a tensor_ in + // the indexing. So insert the indexing leveles into the result klevels at + // this spot + if (tensor_insert_point != -1) { + result_levels.insert( + A, + result_levels.slice(tensor_insert_point, tensor_insert_point), + index_levels); + } + + // std::cout << "flat inputs: " << flat_inputs << "\n"; + // std::cout << "result_levels: " << result_levels << "\n"; + // std::cout << "index_levels: " << index_levels << "\n"; + + // get all the tensors to be the right shape for indexing + if (requires_getindex) { + for (auto i : flat_inputs.enumerate()) { + if (tensor_inputs[i]) { + AT_ASSERT(!flat_inputs[i].ptr()); + // std::cout << "tensor " << i << " " << tensor_inputs[i].levels << + // "\n"; + TensorRef t = tensor_inputs[i].tensor; + if (!tensor_inputs[i].has_device && device_holding_tensor) { + t = A.autorelease(t->to(device_holding_tensor->device())); + } + flat_inputs[i] = handle_from_tensor( + A, _match_levels(A, t, tensor_inputs[i].levels, index_levels)); + } + } + } + + // previously we didn't know how many positional dimensions there would be so + // we couldn't number them right so fill it in now. + auto seen_positionals = 0; + for (auto i : result_levels.reversed_enumerate()) { + if (result_levels[i].is_positional()) { + result_levels[i] = -(++seen_positionals); + } + } + + return IndexingInfo{ + false, + requires_getindex, + self_info.tensor, + flat_inputs, + result_levels, + self_info.has_device}; } +namespace { +mpy::object __getitem__(Arena& A, mpy::handle self, mpy::handle index) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self)); + if (iinfo.can_call_original) { + return mpy::object::checked_steal( + THPVariable_getitem(self.ptr(), index.ptr())); + } + + return invoke_getitem(A, iinfo); +} + +void __setitem__( + Arena& A, + mpy::handle self, + mpy::handle index, + mpy::handle rhs) { + maybeInitializeGlobals(); + auto iinfo = getsetitem(A, self, index, has_dims(self) || has_dims(rhs)); + if (iinfo.can_call_original) { + if (-1 == THPVariable_setitem(self.ptr(), index.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + return; + } + + auto rhs_info = TensorInfo::create(A, rhs, false, false); + if (rhs_info) { // otherwise rhs can be a scalar... + for (auto l : rhs_info.levels) { + if (!iinfo.result_levels.contains(l)) { + if (l.is_positional()) { + mpy::raise_error( + DimensionBindError(), + "rhs contains too many dimensions (%d) compared to indexed value (%d)", + ndim_of_levels(iinfo.result_levels), + rhs_info.ndim()); + } else { + auto tup = levels_to_tuple(iinfo.result_levels); + mpy::raise_error( + DimensionBindError(), + "rhs of setitem contains dimension %R which is not in the dimension on the left (%R)", + l.dim().ptr(), + tup.ptr()); + } + } + } + auto rhs_matched = + _match_levels(A, rhs_info.tensor, rhs_info.levels, iinfo.result_levels); + rhs = handle_from_tensor(A, rhs_matched); + } + self = handle_from_tensor(A, iinfo.self); + + if (iinfo.advanced_indexing) { + auto tup = slice_to_tuple(iinfo.flat_inputs); + if (-1 == THPVariable_setitem(self.ptr(), tup.ptr(), rhs.ptr())) { + throw mpy::exception_set(); + } + } else { + torch_Tensor_copy_.call(self, rhs); + } +} +} // namespace PyObject* Tensor_getitem(PyObject* self, PyObject* index) { - Arena A; - PY_BEGIN - return __getitem__(A, self, index).release(); - PY_END(nullptr); + Arena A; + PY_BEGIN + return __getitem__(A, self, index).release(); + PY_END(nullptr); } int Tensor_setitem(PyObject* self, PyObject* index, PyObject* value) { - Arena A; - PY_BEGIN - __setitem__(A, self, index, value); - return 0; - PY_END(-1); + Arena A; + PY_BEGIN + __setitem__(A, self, index, value); + return 0; + PY_END(-1); } -namespace{ -PyObject* py___getitem__(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 2); - return __getitem__(A, args[0], args[1]).release(); - PY_END(nullptr) +namespace { +PyObject* py___getitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 2); + return __getitem__(A, args[0], args[1]).release(); + PY_END(nullptr) } -PyObject* py___setitem__(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - AT_ASSERT(nargs == 3); - __setitem__(A, args[0], args[1], args[2]); - Py_RETURN_NONE; - PY_END(nullptr) +PyObject* py___setitem__( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + AT_ASSERT(nargs == 3); + __setitem__(A, args[0], args[1], args[2]); + Py_RETURN_NONE; + PY_END(nullptr) } - -PyObject* py_index(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, dims, indices; - va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); - return index(A, self, dims, indices).release(); - PY_END(nullptr) +PyObject* py_index( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, dims, indices; + va.parse("index", {"self", "dims", "indices"}, {&self, &dims, &indices}, 3); + return index(A, self, dims, indices).release(); + PY_END(nullptr) } +PyObject* py_stack( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + mpy::vector_args va(args, nargs, kwnames); + mpy::handle tensors, new_dim, dim; + va.parse( + "stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); -PyObject* py_stack(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); - mpy::handle tensors, new_dim, dim; - va.parse("stack", {"tensors", "new_dim", "dim"}, {&tensors, &new_dim, &dim}, 2); - - Slice result_levels; - Slice infos; - mpy::sequence_view sv(tensors); - auto new_dim_d = Dim::wrap(new_dim); - for (auto i : sv.enumerate()) { - infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); - for (auto l : infos.back().levels) { - if (!result_levels.contains(l)) { - result_levels.append(A, l); - } - } + Slice result_levels; + Slice infos; + mpy::sequence_view sv(tensors); + auto new_dim_d = Dim::wrap(new_dim); + for (auto i : sv.enumerate()) { + infos.append(A, TensorInfo::create(A, A.autorelease(sv[i]), false)); + for (auto l : infos.back().levels) { + if (!result_levels.contains(l)) { + result_levels.append(A, l); + } } - new_dim_d->set_size(infos.size()); - std::vector inputs; - inputs.reserve(infos.size()); - for (auto in : infos) { - inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); - } - auto ndim = ndim_of_levels(result_levels); - int64_t rawdim = 0; - if (dim.ptr()) { - auto d = _wrap_dim(dim, ndim, false); - auto idx = result_levels.index(d); - if (!idx) { - mpy::raise_error(PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); - } - rawdim = *idx; - } - auto result = at::stack(inputs, rawdim); - result_levels.insert(A, rawdim, new_dim_d); - return Tensor::from_positional(A, std::move(result), result_levels, true).release(); - PY_END(nullptr) -} - -PyObject* py_split(PyObject *_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - mpy::handle self, split_size_or_sections, dim; - va.parse("split", {"self", "split_size_or_sections", "dim"}, {&self, &split_size_or_sections, &dim}, 2); - bool dim_is_object = dim.ptr() && Dim::check_exact(dim); - Slice sizes; - - bool all_dims = true; - bool all_ints = true; - - if (!mpy::is_int(split_size_or_sections)) { - mpy::sequence_view sv(split_size_or_sections); - for (auto i : sv.enumerate()) { - sizes.append(A, A.autorelease(sv[i])); - if (Dim::check_exact(sizes.back())) { - all_ints = false; - } else { - all_dims = false; - } - } - } - if (all_ints) { - if (dim_is_object) { - mpy::raise_error(PyExc_TypeError, "when dim is specified as a Dim object, split sizes must also be dimensions."); - } - // call original split (if self has dimensions this will use torch function to do the split) - return torch_Tensor_split.call_vector(mpy::vector_args(args, nargs, kwnames)).release(); - } - if (!all_dims) { - mpy::raise_error(PyExc_TypeError, "split list must be ints or dims but got a mix"); - } - - auto self_info = TensorInfo::create(A, self, false); - auto ndim = self_info.ndim(); - if (!dim_is_object&& ndim == 0) { - mpy::raise_error(PyExc_TypeError, "split expects at least a 1-dimension tensor"); - } - DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; - - auto idx = self_info.levels.index(dim_l); + } + new_dim_d->set_size(infos.size()); + std::vector inputs; + inputs.reserve(infos.size()); + for (auto in : infos) { + inputs.emplace_back(*_match_levels(A, in.tensor, in.levels, result_levels)); + } + auto ndim = ndim_of_levels(result_levels); + int64_t rawdim = 0; + if (dim.ptr()) { + auto d = _wrap_dim(dim, ndim, false); + auto idx = result_levels.index(d); if (!idx) { - if (!dim.ptr()) { - dim = A.autorelease(mpy::from_int(0)); - } - mpy::raise_error(PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + mpy::raise_error( + PyExc_TypeError, "Dimension %R does not exist in inputs", dim.ptr()); } - Slice indices; + rawdim = *idx; + } + auto result = at::stack(inputs, rawdim); + result_levels.insert(A, rawdim, new_dim_d); + return Tensor::from_positional(A, std::move(result), result_levels, true) + .release(); + PY_END(nullptr) +} - int64_t total_size = 0; - Slice unbound; - for (auto i : sizes.enumerate()) { - auto d = Dim::unchecked_wrap(sizes[i]); - if (d->is_bound()) { - indices.append(A, d->size()); - total_size += indices.back(); - } else { - indices.append(A, 0); - unbound.append(A, i); - } +PyObject* py_split( + PyObject* _, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + mpy::handle self, split_size_or_sections, dim; + va.parse( + "split", + {"self", "split_size_or_sections", "dim"}, + {&self, &split_size_or_sections, &dim}, + 2); + bool dim_is_object = dim.ptr() && Dim::check_exact(dim); + Slice sizes; + + bool all_dims = true; + bool all_ints = true; + + if (!mpy::is_int(split_size_or_sections)) { + mpy::sequence_view sv(split_size_or_sections); + for (auto i : sv.enumerate()) { + sizes.append(A, A.autorelease(sv[i])); + if (Dim::check_exact(sizes.back())) { + all_ints = false; + } else { + all_dims = false; + } } - auto tensor_size = self_info.tensor->sizes()[*idx]; - - if (unbound.size()) { - if (total_size > tensor_size) { - mpy::raise_error(PyExc_TypeError, "sizes of target dimensions add up to more (%d) than source dim (%d)", int(total_size), int(tensor_size)); - } - auto remaining_size = tensor_size - total_size; - auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); - for (auto u : unbound) { - auto sz = std::min(chunk_size, remaining_size); - Dim::unchecked_wrap(sizes[u])->set_size(sz); - indices[u] = sz; - remaining_size -= sz; - } - } else if (tensor_size != total_size) { - mpy::raise_error(PyExc_TypeError, "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", int(total_size), int(tensor_size)); + } + if (all_ints) { + if (dim_is_object) { + mpy::raise_error( + PyExc_TypeError, + "when dim is specified as a Dim object, split sizes must also be dimensions."); } + // call original split (if self has dimensions this will use torch function + // to do the split) + return torch_Tensor_split + .call_vector(mpy::vector_args(args, nargs, kwnames)) + .release(); + } + if (!all_dims) { + mpy::raise_error( + PyExc_TypeError, "split list must be ints or dims but got a mix"); + } - auto result_tensors = self_info.tensor->split_with_sizes(at::IntArrayRef(indices.begin(), indices.end()), *idx); - mpy::tuple result(result_tensors.size()); - Slice new_levels; - new_levels.extend(A, self_info.levels); - for (auto i : sizes.enumerate()) { - new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); - result.set(i, Tensor::from_positional(A, std::move(result_tensors[i]), new_levels, true)); + auto self_info = TensorInfo::create(A, self, false); + auto ndim = self_info.ndim(); + if (!dim_is_object && ndim == 0) { + mpy::raise_error( + PyExc_TypeError, "split expects at least a 1-dimension tensor"); + } + DimEntry dim_l = dim.ptr() ? _wrap_dim(dim, ndim, false) : -ndim; + + auto idx = self_info.levels.index(dim_l); + if (!idx) { + if (!dim.ptr()) { + dim = A.autorelease(mpy::from_int(0)); } + mpy::raise_error( + PyExc_TypeError, "tensor does not contain dimension %R", dim.ptr()); + } + Slice indices; - return result.release(); + int64_t total_size = 0; + Slice unbound; + for (auto i : sizes.enumerate()) { + auto d = Dim::unchecked_wrap(sizes[i]); + if (d->is_bound()) { + indices.append(A, d->size()); + total_size += indices.back(); + } else { + indices.append(A, 0); + unbound.append(A, i); + } + } + auto tensor_size = self_info.tensor->sizes()[*idx]; - PY_END(nullptr) + if (unbound.size()) { + if (total_size > tensor_size) { + mpy::raise_error( + PyExc_TypeError, + "sizes of target dimensions add up to more (%d) than source dim (%d)", + int(total_size), + int(tensor_size)); + } + auto remaining_size = tensor_size - total_size; + auto chunk_size = (remaining_size + unbound.size() - 1) / unbound.size(); + for (auto u : unbound) { + auto sz = std::min(chunk_size, remaining_size); + Dim::unchecked_wrap(sizes[u])->set_size(sz); + indices[u] = sz; + remaining_size -= sz; + } + } else if (tensor_size != total_size) { + mpy::raise_error( + PyExc_TypeError, + "sum of sizes of target dimensions (%d) do not match the than source dim (%d)", + int(total_size), + int(tensor_size)); + } + + auto result_tensors = self_info.tensor->split_with_sizes( + at::IntArrayRef(indices.begin(), indices.end()), *idx); + mpy::tuple result(result_tensors.size()); + Slice new_levels; + new_levels.extend(A, self_info.levels); + for (auto i : sizes.enumerate()) { + new_levels[*idx] = Dim::unchecked_wrap(sizes[i]); + result.set( + i, + Tensor::from_positional( + A, std::move(result_tensors[i]), new_levels, true)); + } + + return result.release(); + + PY_END(nullptr) } Slice _wrap_dims(Arena& A, mpy::handle d, size_t N, bool keepdim) { - auto de = _wrap_dim(d, N, keepdim); - Slice r; - if (!de.is_none()) { - r.append(A, de); - } else { - mpy::sequence_view sq(d); - for (auto i : sq.enumerate()) { - r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); - } + auto de = _wrap_dim(d, N, keepdim); + Slice r; + if (!de.is_none()) { + r.append(A, de); + } else { + mpy::sequence_view sq(d); + for (auto i : sq.enumerate()) { + r.append(A, _wrap_dim(A.autorelease(sq[i]), N, keepdim)); } - return r; + } + return r; } struct WrappedOperator : public mpy::base { - mpy::object orig; - PyMethodDef method_def; - mpy::object name, doc; + mpy::object orig; + PyMethodDef method_def; + mpy::object name, doc; - bool is_pointwise = false; - int64_t dim_offset = 0; - int64_t keepdim_offset = 1; - std::string dim_name; - bool single_dim = false; - bool reduce = true; + bool is_pointwise = false; + int64_t dim_offset = 0; + int64_t keepdim_offset = 1; + std::string dim_name; + bool single_dim = false; + bool reduce = true; - static PyTypeObject Type; + static PyTypeObject Type; - void init(mpy::object orig_, PyCFunction wrapper_implementation, std::string dim_name_="") { - orig = std::move(orig_); - method_def.ml_meth = wrapper_implementation; - name = orig.attr("__name__"); - doc = orig.attr("__doc__"); - dim_name = std::move(dim_name_); - if (!mpy::is_none(doc) && !dim_name.empty()) { - doc = mpy::unicode_from_format("%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", doc.ptr(), dim_name.c_str()); - } - method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); - method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); - method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; - } - - mpy::object function() { - return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + void init( + mpy::object orig_, + PyCFunction wrapper_implementation, + std::string dim_name_ = "") { + orig = std::move(orig_); + method_def.ml_meth = wrapper_implementation; + name = orig.attr("__name__"); + doc = orig.attr("__doc__"); + dim_name = std::move(dim_name_); + if (!mpy::is_none(doc) && !dim_name.empty()) { + doc = mpy::unicode_from_format( + "%S\nArgument '%s' can be either an integer or a torchdim.Dim object.\n", + doc.ptr(), + dim_name.c_str()); } + method_def.ml_name = mpy::is_none(name) ? "" : PyUnicode_AsUTF8(name.ptr()); + method_def.ml_doc = mpy::is_none(doc) ? "" : PyUnicode_AsUTF8(doc.ptr()); + method_def.ml_flags = METH_FASTCALL | METH_KEYWORDS; + } + mpy::object function() { + return mpy::object::checked_steal(PyCFunction_New(&method_def, ptr())); + } }; -} +} // namespace PyTypeObject WrappedOperator::Type = { PyVarObject_HEAD_INIT(NULL, 0) - "_C.WrappedOperator", /* tp_name */ - sizeof(WrappedOperator), /* tp_basicsize */ - 0, /* tp_itemsize */ - WrappedOperator::dealloc_stub, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ + "_C.WrappedOperator", /* tp_name */ + sizeof(WrappedOperator), /* tp_basicsize */ + 0, /* tp_itemsize */ + WrappedOperator::dealloc_stub, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_as_async */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ - "Wrapped Object Holder", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - WrappedOperator::new_stub, /* tp_new */ + "Wrapped Object Holder", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + WrappedOperator::new_stub, /* tp_new */ }; -namespace{ -PyObject* patched_dim_method(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - auto self = WrappedOperator::unchecked_wrap(self_); - PY_BEGIN +namespace { +PyObject* patched_dim_method( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + auto self = WrappedOperator::unchecked_wrap(self_); + PY_BEGIN - mpy::vector_args va(args, nargs, kwnames); + mpy::vector_args va(args, nargs, kwnames); - auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - return idx == -1 ? mpy::handle() : va[idx]; - }; - Slice patched_args; - patched_args.extend(A, va.begin(), va.end()); - auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { - auto offset = offset_ + 1; // do not include self - auto idx = va.index(name, offset); - if (idx == -1) { - mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); - } - patched_args[idx] = value; - }; - - auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); - if (!dim.ptr()) { - auto info = TensorInfo::create(A, args[0], true); - EnableAllLayers l(A, info.levels); - l.inplace_update_layers(info.batchedtensor, info.levels); - patched_args[0] = handle_from_tensor(A, info.batchedtensor); - auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device).release(); + auto _getarg = [&](const char* name, int64_t offset_) -> mpy::handle { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + return idx == -1 ? mpy::handle() : va[idx]; + }; + Slice patched_args; + patched_args.extend(A, va.begin(), va.end()); + auto _patcharg = [&](const char* name, int64_t offset_, mpy::handle value) { + auto offset = offset_ + 1; // do not include self + auto idx = va.index(name, offset); + if (idx == -1) { + mpy::raise_error(PyExc_ValueError, "Missing argument %s", name); } + patched_args[idx] = value; + }; - auto info = TensorInfo::create(A, args[0]); - auto keepdim = false; - if (self->reduce) { - auto py_keepdim = _getarg("keepdim", self->keepdim_offset); - if (py_keepdim.ptr()) { - keepdim = mpy::to_bool(py_keepdim); - } - } - - auto ndim = info.ndim(); - auto dims = _wrap_dims(A, dim, ndim, keepdim); - Slice dim_indices; - auto seen = A.allocate(info.levels.size()); - std::fill(seen, seen + info.levels.size(), false); - - for (auto d : dims) { - auto midx = info.levels.index(d); - if (!midx) { - auto tup = levels_to_tuple(info.levels); - mpy::raise_error(PyExc_ValueError, "Tensor with dimensions %R does not contain one of %R\n", tup.ptr(), dim.ptr()); - } - seen[*midx] = true; - dim_indices.append(A, *midx); - } - Slice new_levels; - if (self->reduce && !keepdim) { - for (auto i : info.levels.enumerate()) { - if (!seen[i]) { - new_levels.append(A, info.levels[i]); - } - } - } else { - new_levels = info.levels; - } - mpy::object py_indices; - if (dim_indices.size() == 1) { - py_indices = mpy::from_int(dim_indices[0]); - } else { - mpy::tuple tup(dim_indices.size()); - for (auto i : dim_indices.enumerate()) { - tup.set(i, mpy::from_int(dim_indices[i])); - } - py_indices = std::move(tup); - } - _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); - patched_args[0] = handle_from_tensor(A, info.tensor); + auto dim = _getarg(self->dim_name.c_str(), self->dim_offset); + if (!dim.ptr()) { + auto info = TensorInfo::create(A, args[0], true); + EnableAllLayers l(A, info.levels); + l.inplace_update_layers(info.batchedtensor, info.levels); + patched_args[0] = handle_from_tensor(A, info.batchedtensor); auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); - auto wrap = [&](mpy::handle h) { - if (THPVariable_Check(h.ptr())) { - return A.autorelease(Tensor::from_positional(A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); - } - return h; - }; - return tree_map(A, wrap, r).release(); - PY_END(nullptr) + return l.from_batched(A, THPVariable_Unpack(r.ptr()), info.has_device) + .release(); + } + + auto info = TensorInfo::create(A, args[0]); + auto keepdim = false; + if (self->reduce) { + auto py_keepdim = _getarg("keepdim", self->keepdim_offset); + if (py_keepdim.ptr()) { + keepdim = mpy::to_bool(py_keepdim); + } + } + + auto ndim = info.ndim(); + auto dims = _wrap_dims(A, dim, ndim, keepdim); + Slice dim_indices; + auto seen = A.allocate(info.levels.size()); + std::fill(seen, seen + info.levels.size(), false); + + for (auto d : dims) { + auto midx = info.levels.index(d); + if (!midx) { + auto tup = levels_to_tuple(info.levels); + mpy::raise_error( + PyExc_ValueError, + "Tensor with dimensions %R does not contain one of %R\n", + tup.ptr(), + dim.ptr()); + } + seen[*midx] = true; + dim_indices.append(A, *midx); + } + Slice new_levels; + if (self->reduce && !keepdim) { + for (auto i : info.levels.enumerate()) { + if (!seen[i]) { + new_levels.append(A, info.levels[i]); + } + } + } else { + new_levels = info.levels; + } + mpy::object py_indices; + if (dim_indices.size() == 1) { + py_indices = mpy::from_int(dim_indices[0]); + } else { + mpy::tuple tup(dim_indices.size()); + for (auto i : dim_indices.enumerate()) { + tup.set(i, mpy::from_int(dim_indices[i])); + } + py_indices = std::move(tup); + } + _patcharg(self->dim_name.c_str(), self->dim_offset, py_indices); + patched_args[0] = handle_from_tensor(A, info.tensor); + auto r = self->orig.call_vector(patched_args.begin(), nargs, kwnames); + auto wrap = [&](mpy::handle h) { + if (THPVariable_Check(h.ptr())) { + return A.autorelease(Tensor::from_positional( + A, THPVariable_Unpack(h.ptr()), new_levels, info.has_device)); + } + return h; + }; + return tree_map(A, wrap, r).release(); + PY_END(nullptr) } -PyObject* _wrap(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN +PyObject* _wrap( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN - #define ARGS(_) _(mpy::handle, orig) _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ - _(mpy::handle, dim_name) _(mpy::handle, single_dim) _(mpy::handle, reduce) - MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) +#define ARGS(_) \ + _(mpy::handle, orig) \ + _(mpy::handle, dim_offset) _(mpy::handle, keepdim_offset) \ + _(mpy::handle, dim_name) _(mpy::handle, single_dim) \ + _(mpy::handle, reduce) + MPY_PARSE_ARGS_KWNAMES("O|OOOOO", ARGS) - std::string dim_name_str; - if (dim_name.ptr()) { - dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); - } else { - dim_name_str = "dim"; - } - auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) patched_dim_method, std::move(dim_name_str)); - if (dim_offset.ptr()) { - info->dim_offset = mpy::to_int(dim_offset); - } - if (keepdim_offset.ptr()) { - info->keepdim_offset = mpy::to_int(keepdim_offset); - } + std::string dim_name_str; + if (dim_name.ptr()) { + dim_name_str = PyUnicode_AsUTF8(dim_name.ptr()); + } else { + dim_name_str = "dim"; + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), + (PyCFunction)(void*)patched_dim_method, + std::move(dim_name_str)); + if (dim_offset.ptr()) { + info->dim_offset = mpy::to_int(dim_offset); + } + if (keepdim_offset.ptr()) { + info->keepdim_offset = mpy::to_int(keepdim_offset); + } - if (single_dim.ptr()) { - info->single_dim = mpy::to_bool(single_dim); - } - if (reduce.ptr()) { - info->reduce = mpy::to_bool(reduce); - } - return info->function().release(); - #undef ARGS + if (single_dim.ptr()) { + info->single_dim = mpy::to_bool(single_dim); + } + if (reduce.ptr()) { + info->reduce = mpy::to_bool(reduce); + } + return info->function().release(); +#undef ARGS - PY_END(nullptr) + PY_END(nullptr) } -PyObject* call_torch_function(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - Arena A; - maybeInitializeGlobals(); - auto info = WrappedOperator::unchecked_wrap(self); - return __torch_function__(A, info->orig, mpy::vector_args(args, nargs, kwnames), info->is_pointwise).release(); - PY_END(nullptr) +PyObject* call_torch_function( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + Arena A; + maybeInitializeGlobals(); + auto info = WrappedOperator::unchecked_wrap(self); + return __torch_function__( + A, + info->orig, + mpy::vector_args(args, nargs, kwnames), + info->is_pointwise) + .release(); + PY_END(nullptr) } -PyObject* _wrap_method(PyObject *self, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - AT_ASSERT(nargs == 2); - // XXX - ignore python function wrapped, we will call torch function directly - mpy::handle orig = args[0]; - if (!pointwise.ptr()) { - auto dim = mpy::import("functorch.dim"); - pointwise = dim.attr("pointwise"); - } - auto info = WrappedOperator::create(mpy::object::borrow(orig), (PyCFunction)(void*) call_torch_function); - info->is_pointwise = pointwise.contains(orig); - return PyInstanceMethod_New(info->function().release()); - PY_END(nullptr); +PyObject* _wrap_method( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + AT_ASSERT(nargs == 2); + // XXX - ignore python function wrapped, we will call torch function directly + mpy::handle orig = args[0]; + if (!pointwise.ptr()) { + auto dim = mpy::import("functorch.dim"); + pointwise = dim.attr("pointwise"); + } + auto info = WrappedOperator::create( + mpy::object::borrow(orig), (PyCFunction)(void*)call_torch_function); + info->is_pointwise = pointwise.contains(orig); + return PyInstanceMethod_New(info->function().release()); + PY_END(nullptr); } +PyObject* Tensor_sum( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + Arena A; + PY_BEGIN + maybeInitializeGlobals(); + mpy::vector_args va(args, nargs, kwnames); + auto self_ = Tensor::unchecked_wrap(args[0]); + auto d = self_->delayed(); + if (!d) { + return _Tensor_sum.call_vector(va).release(); + } + mpy::handle self, dim, keepdim, dtype; + va.parse( + "sum", + {"self", "dim", "keepdim", "dtype"}, + {&self, &dim, &keepdim, &dtype}, + 1, + 1); -PyObject* Tensor_sum(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - Arena A; - PY_BEGIN - maybeInitializeGlobals(); - mpy::vector_args va(args, nargs, kwnames); - auto self_ = Tensor::unchecked_wrap(args[0]); - auto d = self_->delayed(); - if (!d) { - return _Tensor_sum.call_vector(va).release(); - } - mpy::handle self, dim, keepdim, dtype; - va.parse("sum", {"self", "dim", "keepdim", "dtype"}, {&self, &dim, &keepdim, &dtype}, 1, 1); + if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { + // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; + return _Tensor_sum.call_vector(va).release(); + } + auto levels = self_->levels(); - if (dtype.ptr() || (keepdim.ptr() && mpy::to_bool(keepdim))) { - // std::cout << "SKIPPING fusion because dtype or keepdim=True specified\n"; - return _Tensor_sum.call_vector(va).release(); - } - auto levels = self_->levels(); + auto N = ndim_of_levels(levels); + auto reduced_dims = _wrap_dims(A, dim, N, false); - auto N = ndim_of_levels(levels); - auto reduced_dims = _wrap_dims(A, dim, N, false); - - return dot(A, TensorInfo::create(A, d->args[0], false), TensorInfo::create(A, d->args[1], false), reduced_dims).release(); - PY_END(nullptr) + return dot(A, + TensorInfo::create(A, d->args[0], false), + TensorInfo::create(A, d->args[1], false), + reduced_dims) + .release(); + PY_END(nullptr) } -PyObject* _parse_test(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - maybeInitializeGlobals(); +PyObject* _parse_test( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + maybeInitializeGlobals(); - int required = mpy::to_int(args[0]); - int kwonly = mpy::to_int(args[1]); + int required = mpy::to_int(args[0]); + int kwonly = mpy::to_int(args[1]); - mpy::vector_args va(args + 2, nargs - 2, kwnames); + mpy::vector_args va(args + 2, nargs - 2, kwnames); + mpy::handle a, b, c, d; + va.parse( + "_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); + mpy::tuple r(4); + r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); + r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); + r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); + r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); + return r.release(); - mpy::handle a, b, c, d; - va.parse("_parse_test", {"a", "b", "c", "d"}, {&a, &b, &c, &d}, required, kwonly); - mpy::tuple r(4); - r.set(0, mpy::object::borrow(a.ptr() ? a : Py_None)); - r.set(1, mpy::object::borrow(b.ptr() ? b : Py_None)); - r.set(2, mpy::object::borrow(c.ptr() ? c : Py_None)); - r.set(3, mpy::object::borrow(d.ptr() ? d : Py_None)); - return r.release(); - - PY_END(nullptr) + PY_END(nullptr) } -PyObject* _set_pointwise_optimize(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN - mpy::handle value; - mpy::vector_args va(args, nargs, kwnames); - va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); - pointwise_optimize = mpy::to_bool(value); - Py_RETURN_NONE; - PY_END(nullptr) +PyObject* _set_pointwise_optimize( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN + mpy::handle value; + mpy::vector_args va(args, nargs, kwnames); + va.parse("_set_pointwise_optimization", {"value"}, {&value}, 1); + pointwise_optimize = mpy::to_bool(value); + Py_RETURN_NONE; + PY_END(nullptr) } -PyObject* _patch_tensor_class(PyObject * self_, - PyObject *const *args, - Py_ssize_t nargs, - PyObject *kwnames) { - PY_BEGIN +PyObject* _patch_tensor_class( + PyObject* self_, + PyObject* const* args, + Py_ssize_t nargs, + PyObject* kwnames) { + PY_BEGIN - auto torch = mpy::import("torch"); - auto py_TensorBase = torch.attr("_C").attr("TensorBase"); - replaceMappingIfMatches(py_TensorBase); + auto torch = mpy::import("torch"); + auto py_TensorBase = torch.attr("_C").attr("TensorBase"); + replaceMappingIfMatches(py_TensorBase); - Py_RETURN_NONE; - PY_END(nullptr) + Py_RETURN_NONE; + PY_END(nullptr) } - const char* dims_doc = R"""( dims(n=None, sizes=None) -> torchdim.Dim or Tuple[torchdim.Dim, ...] @@ -3196,54 +3579,79 @@ Example:: )"""; PyMethodDef methods[] = { - {"dims", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS, dims_doc}, - {"dimlists", (PyCFunction)(void*) _dims, METH_FASTCALL | METH_KEYWORDS}, - {"_test_c", (PyCFunction)(void*) test_c, METH_FASTCALL | METH_KEYWORDS}, - {"_wrap_method", (PyCFunction)(void*) _wrap_method, METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_from_positional", (PyCFunction)(void*) py_Tensor_from_positional, METH_FASTCALL | METH_KEYWORDS}, - {"__torch_function__", (PyCFunction)(void*) py___torch_function__, METH_FASTCALL | METH_KEYWORDS}, - {"tree_flatten", (PyCFunction)(void*) py_tree_flatten, METH_FASTCALL | METH_KEYWORDS}, - {"order", (PyCFunction)(void*) order, METH_FASTCALL | METH_KEYWORDS}, - {"index", (PyCFunction)(void*) py_index, METH_FASTCALL | METH_KEYWORDS}, - {"stack", (PyCFunction)(void*) py_stack, METH_FASTCALL | METH_KEYWORDS}, - {"split", (PyCFunction)(void*) py_split, METH_FASTCALL | METH_KEYWORDS}, - {"expand", (PyCFunction)(void*) expand, METH_FASTCALL | METH_KEYWORDS}, - {"__getitem__", (PyCFunction)(void*) py___getitem__, METH_FASTCALL | METH_KEYWORDS}, - {"__setitem__", (PyCFunction)(void*) py___setitem__, METH_FASTCALL | METH_KEYWORDS}, - {"_wrap", (PyCFunction)(void*) _wrap, METH_FASTCALL | METH_KEYWORDS}, - {"Tensor_sum", (PyCFunction)(void*) Tensor_sum, METH_FASTCALL | METH_KEYWORDS}, - {"_parse_test", (PyCFunction)(void*) _parse_test, METH_FASTCALL | METH_KEYWORDS}, - {"_set_pointwise_optimize", (PyCFunction)(void*) _set_pointwise_optimize, METH_FASTCALL | METH_KEYWORDS}, - {"_patch_tensor_class", (PyCFunction)(void*) _patch_tensor_class, METH_FASTCALL | METH_KEYWORDS}, - {NULL, NULL, 0, NULL} /* Sentinel */ + {"dims", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS, + dims_doc}, + {"dimlists", + (PyCFunction)(void*)_dims, + METH_FASTCALL | METH_KEYWORDS}, + {"_test_c", (PyCFunction)(void*)test_c, METH_FASTCALL | METH_KEYWORDS}, + {"_wrap_method", + (PyCFunction)(void*)_wrap_method, + METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_from_positional", + (PyCFunction)(void*)py_Tensor_from_positional, + METH_FASTCALL | METH_KEYWORDS}, + {"__torch_function__", + (PyCFunction)(void*)py___torch_function__, + METH_FASTCALL | METH_KEYWORDS}, + {"tree_flatten", + (PyCFunction)(void*)py_tree_flatten, + METH_FASTCALL | METH_KEYWORDS}, + {"order", (PyCFunction)(void*)order, METH_FASTCALL | METH_KEYWORDS}, + {"index", (PyCFunction)(void*)py_index, METH_FASTCALL | METH_KEYWORDS}, + {"stack", (PyCFunction)(void*)py_stack, METH_FASTCALL | METH_KEYWORDS}, + {"split", (PyCFunction)(void*)py_split, METH_FASTCALL | METH_KEYWORDS}, + {"expand", (PyCFunction)(void*)expand, METH_FASTCALL | METH_KEYWORDS}, + {"__getitem__", + (PyCFunction)(void*)py___getitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"__setitem__", + (PyCFunction)(void*)py___setitem__, + METH_FASTCALL | METH_KEYWORDS}, + {"_wrap", (PyCFunction)(void*)_wrap, METH_FASTCALL | METH_KEYWORDS}, + {"Tensor_sum", + (PyCFunction)(void*)Tensor_sum, + METH_FASTCALL | METH_KEYWORDS}, + {"_parse_test", + (PyCFunction)(void*)_parse_test, + METH_FASTCALL | METH_KEYWORDS}, + {"_set_pointwise_optimize", + (PyCFunction)(void*)_set_pointwise_optimize, + METH_FASTCALL | METH_KEYWORDS}, + {"_patch_tensor_class", + (PyCFunction)(void*)_patch_tensor_class, + METH_FASTCALL | METH_KEYWORDS}, + {NULL, NULL, 0, NULL} /* Sentinel */ }; struct PyModuleDef module_def = { PyModuleDef_HEAD_INIT, - "_C", /* name of module */ + "_C", /* name of module */ NULL, /* module documentation, may be NULL */ - -1, /* size of per-interpreter state of the module, - or -1 if the module keeps state in global variables. */ - methods -}; -} + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + methods}; +} // namespace PyObject* Dim_init() { - Arena A; - try { - mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); - Dim::ready(mod, "Dim"); - DimList::ready(mod, "DimList"); - Tensor::ready(mod, "Tensor"); - WrappedOperator::ready(mod, "_WrappedOperator"); - Py_INCREF(&PyInstanceMethod_Type); - PyModule_AddObject(mod.ptr(), "_instancemethod", (PyObject *)&PyInstanceMethod_Type); + Arena A; + try { + mpy::object mod = mpy::object::checked_steal(PyModule_Create(&module_def)); + Dim::ready(mod, "Dim"); + DimList::ready(mod, "DimList"); + Tensor::ready(mod, "Tensor"); + WrappedOperator::ready(mod, "_WrappedOperator"); + Py_INCREF(&PyInstanceMethod_Type); + PyModule_AddObject( + mod.ptr(), "_instancemethod", (PyObject*)&PyInstanceMethod_Type); - initializeGlobals(A); - return mod.release(); - } catch(mpy::exception_set& err) { - return nullptr; - } + initializeGlobals(A); + return mod.release(); + } catch (mpy::exception_set& err) { + return nullptr; + } } #endif diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 53c9175a62d8..a3beb561f186 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -583,7 +583,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._C._dispatch_has_kernel", "torch._C._dispatch_is_alias_key", "torch._C._dispatch_is_included_in_alias", - "torch._C._dispatch_is_main_interpreter", "torch._C._dispatch_isTensorSubclassLike", "torch._C._dispatch_key_for_device", "torch._C._dispatch_key_name", diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 20116b97a481..1aa8a8b6df8a 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -409,10 +409,10 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // associated with the TensorImpl. Swap this field as well. std::optional mb_obj_a = a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); std::optional mb_obj_b = b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( mb_obj_a.has_value() && mb_obj_b.has_value(), "Both tensors should have PyObjects tagged by the current python interpreter"); @@ -422,10 +422,8 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { a->cdata = b->cdata; b->cdata = tmp; - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US); + a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); + b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index f944bb5c5461..f289a286b19c 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -586,7 +586,7 @@ static void set_tensor_attr_with_capsule( py::capsule& capsule, const char* attr_name) { std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); TORCH_CHECK( mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); auto obj = mb_obj.value(); @@ -987,7 +987,3 @@ py::handle getTorchApiFunction(const c10::OperatorHandle& op) { c10::impl::PyInterpreter* getPyInterpreter() { return torch::detail::self_interpreter.get(); } - -bool isMainPyInterpreter() { - return torch::detail::self_interpreter.is_main_interpreter(); -} diff --git a/torch/csrc/PyInterpreter.h b/torch/csrc/PyInterpreter.h index 82ca11e2c5d0..0ff9f79d02c2 100644 --- a/torch/csrc/PyInterpreter.h +++ b/torch/csrc/PyInterpreter.h @@ -10,4 +10,4 @@ TORCH_PYTHON_API py::handle getTorchApiFunction(const c10::OperatorHandle& op); // TODO: Move these to a proper namespace TORCH_PYTHON_API c10::impl::PyInterpreter* getPyInterpreter(); -TORCH_PYTHON_API bool isMainPyInterpreter(); +TORCH_PYTHON_API void initializeGlobalPyInterpreter(); diff --git a/torch/csrc/PyInterpreterHooks.h b/torch/csrc/PyInterpreterHooks.h new file mode 100644 index 000000000000..1def7b8c55ae --- /dev/null +++ b/torch/csrc/PyInterpreterHooks.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace torch::detail { + +// Concrete implementation of PyInterpreterHooks +class PyInterpreterHooks : public c10::impl::PyInterpreterHooksInterface { + public: + explicit PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs); + + c10::impl::PyInterpreter* getPyInterpreter() const override; +}; + +} // namespace torch::detail diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index cc682a2644af..08112b41aaae 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -35,7 +35,6 @@ PyTypeObject* THPStorageClass = nullptr; PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { TORCH_CHECK( PyType_IsSubtype(type, &THPStorageType), @@ -43,7 +42,7 @@ PyObject* THPStorage_NewWithStorage( "Storage is not possible. Make sure your class inherits from Storage."); auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); if (maybe_pyobj.has_value() && maybe_pyobj.value()) { TORCH_CHECK( allow_preexisting_pyobj, @@ -78,8 +77,7 @@ PyObject* THPStorage_NewWithStorage( if (!c10::impl::HermeticPyObjectTLS::get_state()) { s->is_hermetic = false; const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), obj, status); + storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); } else { s->is_hermetic = true; } @@ -91,17 +89,12 @@ PyObject* THPStorage_NewWithStorage( PyObject* THPStorage_Wrap(c10::Storage storage) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(storage), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); std::optional maybe_pyobj = pyobj_slot->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); - c10::impl::PyInterpreterStatus status = - c10::impl::PyInterpreterStatus::TAGGED_BY_US; + /*ignore_hermetic_tls=*/false); if (maybe_pyobj.has_value()) { auto obj = *maybe_pyobj; if (obj) { @@ -120,15 +113,8 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { return obj; } } - status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; - } else { - if (storage.use_count() <= 1) { - status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; - } else { - status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; - } } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage), status); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); } static bool THPStorage_isPreservable(THPStorage* self) { @@ -142,8 +128,7 @@ static bool THPStorage_isPreservable(THPStorage* self) { } if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/true) != - (PyObject*)self) { + /*ignore_hermetic_tls=*/true) != (PyObject*)self) { return false; } if (storage.use_count() <= 1) { @@ -161,11 +146,10 @@ static bool THPStorage_tryPreserve(THPStorage* self) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/true); // NOTE: It is possible to just set the PyObjectSlot here, but the point is - // that we should have already set PyObjectSlot when the storage PyObject was - // created. + // that we should have already set PyObjectSlot when the storage PyObject + // was created. TORCH_INTERNAL_ASSERT( maybe_pyobj.has_value(), "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); @@ -373,8 +357,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); // torch.Storage(size, *, ...) } else if (r.idx == 1) { @@ -387,8 +370,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); // torch.Storage(sequence, *, ...) } else if (r.idx == 2) { @@ -412,8 +394,7 @@ static PyObject* THPStorage_pynew( at::DataPtr(), allocator, /*resizable=*/true, - device_opt), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + device_opt)); THPObjectPtr item; try { const auto& storage = THPStorage_Unpack(self); @@ -509,10 +490,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { /* resizable */ false, device_opt); - PyObject* _ret = THPStorage_NewWithStorage( - Py_TYPE(self), - std::move(new_storage_impl), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + PyObject* _ret = + THPStorage_NewWithStorage(Py_TYPE(self), std::move(new_storage_impl)); return _ret; } diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index ce86475d6a95..698cd80548ef 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -19,7 +19,6 @@ TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, c10::Storage _storage, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 8e5a99e4da7f..da64bcfbd500 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -390,10 +390,7 @@ static PyObject* THPStorage_fromFile( storage->set_nbytes(actual_nbytes); } - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(storage), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index 9f7d667613dc..e58865bb60a8 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -86,8 +86,7 @@ static PyObject* THPStorage_pyNewFilenameStorage( THManagedMapAllocator::makeDataPtr( "", handle.c_str(), flags, static_cast(size)), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -182,8 +181,7 @@ static PyObject* THPStorage_newSharedFilename( THManagedMapAllocator::makeDataPtr( manager_handle, object_handle, flags, size), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -197,9 +195,7 @@ static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { return nullptr; } return THPStorage_NewWithStorage( - THPStorageClass, - at::new_shm_fd_storage(size), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + THPStorageClass, at::new_shm_fd_storage(size)); END_HANDLE_TH_ERRORS } @@ -278,8 +274,7 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { at::MapAllocator::makeDataPtr( at::WITH_FD, "", fd, flags, size, nullptr), /*allocator=*/nullptr, - /*resizable=*/false), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + /*resizable=*/false)); END_HANDLE_TH_ERRORS } @@ -560,10 +555,7 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { base->set_resizable(false); base->set_received_cuda(true); - return THPStorage_NewWithStorage( - THPStorageClass, - std::move(base), - c10::impl::PyInterpreterStatus::TAGGED_BY_US); + return THPStorage_NewWithStorage(THPStorageClass, std::move(base)); #else TORCH_CHECK(false, "CUDA is not available"); #endif diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index b0235da869fb..c184dd63d294 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -209,7 +209,6 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); // clang-tidy gets confused by static const @@ -261,16 +260,12 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { } if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, - var, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } std::optional mb_obj = var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); - c10::impl::PyInterpreterStatus status{}; + /*ignore_hermetic_tls=*/false); if (mb_obj.has_value()) { auto obj = *mb_obj; if (obj) { @@ -295,27 +290,17 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR // being a thing, the PyObject field will get cleared when all references // to the Python object are removed. - status = c10::impl::PyInterpreterStatus::TAGGED_BY_US; - } else { - // Assumption: if a Tensor has been shared across threads, this induces - // a refcount bump. Therefore, if the use count 1, we are the sole thread - // with access to this tensor and no race is possible. - if (var.use_count() <= 1) { - status = c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED; - } else { - status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED; - } } if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); + return THPVariable_NewWithVar((PyTypeObject*)clazz, var); } - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); } static bool isResurrectable(THPVariable* self) { @@ -344,8 +329,7 @@ static bool isResurrectable(THPVariable* self) { } // Check if this is hermetic. If it is, no resurrection. if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false) != - (PyObject*)self) { + /*ignore_hermetic_tls=*/false) != (PyObject*)self) { return false; } return true; @@ -371,7 +355,6 @@ static bool THPVariable_tryResurrect(THPVariable* self) { c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); TORCH_INTERNAL_ASSERT( @@ -587,10 +570,7 @@ static PyObject* THPVariable_as_subclass( // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - self.alias(), - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); END_HANDLE_TH_ERRORS } @@ -642,10 +622,7 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - data, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, data); END_HANDLE_TH_ERRORS } @@ -790,10 +767,7 @@ static PyObject* THPVariable_make_wrapper_subclass( tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); + return THPVariable_NewWithVar((PyTypeObject*)cls, tensor); END_HANDLE_TH_ERRORS } @@ -1821,7 +1795,6 @@ PyObject* THPVariable_pynew( return THPVariable_NewWithVar( type, tensor, - c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS } @@ -1874,8 +1847,7 @@ static int THPVariable_subclass_clear(THPVariable* self) { if (!self->cdata.unsafeIsBorrowed() && tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false) == - (PyObject*)self) { + /*ignore_hermetic_tls=*/false) == (PyObject*)self) { // TODO: empirically, on OS X this assert appears to be untrue // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn // distributed/rpc/test_process_group_agent.py @@ -2047,17 +2019,10 @@ static void THPVariable_subclass_dealloc(PyObject* self) { Py_DECREF(type); } -// Creates a new Python object for a Variable. The status parameter -// specifies what the interpreter tag status on the object is; for -// example, if you ran check_pyobj, the return optional of this object -// tells you if the tensor was already tagged or not so you can pass -// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where -// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED. -// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. +// Creates a new Python object for a Variable. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, const at::TensorBase& _var, - c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { // Make sure that the reinterpret into a THPVariable* will be valid TORCH_CHECK( @@ -2068,7 +2033,7 @@ static PyObject* THPVariable_NewWithVar( // This function overwrite the Tensor's pyobj field without extra checks // Make sure it is not set otherwise we would leak memory auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - getPyInterpreter(), /*ignore_hermetic_tls=*/false); + /*ignore_hermetic_tls=*/false); // Under some circumstances, we may attempt to create a new Python // object for a variable that already has a Python object. The most common @@ -2150,8 +2115,7 @@ static PyObject* THPVariable_NewWithVar( // Normal codepath v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( - getPyInterpreter(), obj, status); + var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); if (check_has_torch_dispatch(obj)) { var.unsafeGetTensorImpl()->set_python_dispatch(true); } diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index b2b0e848a7e7..019ce2070634 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -209,12 +209,10 @@ class PythonKernelHolder : public c10::OperatorKernel { } }; +// @todo sahanp: Afait only register is used in the codebase. This can be +// removed / simplified static torch::_RegisterOrVerify register_or_verify() { - if (isMainPyInterpreter()) { - return torch::_RegisterOrVerify::REGISTER; - } else { - return torch::_RegisterOrVerify::VERIFY; - } + return torch::_RegisterOrVerify::REGISTER; } static py::object ophandle_call_boxed( @@ -287,7 +285,6 @@ void initDispatchBindings(PyObject* module) { .def( "reset", [](const py::object& self) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().reset(); return; }, @@ -297,7 +294,6 @@ void initDispatchBindings(PyObject* module) { .def( "def_", [](py::object self, const char* schema, const char* alias) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; @@ -311,7 +307,6 @@ void initDispatchBindings(PyObject* module) { .def( "def_legacy", [](py::object self, const char* schema) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def(torch::jit::parseSchema(schema)); return self; }, @@ -331,7 +326,6 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -349,7 +343,6 @@ void initDispatchBindings(PyObject* module) { const char* dispatch, const char* alias, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { @@ -370,7 +363,6 @@ void initDispatchBindings(PyObject* module) { const char* name, const char* dispatch, const char* debug) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; @@ -465,7 +457,6 @@ void initDispatchBindings(PyObject* module) { .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; @@ -480,7 +471,6 @@ void initDispatchBindings(PyObject* module) { bool with_keyset) { HANDLE_TH_ERRORS auto& lib = self.cast(); - TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); if (func.is(py::module::import("torch.library") .attr("fallthrough_kernel"))) { lib.fallback( @@ -913,8 +903,6 @@ void initDispatchBindings(PyObject* module) { handle.setReportErrorCallback_(std::move(callback_obj)); }); - m.def( - "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); m.def("_dispatch_pystub", [](const char* name, const char* overload) { return c10::Dispatcher::singleton().getPyStub( c10::OperatorName(name, overload));