mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Optimize multi_tensor_apply (#119153)"
This reverts commit 24be7daf799ed94e1964e2ce440ccaad15962719. Reverted https://github.com/pytorch/pytorch/pull/119153 on behalf of https://github.com/yifuwang due to This PR is breaking cuda graph for multi_tensor_apply ([comment](https://github.com/pytorch/pytorch/pull/119153#issuecomment-1939365823))
This commit is contained in:
@ -18,7 +18,7 @@ inline void increment_version(TensorList tensors) {
|
||||
}
|
||||
|
||||
// Initializes args and checks if all args are aligned
|
||||
template <int depth, typename T, template <int> class TensorListMetadata>
|
||||
template <int depth, typename T>
|
||||
__device__ bool init_args(
|
||||
T** args,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -206,7 +206,7 @@ __device__ __forceinline__ void pointwise_op_scalar(
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpScalarFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -254,7 +254,7 @@ struct BinaryOpScalarListFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpListAlphaFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -306,7 +306,7 @@ struct BinaryOpListAlphaFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpScalarTensorFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -363,7 +363,6 @@ struct BinaryOpScalarTensorFunctor {
|
||||
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct ZeroFunctor {
|
||||
template <template <int> class TensorListMetadata>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<1>& tl) {
|
||||
@ -405,7 +404,7 @@ struct ZeroFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct UnaryOpFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -457,7 +456,7 @@ struct UnaryOpFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct PointwiseOpScalarFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -505,7 +504,7 @@ struct PointwiseOpScalarListFunctor {
|
||||
template <typename T, int depth>
|
||||
struct PointwiseOpListFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -556,7 +555,7 @@ struct PointwiseOpListFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct TernaryOpListFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
@ -610,7 +609,7 @@ struct TernaryOpListFunctor {
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct TernaryOpScalarFunctor {
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op, template <int> class TensorListMetadata>
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
|
||||
@ -31,7 +31,6 @@ template <
|
||||
int res_arg_index = 0>
|
||||
struct LpNormFunctor {
|
||||
using opmath_t = typename at::opmath_type<T>;
|
||||
template <template <int> class TensorListMetadata>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
|
||||
@ -61,7 +61,6 @@ struct FusedSgdMathFunctor {
|
||||
static_assert(
|
||||
depth == 2 || depth == 3,
|
||||
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
|
||||
template <template <int> class TensorListMetadata>
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
const int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
#pragma once
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -7,12 +6,6 @@
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <vector>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
@ -46,7 +39,7 @@ __device__ __forceinline__ void load_store(
|
||||
}
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadataStatic {
|
||||
struct TensorListMetadata {
|
||||
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
@ -54,15 +47,6 @@ struct TensorListMetadataStatic {
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
const void** addresses[n];
|
||||
int64_t* numel_for_tensor;
|
||||
size_t* block_to_tensor;
|
||||
size_t* block_to_chunk;
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename scalar_vals_t, int n>
|
||||
struct TensorListScalarListMetadata {
|
||||
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
|
||||
@ -111,73 +95,6 @@ struct FusedOptimizerTensorListMetadata {
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
bool can_use_static_tensor_list_meta(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
int depth) {
|
||||
const int64_t n_tensors = tensor_lists[0].size();
|
||||
if (n_tensors > depth_to_max_tensors[depth - 1]) {
|
||||
return false;
|
||||
}
|
||||
int64_t num_blocks = 0;
|
||||
for (const auto t : c10::irange(n_tensors)) {
|
||||
const auto numel = tensor_lists[0][t].numel();
|
||||
const auto chunks = at::ceil_div(numel, kChunkSize);
|
||||
num_blocks += chunks;
|
||||
if (num_blocks > depth_to_max_blocks[depth - 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Helper for transfering multiple std::vector<T> onto device with a single
|
||||
// page-locked cudaMemcpyAsync.
|
||||
struct VecPacker {
|
||||
std::vector<const void*> ptrs;
|
||||
std::vector<size_t> sizes;
|
||||
std::vector<size_t> offsets;
|
||||
int64_t packed_numel = 0;
|
||||
at::Tensor packed;
|
||||
|
||||
template <typename T>
|
||||
// Add a vector to be copied to device
|
||||
// NOTE: VecPacker doesn't make copies of the added vectors. They have to be
|
||||
// kept alive by the caller until .pack() is called.
|
||||
void add(const std::vector<T>& vec) {
|
||||
// 16 would cover alignment for the largest known T (c10::complex)
|
||||
static const size_t alignment = 16;
|
||||
static_assert(alignment % sizeof(T) == 0);
|
||||
ptrs.push_back(vec.data());
|
||||
const auto vec_bytes = sizeof(T) * vec.size();
|
||||
const auto vec_bytes_aligned = at::round_up(vec_bytes, alignment);
|
||||
sizes.push_back(vec_bytes);
|
||||
offsets.push_back(packed_numel);
|
||||
packed_numel += vec_bytes_aligned;
|
||||
}
|
||||
|
||||
// Copy all previously added vectors onto device and return their device
|
||||
// pointers in the order they are added. We leverage the stream awareness of
|
||||
// CUDACachingAllocator to manage the lifetime of the device arguments - the
|
||||
// device memory is guaranteed to be alive as long as VecPacker is destroyed
|
||||
// after the kernel that consumes it.
|
||||
std::vector<void*> pack(const at::Device& device) {
|
||||
packed = at::empty(
|
||||
{packed_numel},
|
||||
at::TensorOptions().dtype(at::kByte).pinned_memory(true));
|
||||
for (const auto i : c10::irange(ptrs.size())) {
|
||||
memcpy(packed.data_ptr<uint8_t>() + offsets[i], ptrs[i], sizes[i]);
|
||||
}
|
||||
packed = packed.to(device, /*non_blocking=*/true);
|
||||
|
||||
std::vector<void*> dev_ptrs;
|
||||
dev_ptrs.reserve(ptrs.size());
|
||||
for (const auto offset : offsets) {
|
||||
dev_ptrs.push_back(packed.data_ptr<uint8_t>() + offset);
|
||||
}
|
||||
return dev_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
||||
__global__ void multi_tensor_apply_kernel(
|
||||
@ -298,7 +215,7 @@ void multi_tensor_apply(
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply_static(
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args) {
|
||||
@ -306,7 +223,7 @@ void multi_tensor_apply_static(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth.");
|
||||
const size_t n_tensors = tensor_lists[0].size();
|
||||
TensorListMetadataStatic<depth> tensorListMeta;
|
||||
TensorListMetadata<depth> tensorListMeta;
|
||||
tensorListMeta.start_tensor_this_launch = 0;
|
||||
|
||||
int loc_block_info = 0;
|
||||
@ -324,17 +241,49 @@ void multi_tensor_apply_static(
|
||||
}
|
||||
loc_tensor_info++;
|
||||
|
||||
// see note: [chunking territory].
|
||||
const auto numel = tensor_lists[0][t].numel();
|
||||
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
|
||||
for (auto chunk = 0; chunk < chunks; chunk++) {
|
||||
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
||||
const bool tensors_full =
|
||||
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks - 1);
|
||||
const bool blocks_full =
|
||||
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
|
||||
if (tensors_full || blocks_full) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
loc_block_info,
|
||||
kBlockSize,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(
|
||||
tensorListMeta, callable, args...);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// Reset.
|
||||
loc_block_info = 0;
|
||||
if (chunk == chunks - 1) {
|
||||
loc_tensor_info = 0;
|
||||
tensorListMeta.start_tensor_this_launch = t + 1;
|
||||
} else {
|
||||
tensorListMeta.numel_for_tensor[0] =
|
||||
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
|
||||
for (int d = 0; d < depth; d++) {
|
||||
tensorListMeta.addresses[d][0] =
|
||||
tensorListMeta.addresses[d][loc_tensor_info - 1];
|
||||
}
|
||||
loc_tensor_info = 1;
|
||||
tensorListMeta.start_tensor_this_launch = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(loc_tensor_info < depth_to_max_tensors[depth - 1]);
|
||||
TORCH_CHECK(loc_block_info < depth_to_max_blocks[depth - 1]);
|
||||
|
||||
// see note: [finishing what we started]
|
||||
if (loc_block_info != 0) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
loc_block_info,
|
||||
@ -345,96 +294,6 @@ void multi_tensor_apply_static(
|
||||
}
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
T callable,
|
||||
ArgTypes... args) {
|
||||
// Note: [static arg vs. dynamic arg]
|
||||
// Due to the dynamic nature of the workload, the kernel arguments aren't
|
||||
// guaranteed to fit in the static 4kb kernel argument memory. Previously
|
||||
// with the apex implementation, we overcame this limitation by dividing a
|
||||
// multi_tensor_apply workload into multiple kernel launches. However, this
|
||||
// led to low sustained occupancy, affecting the performance of memory bound
|
||||
// ops.
|
||||
//
|
||||
// Based on the observation that the kernel argument memory limitation
|
||||
// doesn't correlate well with available SM resources, we have adopted a
|
||||
// different approach. When the kernel arguments fit into the static kernel
|
||||
// argument memory, we use this memory to transfer the arguments. Conversely,
|
||||
// when the kernel arguments don't fit into the static kernel argument
|
||||
// memory, instead of sacrificing sustained occupancy, we use a page-locked
|
||||
// cudaMemcpyAsync to transfer the arguments, then perform the entire
|
||||
// workload in a single kernel.
|
||||
if (can_use_static_tensor_list_meta(tensor_lists, depth)) {
|
||||
multi_tensor_apply_static<depth, T, ArgTypes...>(
|
||||
tensor_lists, callable, args...);
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth.");
|
||||
const size_t n_tensors = tensor_lists[0].size();
|
||||
|
||||
std::vector<const void*> addresses[depth];
|
||||
std::vector<int64_t> numel_for_tensor;
|
||||
std::vector<size_t> block_to_tensor;
|
||||
std::vector<size_t> block_to_chunk;
|
||||
|
||||
for (int d = 0; d < depth; ++d) {
|
||||
addresses[d].reserve(n_tensors);
|
||||
}
|
||||
numel_for_tensor.reserve(n_tensors);
|
||||
block_to_tensor.reserve(n_tensors); // reserve for lowerbound
|
||||
block_to_chunk.reserve(n_tensors); // reserve for lowerbound
|
||||
|
||||
for (size_t t = 0; t < n_tensors; t++) {
|
||||
const auto numel = tensor_lists[0][t].numel();
|
||||
// short-circuit to avoid adding empty tensors to tensorListMeta
|
||||
if (numel == 0) {
|
||||
continue;
|
||||
}
|
||||
numel_for_tensor.push_back(numel);
|
||||
for (int d = 0; d < depth; d++) {
|
||||
addresses[d].push_back(tensor_lists[d][t].const_data_ptr());
|
||||
}
|
||||
const auto chunks = at::ceil_div(numel, kChunkSize);
|
||||
block_to_tensor.insert(block_to_tensor.end(), chunks, t);
|
||||
block_to_chunk.resize(block_to_chunk.size() + chunks);
|
||||
std::iota(block_to_chunk.end() - chunks, block_to_chunk.end(), 0);
|
||||
}
|
||||
|
||||
VecPacker packer;
|
||||
for (auto d = 0; d < depth; ++d) {
|
||||
packer.add(addresses[d]);
|
||||
}
|
||||
packer.add(numel_for_tensor);
|
||||
packer.add(block_to_tensor);
|
||||
packer.add(block_to_chunk);
|
||||
|
||||
auto device = tensor_lists[0][0].device();
|
||||
auto dev_ptrs = packer.pack(device);
|
||||
|
||||
TensorListMetadata<depth> tl;
|
||||
for (auto d = 0; d < depth; ++d) {
|
||||
tl.addresses[d] = static_cast<const void**>(dev_ptrs[d]);
|
||||
}
|
||||
tl.numel_for_tensor = static_cast<int64_t*>(dev_ptrs[depth]);
|
||||
tl.block_to_tensor = static_cast<size_t*>(dev_ptrs[depth + 1]);
|
||||
tl.block_to_chunk = static_cast<size_t*>(dev_ptrs[depth + 2]);
|
||||
tl.start_tensor_this_launch = 0;
|
||||
|
||||
if (block_to_tensor.size() > 0) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
block_to_tensor.size(),
|
||||
kBlockSize,
|
||||
0,
|
||||
at::cuda::getCurrentCUDAStream()>>>(tl, callable, args...);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
|
||||
template <int depth, typename T, typename... ArgTypes>
|
||||
void multi_tensor_apply_for_fused_optimizer(
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
|
||||
Reference in New Issue
Block a user