mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
empty_strided: Factor out generic implementation (#70614)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70614 This creates an `empty_strided_generic` function which, similar to `empty_generic`, is a device-independent tensor constructor. This also adds `at::detail::empty_strided_cpu` to complement `at::detail::empty_cpu`. Test Plan: Imported from OSS Reviewed By: samdow Differential Revision: D33623679 Pulled By: ngimel fbshipit-source-id: 85994e88d664870bf425f398dfcdfc467885c694 (cherry picked from commit 2ff2a89df5752cfad667463aa3c3bffe8479ec9a)
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5e9a276ea
commit
87215ed526
@ -19,6 +19,22 @@ void check_size_nonnegative(IntArrayRef size) {
|
||||
}
|
||||
}
|
||||
|
||||
size_t computeStorageNbytes(
|
||||
IntArrayRef sizes,
|
||||
IntArrayRef strides,
|
||||
size_t itemsize_bytes) {
|
||||
// size of the underlying storage is 1 bigger than the offset
|
||||
// of the last element according to stride
|
||||
size_t size = 1;
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
if(sizes[i] == 0) {
|
||||
return 0;
|
||||
}
|
||||
size += strides[i]*(sizes[i]-1);
|
||||
}
|
||||
return size * itemsize_bytes;
|
||||
}
|
||||
|
||||
TensorBase empty_generic(
|
||||
IntArrayRef size,
|
||||
c10::Allocator* allocator,
|
||||
@ -54,6 +70,29 @@ TensorBase empty_generic(
|
||||
return tensor;
|
||||
}
|
||||
|
||||
TensorBase empty_strided_generic(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::Allocator* allocator,
|
||||
c10::DispatchKeySet ks,
|
||||
ScalarType scalar_type) {
|
||||
at::detail::check_size_nonnegative(size);
|
||||
|
||||
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
|
||||
int64_t size_bytes = computeStorageNbytes(size, stride, dtype.itemsize());
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
allocator->allocate(size_bytes),
|
||||
allocator,
|
||||
/*resizeable=*/true);
|
||||
|
||||
auto tensor = detail::make_tensor_base<TensorImpl>(
|
||||
std::move(storage_impl), ks, dtype);
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
TensorBase empty_cpu(IntArrayRef size, ScalarType dtype, bool pin_memory,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
auto allocator = GetCPUAllocatorMaybePinned(pin_memory);
|
||||
@ -88,4 +127,41 @@ TensorBase empty_cpu(
|
||||
options.memory_format_opt());
|
||||
}
|
||||
|
||||
TensorBase empty_strided_cpu(IntArrayRef size, IntArrayRef stride,
|
||||
ScalarType dtype, bool pin_memory) {
|
||||
auto allocator = at::detail::GetCPUAllocatorMaybePinned(pin_memory);
|
||||
constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
|
||||
return at::detail::empty_strided_generic(
|
||||
size, stride, allocator, cpu_ks, dtype);
|
||||
}
|
||||
|
||||
TensorBase empty_strided_cpu(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt) {
|
||||
auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::CPU);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout_or_default(layout_opt) == Layout::Strided);
|
||||
|
||||
auto pin_memory = pinned_memory_or_default(pin_memory_opt);
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
return at::detail::empty_strided_cpu(size, stride, dtype, pin_memory);
|
||||
}
|
||||
|
||||
TensorBase empty_strided_cpu(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
const TensorOptions &options) {
|
||||
return at::detail::empty_strided_cpu(
|
||||
size,
|
||||
stride,
|
||||
optTypeMetaToScalarType(options.dtype_opt()),
|
||||
options.layout_opt(),
|
||||
options.device_opt(),
|
||||
options.pinned_memory_opt());
|
||||
}
|
||||
|
||||
}} // namespace at::detail
|
||||
|
@ -5,6 +5,8 @@ namespace at {
|
||||
namespace detail {
|
||||
|
||||
TORCH_API void check_size_nonnegative(IntArrayRef size);
|
||||
TORCH_API size_t computeStorageNbytes(
|
||||
IntArrayRef sizes, IntArrayRef strides, size_t itemsize);
|
||||
|
||||
TORCH_API TensorBase empty_generic(
|
||||
IntArrayRef size,
|
||||
@ -13,6 +15,13 @@ TORCH_API TensorBase empty_generic(
|
||||
ScalarType scalar_type,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt);
|
||||
|
||||
TORCH_API TensorBase empty_strided_generic(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::Allocator* allocator,
|
||||
c10::DispatchKeySet ks,
|
||||
ScalarType scalar_type);
|
||||
|
||||
TORCH_API TensorBase empty_cpu(
|
||||
IntArrayRef size,
|
||||
ScalarType dtype,
|
||||
@ -31,4 +40,23 @@ TORCH_API TensorBase empty_cpu(
|
||||
IntArrayRef size,
|
||||
const TensorOptions &options);
|
||||
|
||||
TORCH_API TensorBase empty_strided_cpu(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
ScalarType dtype,
|
||||
bool pin_memory=false);
|
||||
|
||||
TORCH_API TensorBase empty_strided_cpu(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt);
|
||||
|
||||
TORCH_API TensorBase empty_strided_cpu(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
const TensorOptions &options);
|
||||
|
||||
}} // namespace at::detail
|
||||
|
@ -317,22 +317,6 @@ std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
|
||||
return strides;
|
||||
}
|
||||
|
||||
size_t computeStorageNbytes(
|
||||
IntArrayRef sizes,
|
||||
IntArrayRef strides,
|
||||
size_t itemsize_bytes) {
|
||||
// size of the underlying storage is 1 bigger than the offset
|
||||
// of the last element according to stride
|
||||
size_t size = 1;
|
||||
for (const auto i : c10::irange(sizes.size())) {
|
||||
if(sizes[i] == 0) {
|
||||
return 0;
|
||||
}
|
||||
size += strides[i]*(sizes[i]-1);
|
||||
}
|
||||
return size * itemsize_bytes;
|
||||
}
|
||||
|
||||
// On a high level,
|
||||
// 1. separate `oldshape` into chunks of dimensions, where the dimensions are
|
||||
// ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] *
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/DimVector.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorGeometry.h>
|
||||
#include <ATen/Utils.h>
|
||||
@ -152,8 +153,6 @@ TORCH_API void check_dim_size(
|
||||
|
||||
namespace detail {
|
||||
TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
|
||||
TORCH_API size_t
|
||||
computeStorageNbytes(IntArrayRef sizes, IntArrayRef strides, size_t itemsize);
|
||||
|
||||
TORCH_API c10::optional<std::vector<int64_t>> computeStride(
|
||||
IntArrayRef oldshape,
|
||||
|
@ -60,12 +60,20 @@ Tensor empty_strided_meta(
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt
|
||||
) {
|
||||
auto t = at::native::empty_meta({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
// Amazingly the CPU implementation will work for us, because most of resize
|
||||
// is generic except the memcpy, but the memcpy will be skipped if the source
|
||||
// storage is nullptr (which it always is, for meta tensors)
|
||||
at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta);
|
||||
// NB: because there is no SparseMeta (yet), non-strided layout is
|
||||
// exerciseable
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
layout_or_default(layout_opt) == Layout::Strided,
|
||||
"strided meta tensors not supported yet"
|
||||
);
|
||||
|
||||
auto* allocator = GetMetaAllocator();
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
constexpr c10::DispatchKeySet meta_ks(c10::DispatchKey::Meta);
|
||||
return at::detail::empty_strided_generic(
|
||||
size, stride, allocator, meta_ks, dtype);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
@ -201,10 +201,7 @@ Tensor empty(
|
||||
|
||||
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_size_nonnegative(size);
|
||||
auto t = at::native::empty_cpu({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
}
|
||||
|
||||
Tensor& empty_out(IntArrayRef size,
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/EmptyTensor.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
@ -65,9 +66,13 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
|
||||
}
|
||||
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
auto t = at::native::empty_cuda({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
TORCH_CHECK(device_or_default(device_opt).is_cuda());
|
||||
TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned");
|
||||
auto* allocator = at::cuda::getCUDADeviceAllocator();
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
constexpr c10::DispatchKeySet cuda_ks(c10::DispatchKey::CUDA);
|
||||
return at::detail::empty_strided_generic(
|
||||
size, stride, allocator, cuda_ks, dtype);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
Reference in New Issue
Block a user