[Fix] Completely remove stride normalization on DLPack Tensor (#164161)

A followup on PR #163282
Fixes #163274
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164161
Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
Kathryn-cat
2025-10-14 17:17:08 +00:00
committed by PyTorch MergeBot
parent 6adaa328f4
commit 7fee6bbf34
2 changed files with 6 additions and 27 deletions

View File

@ -389,37 +389,16 @@ void fillVersion<DLManagedTensorVersioned>(
// constructed out of ATen tensor
template <class T>
T* toDLPackImpl(const Tensor& src) {
auto view = src;
// Detect whether there is need to normalize the strides
// Background: gh-83069
//
// However, normalizing strides can come at a high-cost
// to slow down toDLPack conversion 3x, so we
// only normalize if needed.
//
// The following code detects whether the src follows
// a continuous pattern. If the src follows such pattern (common-case)
// then we do not need to normalize the strides.
bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
// less common case, try normalizing the strides
if (need_normalize_strides) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
view = src.as_strided(shape, {1}, src.storage_offset());
}
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
atDLMTensor->handle = view;
atDLMTensor->handle = src;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter<T>;
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
fillVersion(&atDLMTensor->tensor);