add strides to slow path

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78610

Approved by: https://github.com/ezyang
This commit is contained in:
George Qi
2022-06-10 03:02:28 +00:00
committed by PyTorch MergeBot
parent 1d7627955b
commit a90f006fe5
8 changed files with 125 additions and 30 deletions

View File

@ -218,6 +218,7 @@ void concrete_dispatch_fn(
bool concrete_is_contiguous_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
int64_t concrete_dim_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
class PyInterpreterHolder {
public:
@ -229,7 +230,8 @@ class PyInterpreterHolder {
&concrete_dispatch_fn,
&concrete_is_contiguous_fn,
&concrete_device_fn,
&concrete_dim_fn)) {}
&concrete_dim_fn,
&concrete_strides_fn)) {}
// NB: intentionally leaks the memory
~PyInterpreterHolder() {
impl_->disarm();
@ -1905,7 +1907,6 @@ bool isPythonTensor(const Tensor& tensor) {
py::object torchDispatchFromTensorImpl(const c10::TensorImpl* self, const char* func_name, PyObject* torch_api_function, const char* module_name) {
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs");
// Setup the arguments expected for the detach call
std::vector<py::handle> overloaded_args;
// TODO: there should be a shorter way to spell this
// TODO: fix the constness of target
@ -2100,4 +2101,41 @@ c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::Tenso
return toDevice(out.ptr());
}
c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"stride",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("stride")
.ptr(),
"torch.ops.aten");
if (out == Py_None) {
return self->strides_default();
}
py::object values = py::reinterpret_steal<py::object>(out.ptr());
c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
c10::optional<PyObject*> mb_obj = ptr->check_pyobj(getPyInterpreter());
TORCH_CHECK(mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
PyObject* subclass = *mb_obj;
Py_INCREF(subclass);
py::object sub = py::reinterpret_steal<py::object>(subclass);
py::object os = py::module_::import("torch").attr("overrides");
py::function get_buffer = py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
auto buffer = get_buffer(sub, values);
auto result = THPUtils_unpackLongs(buffer.ptr());
int64_t* start = (int64_t*) result[0];
int64_t len = result[1];
return c10::IntArrayRef(start, len);
}
} // anonymous namespace