mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
add strides to slow path
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78610 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d7627955b
commit
a90f006fe5
@ -218,6 +218,7 @@ void concrete_dispatch_fn(
|
||||
bool concrete_is_contiguous_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
int64_t concrete_dim_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
|
||||
|
||||
class PyInterpreterHolder {
|
||||
public:
|
||||
@ -229,7 +230,8 @@ class PyInterpreterHolder {
|
||||
&concrete_dispatch_fn,
|
||||
&concrete_is_contiguous_fn,
|
||||
&concrete_device_fn,
|
||||
&concrete_dim_fn)) {}
|
||||
&concrete_dim_fn,
|
||||
&concrete_strides_fn)) {}
|
||||
// NB: intentionally leaks the memory
|
||||
~PyInterpreterHolder() {
|
||||
impl_->disarm();
|
||||
@ -1905,7 +1907,6 @@ bool isPythonTensor(const Tensor& tensor) {
|
||||
py::object torchDispatchFromTensorImpl(const c10::TensorImpl* self, const char* func_name, PyObject* torch_api_function, const char* module_name) {
|
||||
TORCH_CHECK(PyGILState_Check(), "GIL must be held before you call parseIValuesToPyArgsKwargs");
|
||||
|
||||
// Setup the arguments expected for the detach call
|
||||
std::vector<py::handle> overloaded_args;
|
||||
// TODO: there should be a shorter way to spell this
|
||||
// TODO: fix the constness of target
|
||||
@ -2100,4 +2101,41 @@ c10::Device concrete_device_fn(const c10::impl::PyInterpreter*, const c10::Tenso
|
||||
return toDevice(out.ptr());
|
||||
}
|
||||
|
||||
c10::IntArrayRef concrete_strides_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"stride",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("stride")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out == Py_None) {
|
||||
return self->strides_default();
|
||||
}
|
||||
|
||||
py::object values = py::reinterpret_steal<py::object>(out.ptr());
|
||||
|
||||
c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
|
||||
c10::optional<PyObject*> mb_obj = ptr->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
PyObject* subclass = *mb_obj;
|
||||
Py_INCREF(subclass);
|
||||
py::object sub = py::reinterpret_steal<py::object>(subclass);
|
||||
|
||||
py::object os = py::module_::import("torch").attr("overrides");
|
||||
py::function get_buffer = py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
|
||||
auto buffer = get_buffer(sub, values);
|
||||
auto result = THPUtils_unpackLongs(buffer.ptr());
|
||||
int64_t* start = (int64_t*) result[0];
|
||||
int64_t len = result[1];
|
||||
|
||||
return c10::IntArrayRef(start, len);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
Reference in New Issue
Block a user