mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Fix] Completely remove stride normalization on DLPack Tensor (#164161)
A followup on PR #163282 Fixes #163274 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164161 Approved by: https://github.com/ngimel, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
6adaa328f4
commit
7fee6bbf34
@ -389,37 +389,16 @@ void fillVersion<DLManagedTensorVersioned>(
|
||||
// constructed out of ATen tensor
|
||||
template <class T>
|
||||
T* toDLPackImpl(const Tensor& src) {
|
||||
auto view = src;
|
||||
|
||||
// Detect whether there is need to normalize the strides
|
||||
// Background: gh-83069
|
||||
//
|
||||
// However, normalizing strides can come at a high-cost
|
||||
// to slow down toDLPack conversion 3x, so we
|
||||
// only normalize if needed.
|
||||
//
|
||||
// The following code detects whether the src follows
|
||||
// a continuous pattern. If the src follows such pattern (common-case)
|
||||
// then we do not need to normalize the strides.
|
||||
bool need_normalize_strides = src.dim() == 1 && src.size(0) == 1 && src.stride(0) != 1;
|
||||
// less common case, try normalizing the strides
|
||||
if (need_normalize_strides) {
|
||||
// create a new tensor with possibly normalized strides
|
||||
// gh-83069
|
||||
auto shape = src.sizes();
|
||||
view = src.as_strided(shape, {1}, src.storage_offset());
|
||||
}
|
||||
|
||||
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
|
||||
atDLMTensor->handle = view;
|
||||
atDLMTensor->handle = src;
|
||||
atDLMTensor->tensor.manager_ctx = atDLMTensor;
|
||||
atDLMTensor->tensor.deleter = &deleter<T>;
|
||||
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
|
||||
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(view.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(view.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
|
||||
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
|
||||
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
||||
fillVersion(&atDLMTensor->tensor);
|
||||
|
||||
|
Reference in New Issue
Block a user