Revert "Optimize multi_tensor_apply (#119153)"

This reverts commit 24be7daf799ed94e1964e2ce440ccaad15962719.

Reverted https://github.com/pytorch/pytorch/pull/119153 on behalf of https://github.com/yifuwang due to This PR is breaking cuda graph for multi_tensor_apply ([comment](https://github.com/pytorch/pytorch/pull/119153#issuecomment-1939365823))
This commit is contained in:
PyTorch MergeBot
2024-02-12 19:11:29 +00:00
parent 8d8fb9783c
commit c24b74efc7
4 changed files with 46 additions and 190 deletions

View File

@ -18,7 +18,7 @@ inline void increment_version(TensorList tensors) {
}
// Initializes args and checks if all args are aligned
template <int depth, typename T, template <int> class TensorListMetadata>
template <int depth, typename T>
__device__ bool init_args(
T** args,
TensorListMetadata<depth>& tl,
@ -206,7 +206,7 @@ __device__ __forceinline__ void pointwise_op_scalar(
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -254,7 +254,7 @@ struct BinaryOpScalarListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpListAlphaFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -306,7 +306,7 @@ struct BinaryOpListAlphaFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarTensorFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -363,7 +363,6 @@ struct BinaryOpScalarTensorFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor {
template <template <int> class TensorListMetadata>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<1>& tl) {
@ -405,7 +404,7 @@ struct ZeroFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct UnaryOpFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -457,7 +456,7 @@ struct UnaryOpFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -505,7 +504,7 @@ struct PointwiseOpScalarListFunctor {
template <typename T, int depth>
struct PointwiseOpListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -556,7 +555,7 @@ struct PointwiseOpListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,
@ -610,7 +609,7 @@ struct TernaryOpListFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op, template <int> class TensorListMetadata>
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,

View File

@ -31,7 +31,6 @@ template <
int res_arg_index = 0>
struct LpNormFunctor {
using opmath_t = typename at::opmath_type<T>;
template <template <int> class TensorListMetadata>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth>& tl,

View File

@ -61,7 +61,6 @@ struct FusedSgdMathFunctor {
static_assert(
depth == 2 || depth == 3,
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
template <template <int> class TensorListMetadata>
C10_DEVICE __forceinline__ void operator()(
const int chunk_size,
TensorListMetadata<depth>& tl,

View File

@ -1,5 +1,4 @@
#pragma once
#include <ATen/ceil_div.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
@ -7,12 +6,6 @@
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <vector>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/empty.h>
#endif
namespace at::native {
namespace {
@ -46,7 +39,7 @@ __device__ __forceinline__ void load_store(
}
template <int n>
struct TensorListMetadataStatic {
struct TensorListMetadata {
const void* addresses[n][depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
@ -54,15 +47,6 @@ struct TensorListMetadataStatic {
int start_tensor_this_launch;
};
template <int n>
struct TensorListMetadata {
const void** addresses[n];
int64_t* numel_for_tensor;
size_t* block_to_tensor;
size_t* block_to_chunk;
int start_tensor_this_launch;
};
template <typename scalar_vals_t, int n>
struct TensorListScalarListMetadata {
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
@ -111,73 +95,6 @@ struct FusedOptimizerTensorListMetadata {
int start_tensor_this_launch;
};
bool can_use_static_tensor_list_meta(
std::vector<std::vector<at::Tensor>>& tensor_lists,
int depth) {
const int64_t n_tensors = tensor_lists[0].size();
if (n_tensors > depth_to_max_tensors[depth - 1]) {
return false;
}
int64_t num_blocks = 0;
for (const auto t : c10::irange(n_tensors)) {
const auto numel = tensor_lists[0][t].numel();
const auto chunks = at::ceil_div(numel, kChunkSize);
num_blocks += chunks;
if (num_blocks > depth_to_max_blocks[depth - 1]) {
return false;
}
}
return true;
}
// Helper for transfering multiple std::vector<T> onto device with a single
// page-locked cudaMemcpyAsync.
struct VecPacker {
std::vector<const void*> ptrs;
std::vector<size_t> sizes;
std::vector<size_t> offsets;
int64_t packed_numel = 0;
at::Tensor packed;
template <typename T>
// Add a vector to be copied to device
// NOTE: VecPacker doesn't make copies of the added vectors. They have to be
// kept alive by the caller until .pack() is called.
void add(const std::vector<T>& vec) {
// 16 would cover alignment for the largest known T (c10::complex)
static const size_t alignment = 16;
static_assert(alignment % sizeof(T) == 0);
ptrs.push_back(vec.data());
const auto vec_bytes = sizeof(T) * vec.size();
const auto vec_bytes_aligned = at::round_up(vec_bytes, alignment);
sizes.push_back(vec_bytes);
offsets.push_back(packed_numel);
packed_numel += vec_bytes_aligned;
}
// Copy all previously added vectors onto device and return their device
// pointers in the order they are added. We leverage the stream awareness of
// CUDACachingAllocator to manage the lifetime of the device arguments - the
// device memory is guaranteed to be alive as long as VecPacker is destroyed
// after the kernel that consumes it.
std::vector<void*> pack(const at::Device& device) {
packed = at::empty(
{packed_numel},
at::TensorOptions().dtype(at::kByte).pinned_memory(true));
for (const auto i : c10::irange(ptrs.size())) {
memcpy(packed.data_ptr<uint8_t>() + offsets[i], ptrs[i], sizes[i]);
}
packed = packed.to(device, /*non_blocking=*/true);
std::vector<void*> dev_ptrs;
dev_ptrs.reserve(ptrs.size());
for (const auto offset : offsets) {
dev_ptrs.push_back(packed.data_ptr<uint8_t>() + offset);
}
return dev_ptrs;
}
};
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
__global__ void multi_tensor_apply_kernel(
@ -298,7 +215,7 @@ void multi_tensor_apply(
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply_static(
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args) {
@ -306,7 +223,7 @@ void multi_tensor_apply_static(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
TensorListMetadataStatic<depth> tensorListMeta;
TensorListMetadata<depth> tensorListMeta;
tensorListMeta.start_tensor_this_launch = 0;
int loc_block_info = 0;
@ -324,17 +241,49 @@ void multi_tensor_apply_static(
}
loc_tensor_info++;
// see note: [chunking territory].
const auto numel = tensor_lists[0][t].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
for (auto chunk = 0; chunk < chunks; chunk++) {
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
const bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
(loc_block_info == depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
loc_block_info,
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(
tensorListMeta, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Reset.
loc_block_info = 0;
if (chunk == chunks - 1) {
loc_tensor_info = 0;
tensorListMeta.start_tensor_this_launch = t + 1;
} else {
tensorListMeta.numel_for_tensor[0] =
tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
for (int d = 0; d < depth; d++) {
tensorListMeta.addresses[d][0] =
tensorListMeta.addresses[d][loc_tensor_info - 1];
}
loc_tensor_info = 1;
tensorListMeta.start_tensor_this_launch = t;
}
}
}
}
TORCH_CHECK(loc_tensor_info < depth_to_max_tensors[depth - 1]);
TORCH_CHECK(loc_block_info < depth_to_max_blocks[depth - 1]);
// see note: [finishing what we started]
if (loc_block_info != 0) {
multi_tensor_apply_kernel<<<
loc_block_info,
@ -345,96 +294,6 @@ void multi_tensor_apply_static(
}
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(
std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args) {
// Note: [static arg vs. dynamic arg]
// Due to the dynamic nature of the workload, the kernel arguments aren't
// guaranteed to fit in the static 4kb kernel argument memory. Previously
// with the apex implementation, we overcame this limitation by dividing a
// multi_tensor_apply workload into multiple kernel launches. However, this
// led to low sustained occupancy, affecting the performance of memory bound
// ops.
//
// Based on the observation that the kernel argument memory limitation
// doesn't correlate well with available SM resources, we have adopted a
// different approach. When the kernel arguments fit into the static kernel
// argument memory, we use this memory to transfer the arguments. Conversely,
// when the kernel arguments don't fit into the static kernel argument
// memory, instead of sacrificing sustained occupancy, we use a page-locked
// cudaMemcpyAsync to transfer the arguments, then perform the entire
// workload in a single kernel.
if (can_use_static_tensor_list_meta(tensor_lists, depth)) {
multi_tensor_apply_static<depth, T, ArgTypes...>(
tensor_lists, callable, args...);
return;
}
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
std::vector<const void*> addresses[depth];
std::vector<int64_t> numel_for_tensor;
std::vector<size_t> block_to_tensor;
std::vector<size_t> block_to_chunk;
for (int d = 0; d < depth; ++d) {
addresses[d].reserve(n_tensors);
}
numel_for_tensor.reserve(n_tensors);
block_to_tensor.reserve(n_tensors); // reserve for lowerbound
block_to_chunk.reserve(n_tensors); // reserve for lowerbound
for (size_t t = 0; t < n_tensors; t++) {
const auto numel = tensor_lists[0][t].numel();
// short-circuit to avoid adding empty tensors to tensorListMeta
if (numel == 0) {
continue;
}
numel_for_tensor.push_back(numel);
for (int d = 0; d < depth; d++) {
addresses[d].push_back(tensor_lists[d][t].const_data_ptr());
}
const auto chunks = at::ceil_div(numel, kChunkSize);
block_to_tensor.insert(block_to_tensor.end(), chunks, t);
block_to_chunk.resize(block_to_chunk.size() + chunks);
std::iota(block_to_chunk.end() - chunks, block_to_chunk.end(), 0);
}
VecPacker packer;
for (auto d = 0; d < depth; ++d) {
packer.add(addresses[d]);
}
packer.add(numel_for_tensor);
packer.add(block_to_tensor);
packer.add(block_to_chunk);
auto device = tensor_lists[0][0].device();
auto dev_ptrs = packer.pack(device);
TensorListMetadata<depth> tl;
for (auto d = 0; d < depth; ++d) {
tl.addresses[d] = static_cast<const void**>(dev_ptrs[d]);
}
tl.numel_for_tensor = static_cast<int64_t*>(dev_ptrs[depth]);
tl.block_to_tensor = static_cast<size_t*>(dev_ptrs[depth + 1]);
tl.block_to_chunk = static_cast<size_t*>(dev_ptrs[depth + 2]);
tl.start_tensor_this_launch = 0;
if (block_to_tensor.size() > 0) {
multi_tensor_apply_kernel<<<
block_to_tensor.size(),
kBlockSize,
0,
at::cuda::getCurrentCUDAStream()>>>(tl, callable, args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply_for_fused_optimizer(
std::vector<std::vector<at::Tensor>>& tensor_lists,