Files
pytorch/torch/csrc/PyInterpreter.cpp
PaliC 1b99c1859c [BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427)
This PR is a bit more involved but effectively works to drastically simplify PyObjectSlot and PyInterpreter.
1) For PyObjectSlot we now use a global pyinterpreter since there only is one. From here we change all of the call sites to rely on this assumption.
2) We also remove the "tags" of the PyInterpreter by deprecating `PyInterpreterStatus`.

For the reviewer, sadly it seems like `functorch/csrc/dim/dim.cpp` needed to get linted, so there is an unreadable amount of changes there. Fortunately, the only actual change in the file is as follows which just removes `getPyInterpreter()` from  the `check_pyobj` call.

```
 mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
-    // fast case: tensor is live in python
-    std::optional<PyObject*> mb_obj =
-        t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
-    if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
-        return *mb_obj;
-    }
-    return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
-}
-}
+  // fast case: tensor is live in python
+  std::optional<PyObject*> mb_obj =
+      t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
+          /*ignore_hermetic_tls=*/false);
+  if (mb_obj.has_value() &&
+      !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
+    return *mb_obj;
+  }
+  return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
+}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158427
Approved by: https://github.com/albanD
2025-07-30 17:29:43 +00:00

990 lines
35 KiB
C++

#include <ATen/core/PythonFallbackKernel.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/autograd/generated/VariableType.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <string>
using namespace torch;
using namespace at;
using namespace c10;
namespace torch::detail {
namespace {
// NB: This is a macro and not a template function (like it was before)
// because passing in constexpr char* as template argument breaks some
// versions of MSVC that are being used internally at Meta.
// MSVC 14.16.27023 (vs2017_15.9)
#define CONCRETE_GPU_TRACE(device_type, func_name, ...) \
at::impl::MaybeSetTLSOnEntryGuard guard; \
if (Py_IsInitialized()) { \
pybind11::gil_scoped_acquire gil; \
try { \
/* Masquerade hip as cuda because hip uses `torch.cuda` module. */ \
if (device_type == at::kHIP) { \
device_type = at::kCUDA; \
} \
std::string module_name = "torch." + DeviceTypeName(device_type, true); \
py::module mod = py::module::import(module_name.c_str()); \
py::object hook = \
mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks"); \
hook(__VA_ARGS__); \
} catch (const std::exception& e) { \
LOG(ERROR) << device_type \
<< " trace hook execution failed: " << e.what(); \
} \
}
struct ConcretePyInterpreterVTable final
: public c10::impl::PyInterpreterVTable {
std::string name() const override;
void incref(PyObject* pyobj) const override;
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl
c10::intrusive_ptr<c10::TensorImpl> detach(
const c10::TensorImpl* self) const override;
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override;
void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override;
// NB: this is defined in python_dispatch.cpp
void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset,
bool with_op) const override {
torch::impl::dispatch::python_op_registration_trampoline_impl(
op, key, keyset, stack, with_keyset, with_op);
}
void throw_abstract_impl_not_imported_error(
std::string opname,
const char* pymodule,
const char* context) const override {
py::gil_scoped_acquire gil;
pybind11::module::import("torch._utils_internal")
.attr("throw_abstract_impl_not_imported_error")(
opname, pymodule, context);
}
bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat)
const override;
bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override;
c10::Device device(const c10::TensorImpl* self) const override;
int64_t dim(const c10::TensorImpl* self) const override;
c10::IntArrayRef strides(const c10::TensorImpl* self) const override;
c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
c10::Layout layout(const c10::TensorImpl* self) const override;
int64_t numel(const c10::TensorImpl* self) const override;
c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event);
}
void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event)
const override {
CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event);
}
void trace_gpu_event_record(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream);
}
void trace_gpu_event_wait(
at::DeviceType device_type,
uintptr_t event,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream);
}
void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr);
}
void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr)
const override {
CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr);
}
void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream)
const override {
CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream);
}
void trace_gpu_device_synchronization(
at::DeviceType device_type) const override {
CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks");
}
void trace_gpu_stream_synchronization(
at::DeviceType device_type,
uintptr_t stream) const override {
CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream);
}
void trace_gpu_event_synchronization(
at::DeviceType device_type,
uintptr_t event) const override {
CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event);
}
void reset_backward_hooks(const c10::TensorImpl* self) const override;
static ConcretePyInterpreterVTable* instance() {
static ConcretePyInterpreterVTable s;
return &s;
}
};
class PyInterpreterHolder {
public:
PyInterpreterHolder()
: impl_(new c10::impl::PyInterpreter(
ConcretePyInterpreterVTable::instance())),
is_main_interpreter_(
at::impl::PythonOpRegistrationTrampoline::registerInterpreter(
impl_)) {}
PyInterpreterHolder(const PyInterpreterHolder&) = delete;
PyInterpreterHolder(PyInterpreterHolder&&) = delete;
PyInterpreterHolder& operator=(const PyInterpreterHolder&) = delete;
PyInterpreterHolder& operator=(PyInterpreterHolder&&) = delete;
// NB: intentionally leaks the PyInterpreter, as there may still be
// references to it that are live, living in objects that aren't being
// destructed while Python is being cleaned up.
~PyInterpreterHolder() {
impl_->disarm();
}
c10::impl::PyInterpreter* get() const noexcept {
return impl_;
}
bool is_main_interpreter() const noexcept {
return is_main_interpreter_;
}
private:
c10::impl::PyInterpreter* impl_;
bool is_main_interpreter_;
};
py::object torchDispatchFromTensorImpl(
const c10::TensorImpl* self,
const char* func_name,
PyObject* torch_api_function,
const char* module_name,
// WARNING: MUST NOT BE TENSOR ARGS
c10::SmallVector<py::object, 1> extra_args = {}) {
if (torch_api_function == nullptr) {
throw python_error();
}
TORCH_CHECK(
PyGILState_Check(),
"GIL must be held before you call parseIValuesToPyArgsKwargs");
std::vector<PyObject*> overloaded_args;
// TODO: there should be a shorter way to spell this
// TODO: fix the constness of target
at::Tensor self_t = at::Tensor(
c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
// NB: this may not be a python tensor if you got here from a mode!
// TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
append_overloaded_tensor(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(
PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size())));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
int64_t i = 1;
for (auto& a : extra_args) {
if (a.ptr() == nullptr)
throw python_error();
PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr());
i++;
}
py::dict kwargs;
return py::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
func_name,
torch_api_function,
module_name,
TorchFunctionName::TorchDispatch));
}
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the
// pyobj has a PyObjectSlot or not.
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
// - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const {
// Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned)
// PyObjects stored in them.
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
// Two possibilities:
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
// Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
if (THPVariable_Check(pyobj)) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
// so that it fails on subsequent uses. Don't raise an error here;
// you're probably in a destructor.
TORCH_WARN(
"Deallocating Tensor that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"Tensor and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this tensor via the PyObject will now fail.");
((THPVariable*)pyobj)->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
"Deallocating UntypedStorage that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>();
}
}
Py_DECREF(pyobj);
}
void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
Py_INCREF(pyobj);
}
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
void ConcretePyInterpreterVTable::reportErrorCallback(
PyObject* callback,
DispatchKey key) const {
py::gil_scoped_acquire g;
auto func = py::reinterpret_borrow<py::object>(callback);
// Not all DispatchKeys are pybind'ed into Python and we do not have infra
// to ensure this, so just pass a string back to Python.
func(c10::toString(key));
}
void ConcretePyInterpreterVTable::dispatch(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const {
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
// The plan: convert all the arguments back into PyObjects,
// extracting out the tensor handles, then call
// handle_torch_function_no_python_arg_parser
// NB: at the point arguments are pushed to the stack, ALL defaults
// are already present
py::gil_scoped_acquire g;
std::vector<PyObject*> overloaded_args;
py::handle torch_api_function_overload = getTorchApiFunction(op);
// Find overloaded tensors
for (const auto idx : c10::irange(arguments.size())) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) {
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
for (const auto jdx : c10::irange(list.size())) {
const auto& nv = list[jdx];
if (nv.isTensor()) {
const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) {
append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
}
}
}
}
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
PyObject* obj = handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
nullptr,
torch_api_function_overload.ptr(),
nullptr,
TorchFunctionName::TorchDispatch);
pushPyOutToStack(
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
}
void ConcretePyInterpreterVTable::python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet ks,
torch::jit::Stack* stack) const {
py::gil_scoped_acquire g;
py::handle torch_api_function_overload = getTorchApiFunction(op);
// TODO: if necessary, can optimize to cache the cache lookup
// TODO: if necessary, can optimize OpOverload to have slots
auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache"));
if (cache.ptr() == nullptr) {
throw python_error();
}
c10::DispatchKey k = ks.highestPriorityTypeId();
PyObject* raw_handler = nullptr;
if (PyDict_GetItemRef(cache.ptr(), py::cast(k).ptr(), &raw_handler) < 0) {
// There was an error that is not missing key (which would return 0)
throw python_error();
}
auto handler = py::reinterpret_steal<py::object>(raw_handler);
if (handler.ptr() == nullptr) {
// Slow path
handler = torch_api_function_overload.attr("_get_dispatch")(k);
}
if (py::isinstance<c10::DispatchKey>(handler)) {
// NB: not redispatch, as that will permanently remove the python
// dispatcher for subsequent redispatches
op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack);
return;
}
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto args = std::move(args_kwargs.first);
auto kwargs = std::move(args_kwargs.second);
py::object obj = py::reinterpret_steal<py::object>(
PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr()));
if (obj.ptr() == nullptr) {
throw python_error();
}
pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
}
c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"detach",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("detach")
.attr("default")
.ptr(),
"torch.ops.aten");
TORCH_CHECK(
THPVariable_Check(out.ptr()),
"detach returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected Tensor");
const at::Tensor& res_t = THPVariable_Unpack(out.ptr());
return res_t.getIntrusivePtr();
}
bool ConcretePyInterpreterVTable::is_contiguous(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
py::object out;
if (memory_format == at::MemoryFormat::Contiguous) {
// For backwards compatibility
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("default")
.ptr(),
"torch.ops.aten");
} else {
out = torchDispatchFromTensorImpl(
self,
"is_contiguous",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_contiguous")
.attr("memory_format")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
}
if (out.is_none()) {
return self->is_contiguous_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_contiguous returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_strides_like(
const c10::TensorImpl* self,
at::MemoryFormat memory_format) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_strides_like",
py::module::import("torch")
.attr("ops")
.attr("aten")
// NB: intentionally suffixed with _format to avoid
// triggering matches against "_like" suffix
.attr("is_strides_like_format")
.attr("default")
.ptr(),
"torch.ops.aten",
{py::cast(memory_format)});
if (out.is_none()) {
return self->is_strides_like_default(memory_format);
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_strides_like_format returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"is_non_overlapping_and_dense",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("is_non_overlapping_and_dense")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->is_non_overlapping_and_dense_default();
}
TORCH_CHECK(
PyBool_Check(out.ptr()),
"is_non_overlapping_and_dense returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected bool");
return PyObject_IsTrue(out.ptr());
}
int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"dim",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("dim")
.attr("default")
.ptr(),
"torch.ops.aten");
TORCH_CHECK(
PyLong_Check(out.ptr()),
"dim returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected int");
return THPUtils_unpackLong(out.ptr());
}
c10::Device ConcretePyInterpreterVTable::device(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"device",
py::module::import("torch")
.attr("ops")
.attr("prim")
.attr("device")
.attr("default")
.ptr(),
"torch.ops.prim");
return toDevice(out.ptr());
}
static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto obj = mb_obj.value();
py::handle(obj).attr(attr_name) = capsule;
}
// Note [Tensor Subclass custom size/stride caching strategy]
// Tensor subclasses can use __torch_dispatch__ to override size/stride calls.
// However, this presents a problem:
// (1) When you return a custom (maybe symbolic) size/stride
// from python, we need to stash this fresh vector of ints/symints
// somewhere so that it has the same lifetime as the tensor.
// (2) If the subclass experiences a metadata mutation,
// this stashed vector is no longer valid, so we need to allocate a fresh
// buffer to store the new sizes the next time someone asks for them.
//
// We handle this in the same way that `TensorImpl::sizes_default()`
// handles its buffer: we simply reallocate the buffer whenever
// the number of dimensions changes due to a resize.
// Notable, we do *not* reallocate the buffer if the values changed,
// but the number of dimensions stayed the same (e.g. `.transpose_()`).
template <typename T>
static c10::ArrayRef<T> get_set_cached_attr(
const c10::TensorImpl* tensor,
const char* base_attr_name,
const py::object& obj) {
std::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto tensor_obj = mb_obj.value();
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
bool is_buffer_allocated = false;
size_t curr_size = 0;
if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) {
auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str());
curr_size = py::cast<size_t>(len_pyobj);
is_buffer_allocated = true;
}
size_t new_size = py::len(obj);
// We do the smallvector optimization here: any time the new_size is <=5,
// we always allocate our buffer to size 5, so that if the next resize
// is also to <=5 elements, we don't need to reallocate.
// Note: I tried removing this optimization and tripped ASAN
// in a batchnorm kernel here:
// https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D@lint-ignore
// We should fix this instead.
bool needs_resize = false;
// We need to resize if:
// (1) we haven't allocated our buffer at all yet
// (2) Our buffer size is different from the new size
// (note: we use the small vector optimization, where our buffer
// is always allocated to at least size 5, and any resizes
// within the <= 5 regime to not require a reallocation).
auto is_smallvector = curr_size <= 5;
needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) ||
(!is_smallvector && curr_size != new_size);
if (needs_resize) {
// If our current buffer is not the right size (either because we haven't
// allocated it yet, or there was a metadata mutation that changed the
// number of dims of the tensor), allocate a fresh buffer. Note that this
// will trash the previous buffer if there already was one, invalidating any
// existing SymIntArrayRef's from an old .sym_size() call.
auto new_buffer_size = new_size;
if (new_size <= 5) {
// This is the smallvector optimization
new_buffer_size = 5;
}
T* ptr = new T[new_buffer_size];
auto capsule =
py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); });
int64_t idx = 0;
for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
ptr[idx] = py::cast<T>(*it);
}
// Set the buffer
set_tensor_attr_with_capsule(tensor, capsule, base_attr_name);
// Set the len buffer
py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size;
} else {
TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name));
auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
void* buffer_pycapsule =
PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
// Overwrite the buffer with our new values, but only if any of them changed
// (due to a metadata mutation).
// This is technically not thread safe, because the update happens lazily.
// The original metadata mutation call on the tensor might have been thread
// safe (e.g. a .resize_() call), but we won't actually mutate the size
// buffer until the first call to .sizes() which the user might not access
// in a thread-safe way. For now we are not explicitly locking, but maybe we
// should.
int64_t idx = 0;
// Quick sanity assert that our buffer size is large enough
// to compare against all the elements in the new buffer.
size_t curr_buffer_size = 5;
if (curr_buffer_size < curr_size) {
curr_buffer_size = curr_size;
}
TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size);
for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) {
auto actual_val = py::cast<T>(*it);
if constexpr (std::is_same_v<T, c10::SymInt>) {
// if our SymInts are symbolic, we are *not* doing an equality check on
// the symints. we just want to see if the nodes are the same. this is
// because we don't want to introduce any guards here.
if (!curr_buffer[idx].is_same(actual_val)) {
curr_buffer[idx] = actual_val;
}
} else {
if (curr_buffer[idx] != actual_val) {
curr_buffer[idx] = actual_val;
}
}
}
}
// The correct data is now stored at the buffer - read and return it.
auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name);
void* buffer_pycapsule =
PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr);
auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule);
return c10::ArrayRef<T>(curr_buffer, new_size);
}
c10::IntArrayRef ConcretePyInterpreterVTable::strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"stride",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("stride")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call strides on a tensor with symbolic shapes/strides");
return self->strides_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"strides must be a list or a tuple");
auto updated_strides =
get_set_cached_attr<int64_t>(self, "_strides_capsule", out);
return updated_strides;
}
c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"size",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("size")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->sizes_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sizes must be a list or a tuple");
auto updated_sizes =
get_set_cached_attr<int64_t>(self, "_sizes_capsule", out);
return updated_sizes;
END_HANDLE_TH_ERRORS_PYBIND
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"sym_size",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_size")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_sizes_default();
}
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sym_size must be a list or a tuple");
// See Note [Tensor Subclass custom size/stride caching strategy]
auto updated_sym_sizes =
get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out);
return updated_sym_sizes;
END_HANDLE_TH_ERRORS_PYBIND
}
c10::Layout ConcretePyInterpreterVTable::layout(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"layout",
py::module::import("torch")
.attr("ops")
.attr("prim")
.attr("layout")
.attr("default")
.ptr(),
"torch.ops.prim");
TORCH_CHECK(
THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()),
"layout returned invalid type ",
py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())),
", expected Layout");
if (THPLayout_Check(out.ptr())) {
return toLayout(out.ptr());
} else {
return c10::Layout(py::cast<int64_t>(out));
}
}
int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->numel_default();
}
return py::cast<int64_t>(out);
}
c10::SymInt ConcretePyInterpreterVTable::sym_numel(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_numel_default();
}
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"sym_storage_offset",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_storage_offset")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_storage_offset_default();
}
return torch::is_symint(out) ? out.cast<c10::SymInt>()
: c10::SymInt{py::cast<int64_t>(out)};
}
c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
auto out = torchDispatchFromTensorImpl(
self,
"sym_stride",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("sym_stride")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
return self->sym_strides_default();
}
// We need to squeeze SymIntNodes and ints into `SymInts`
// since it's a format `sym_strides()` are stored in
TORCH_CHECK(
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"sym_strides must be a list or a tuple");
auto updated_sym_strides =
get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out);
return updated_sym_strides;
END_HANDLE_TH_ERRORS_PYBIND
}
void ConcretePyInterpreterVTable::reset_backward_hooks(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
HANDLE_TH_ERRORS
Tensor self_t =
Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None);
END_HANDLE_TH_ERRORS_PYBIND
}
std::string ConcretePyInterpreterVTable::name() const {
std::stringstream ss;
ss << getPyInterpreter();
return ss.str();
}
PyInterpreterHolder self_interpreter;
} // anonymous namespace
py::handle getTorchApiFunction(const c10::OperatorHandle& op) {
return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* {
// Parse the name into namespace and name (no overload_name)
// TODO: put this into the library
const auto& schema = op.schema();
const auto& qualified_name = op.operator_name().name;
const auto& overload_name = schema.overload_name();
auto pos = qualified_name.find("::");
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
// Make me some null terminated strings
std::string ns_str = qualified_name.substr(0, pos);
const char* ns = ns_str.c_str();
const char* func_name = qualified_name.c_str() + pos + strlen("::");
py::handle torch_api_function =
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
if (overload_name.empty()) {
return torch_api_function.attr("default").ptr();
} else {
return torch_api_function.attr(overload_name.c_str()).ptr();
}
});
}
} // namespace torch::detail
c10::impl::PyInterpreter* getPyInterpreter() {
return torch::detail::self_interpreter.get();
}