Revert "Add support for 32KB multi_tensor_apply kernel arguments (#134373)"

This reverts commit 08184aa85cf183198ebdf2fd7a49fe7bc4842c13.

Reverted https://github.com/pytorch/pytorch/pull/134373 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/135126 for more details ([comment](https://github.com/pytorch/pytorch/pull/134373#issuecomment-2329839011))
This commit is contained in:
PyTorch MergeBot
2024-09-04 19:44:29 +00:00
parent dd7cd182ab
commit 741d52c69f
19 changed files with 469 additions and 851 deletions

View File

@ -157,24 +157,21 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
using opmath_t = at::opmath_type<scalar_t>;
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(tensor_lists,
UnaryOpFunctor<scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>(),
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
// There is a slight asymmetry here with the TensorIterator kernel above.
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
if (!isfinite_ensure_cuda_math(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
});
});
multi_tensor_apply<1>(tensor_lists,
UnaryOpFunctor<scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
// There is a slight asymmetry here with the TensorIterator kernel above.
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
if (!isfinite_ensure_cuda_math(val)) {
*found_inf_ptr = 1.f;
}
// Every thread accesses inv_scale, but it will hit in cache.
const auto inv_scale_val = *inv_scale_ptr;
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
});
});
}

View File

@ -41,18 +41,15 @@ std::vector<Tensor> foreach_tensor_list_op(
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2,
large_kernel_arg>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
});
multi_tensor_apply<3>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
return tensor_lists[2];
}
@ -67,18 +64,15 @@ void foreach_tensor_list_op_(
tensor_lists.emplace_back(tensors2.vec());
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
BinaryOpListAlphaFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
alpha.to<opmath_t>());
increment_version(tensors1);
}
@ -337,15 +331,13 @@ template <
typename src_t,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
int res_arg_index>
struct CopyFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -428,17 +420,14 @@ void foreach_tensor_copy_list_kernel_cuda_(
using opmath_t = at::opmath_type<scalar_t>;
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
if constexpr (std::is_same_v<scalar_t, src_t>) {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
Copy<opmath_t, opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<opmath_t, opmath_t>());
} else {
// Ref:
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
@ -446,18 +435,15 @@ void foreach_tensor_copy_list_kernel_cuda_(
TORCH_WARN_ONCE(
"Casting complex values to real discards the imaginary part");
}
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
CopyFunctor<
scalar_t,
src_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
Copy<scalar_t, src_t>());
});
multi_tensor_apply<2>(
tensor_lists,
CopyFunctor<
scalar_t,
src_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Copy<scalar_t, src_t>());
}
});
});

View File

@ -36,18 +36,15 @@ std::vector<Tensor> foreach_binary_op(
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
BinaryOpScalarFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
BinaryOpScalarFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
return tensor_lists[1];
}
@ -57,18 +54,15 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) {
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
BinaryOpScalarFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
multi_tensor_apply<1>(
tensor_lists,
BinaryOpScalarFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
increment_version(tensors);
}

View File

@ -36,19 +36,16 @@ std::vector<Tensor> foreach_binary_op(
tensor_lists.emplace_back(vec_res);
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2, opmath_t>(
tensor_lists,
scalars,
BinaryOpScalarListFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
multi_tensor_apply<2, opmath_t>(
tensor_lists,
scalars,
BinaryOpScalarListFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>());
});
Op<opmath_t>());
return tensor_lists[1];
}
@ -58,18 +55,15 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1, opmath_t>(
tensor_lists,
scalars,
BinaryOpScalarListFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>());
});
multi_tensor_apply<1, opmath_t>(
tensor_lists,
scalars,
BinaryOpScalarListFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
increment_version(tensors);
}

View File

@ -46,19 +46,16 @@ std::vector<Tensor> foreach_binary_op(
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>(),
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
return tensor_lists[1];
}
@ -84,19 +81,16 @@ void foreach_binary_op_(
tensor_lists.emplace_back(tensors.vec());
using opmath_t = at::opmath_type<T>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
});
multi_tensor_apply<1>(
tensor_lists,
BinaryOpScalarTensorFunctor<
T,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
scalar.data_ptr<T>(),
alpha.to<opmath_t>());
increment_version(tensors);
}

View File

@ -18,10 +18,10 @@ inline void increment_version(TensorList tensors) {
}
// Initializes args and checks if all args are aligned
template <int depth, typename T, bool large_kernel_arg>
template <int depth, typename T>
__device__ bool init_args(
T** args,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@ -38,10 +38,10 @@ __device__ bool init_args(
}
// Initializes args and checks if all args are aligned
template <int depth, typename T, typename T2, bool large_kernel_arg>
template <int depth, typename T, typename T2>
__device__ bool init_args(
T** args,
TensorListScalarListMetadata<T2, depth, large_kernel_arg>& tl,
TensorListScalarListMetadata<T2, depth>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@ -57,10 +57,10 @@ __device__ bool init_args(
return all_aligned;
}
template <int depth, typename T, bool large_kernel_arg>
template <int depth, typename T>
__device__ bool init_args(
T** args,
FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
FusedOptimizerTensorListMetadata<depth>& tl,
const int64_t chunk_idx,
const int64_t chunk_size,
const int64_t tensor_loc) {
@ -203,19 +203,13 @@ __device__ __forceinline__ void pointwise_op_scalar(
//
// Binary Functors
//
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -233,19 +227,13 @@ struct BinaryOpScalarFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarListFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -263,19 +251,13 @@ struct BinaryOpScalarListFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpListAlphaFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op,
opmath_t alpha) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -321,19 +303,13 @@ struct BinaryOpListAlphaFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct BinaryOpScalarTensorFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op,
T* scalar,
opmath_t alpha) {
@ -385,17 +361,11 @@ struct BinaryOpScalarTensorFunctor {
// Unary Functors
//
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<1, large_kernel_arg>& tl) {
TensorListMetadata<1>& tl) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
auto n = tl.numel_for_tensor[tensor_loc];
@ -431,19 +401,13 @@ struct ZeroFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct UnaryOpFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -489,19 +453,13 @@ struct UnaryOpFunctor {
// Pointwise Functors
//
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -519,19 +477,13 @@ struct PointwiseOpScalarFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct PointwiseOpScalarListFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -549,14 +501,13 @@ struct PointwiseOpScalarListFunctor {
}
};
template <typename T, int depth, bool large_kernel_arg>
template <typename T, int depth>
struct PointwiseOpListFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -601,19 +552,13 @@ struct PointwiseOpListFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpListFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op) {
static_assert(depth == 3 || depth == 4, "");
static_assert(depth >= r_args_depth, "");
@ -661,19 +606,13 @@ struct TernaryOpListFunctor {
}
};
template <
typename T,
int depth,
int r_args_depth,
int res_arg_index,
bool large_kernel_arg>
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct TernaryOpScalarFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
Op op,
opmath_t alpha) {
static_assert(depth == 2 || depth == 3, "");

View File

@ -46,18 +46,15 @@ std::vector<Tensor> foreach_pointwise_op(
"foreach_pointwise_op_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<4>(
tensor_lists,
PointwiseOpScalarFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
multi_tensor_apply<4>(
tensor_lists,
PointwiseOpScalarFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
return tensor_lists[3];
@ -81,18 +78,15 @@ void foreach_pointwise_op_(
"foreach_pointwise_op__cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
PointwiseOpScalarFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
multi_tensor_apply<3>(
tensor_lists,
PointwiseOpScalarFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
Op<opmath_t>(),
scalar.to<opmath_t>());
});
increment_version(input);
}
@ -116,18 +110,15 @@ void foreach_pointwise_op_(
"foreach_pointwise_op__cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3, opmath_t>(
tensor_lists,
scalars,
PointwiseOpScalarListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>());
});
multi_tensor_apply<3, opmath_t>(
tensor_lists,
scalars,
PointwiseOpScalarListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
Op<opmath_t>());
});
increment_version(input);
}
@ -158,18 +149,15 @@ std::vector<Tensor> foreach_pointwise_op(
"foreach_pointwise_op_cuda",
[&]() {
using opmath_t = at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<4, opmath_t>(
tensor_lists,
scalars,
PointwiseOpScalarListFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3,
large_kernel_arg>(),
Op<opmath_t>());
});
multi_tensor_apply<4, opmath_t>(
tensor_lists,
scalars,
PointwiseOpScalarListFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
Op<opmath_t>());
});
return tensor_lists[3];

View File

@ -50,13 +50,11 @@ template <
typename T,
int depth = 1,
int r_args_depth = 1,
int res_arg_index = 0,
bool large_kernel_arg = false>
int res_arg_index = 0>
struct LpMaxFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
T* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -180,13 +178,11 @@ std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
tensor_lists[0][0].scalar_type(),
"foreach_tensor_max_cuda_scalar_type",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
LpMaxFunctor<scalar_t, 1, 1, 0, large_kernel_arg>(),
output_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
});
multi_tensor_apply<1>(
tensor_lists,
LpMaxFunctor<scalar_t>(),
output_per_tensor.mutable_data_ptr<scalar_t>(),
max_chunks_per_tensor);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(
@ -243,14 +239,12 @@ template <
typename out_t,
int depth = 1,
int r_args_depth = 1,
int res_arg_index = 0,
bool large_kernel_arg = false>
int res_arg_index = 0>
struct LpNormFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
using out_opmath_t = typename at::opmath_type<out_t>;
__device__ __forceinline__ void operator()(
int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
out_opmath_t* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -482,50 +476,23 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
using out_opmath_t = typename at::opmath_type<out_t>;
if (p == static_cast<double>(1)) {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<
scalar_t,
NormType::L1,
out_t,
1,
1,
0,
large_kernel_arg>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
});
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, NormType::L1, out_t>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
} else if (p == static_cast<double>(2)) {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<
scalar_t,
NormType::L2,
out_t,
1,
1,
0,
large_kernel_arg>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
});
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, NormType::L2, out_t>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
} else if (p == std::numeric_limits<double>::infinity()) {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<
scalar_t,
NormType::LInf,
out_t,
1,
1,
0,
large_kernel_arg>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
});
multi_tensor_apply<1>(
tensor_lists,
LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
max_chunks_per_tensor);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
const at::cuda::OptionalCUDAGuard device_guard(

View File

@ -46,17 +46,14 @@ std::vector<at::Tensor> foreach_tensor_lerp_ternary_cuda(
"foreach_tensor_lerp_ternary_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<4>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3,
large_kernel_arg>(),
LerpFunctor<opmath_t>());
});
multi_tensor_apply<4>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 4,
/* r_args_depth */ 3,
/* res_arg_index */ 3>(),
LerpFunctor<opmath_t>());
});
return tensor_lists[3];
@ -80,17 +77,14 @@ void foreach_tensor_lerp_ternary_cuda_(
"foreach_tensor_lerp_ternary_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0,
large_kernel_arg>(),
LerpFunctor<opmath_t>());
});
multi_tensor_apply<3>(
tensor_lists,
TernaryOpListFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 3,
/* res_arg_index */ 0>(),
LerpFunctor<opmath_t>());
});
increment_version(tensors1);
}
@ -119,18 +113,15 @@ std::vector<at::Tensor> foreach_tensor_lerp_list_cuda(
"foreach_tensor_lerp_scalar_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2,
large_kernel_arg>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
multi_tensor_apply<3>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 3,
/* r_args_depth */ 2,
/* res_arg_index */ 2>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
return tensor_lists[2];
@ -154,18 +145,15 @@ void foreach_tensor_lerp_list_cuda_(
"foreach_tensor_lerp_scalar_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0,
large_kernel_arg>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
TernaryOpScalarFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 2,
/* res_arg_index */ 0>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
});
}
} // namespace at::native

View File

@ -56,17 +56,14 @@ std::vector<Tensor> foreach_unary_op(TensorList tensors) {
tensor_lists.emplace_back(std::move(vec_res));
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1,
large_kernel_arg>(),
Op<opmath_t>());
});
multi_tensor_apply<2>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 2,
/* r_args_depth */ 1,
/* res_arg_index */ 1>(),
Op<opmath_t>());
return tensor_lists[1];
}
@ -76,17 +73,14 @@ void foreach_unary_op_(TensorList tensors) {
std::vector<std::vector<at::Tensor>> tensor_lists;
tensor_lists.emplace_back(tensors.vec());
using opmath_t = typename at::opmath_type<scalar_t>;
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>(),
Op<opmath_t>());
});
multi_tensor_apply<1>(
tensor_lists,
UnaryOpFunctor<
scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>(),
Op<opmath_t>());
increment_version(tensors);
}
@ -401,16 +395,13 @@ void foreach_tensor_zero_cuda_(TensorList tensors) {
tensors[0].scalar_type(),
"foreach_zero_cuda_",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<1>(
tensor_lists,
ZeroFunctor<
scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0,
large_kernel_arg>());
});
multi_tensor_apply<1>(
tensor_lists,
ZeroFunctor<
scalar_t,
/* depth */ 1,
/* r_args_depth */ 1,
/* res_arg_index */ 0>());
});
}

View File

@ -56,15 +56,14 @@ C10_DEVICE __forceinline__ void sgd_math(
}
}
template <typename scalar_t, int depth, bool large_kernel_arg>
template <typename scalar_t, int depth>
struct FusedSgdMathFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(
depth == 2 || depth == 3,
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
C10_DEVICE __forceinline__ void operator()(
const int chunk_size,
TensorListMetadata<depth, large_kernel_arg>& tl,
TensorListMetadata<depth>& tl,
const double weight_decay,
const double momentum,
const float* lr_ptr,
@ -173,21 +172,19 @@ void _fused_sgd_with_momentum_kernel_cuda_(
params[0].scalar_type(),
"fused_sgd_with_momentum_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
weight_decay,
momentum,
lr_ptr,
lr,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply<3>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 3>(),
weight_decay,
momentum,
lr_ptr,
lr,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -249,21 +246,19 @@ void _fused_sgd_with_momentum_kernel_cuda_(
params[0].scalar_type(),
"fused_sgd_with_momentum_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<3>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
weight_decay,
momentum,
lr.data_ptr<float>(),
1.0,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply<3>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 3>(),
weight_decay,
momentum,
lr.data_ptr<float>(),
1.0,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -317,21 +312,19 @@ void _fused_sgd_kernel_cuda_(
params[0].scalar_type(),
"fused_sgd_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
weight_decay,
momentum,
lr_ptr,
lr,
dampening,
nesterov,
maximize,
/* is_first_step */ false,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply<2>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 2>(),
weight_decay,
momentum,
lr_ptr,
lr,
dampening,
nesterov,
maximize,
/* is_first_step */ false,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -411,21 +404,19 @@ void _fused_sgd_kernel_cuda_(
params[0].scalar_type(),
"fused_sgd_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply<2>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
weight_decay,
momentum,
lr.data_ptr<float>(),
1.0,
dampening,
nesterov,
maximize,
/* is_first_step */ false,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply<2>(
tensor_lists,
FusedSgdMathFunctor<scalar_t, 2>(),
weight_decay,
momentum,
lr.data_ptr<float>(),
1.0,
dampening,
nesterov,
maximize,
/* is_first_step */ false,
grad_scale_ptr,
found_inf_ptr);
});
}

View File

@ -1,25 +0,0 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <cuda_runtime.h>
namespace at::native {
bool supports_large_kernel_arg() {
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010
static std::optional<bool> supports_large_kernel_arg_ = std::nullopt;
if (!supports_large_kernel_arg_.has_value()) {
int driver_ver = 0;
AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver));
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7;
}
const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() !=
at::cuda::CaptureStatus::None;
return !is_capturing && *supports_large_kernel_arg_;
#else
return false;
#endif
}
} // namespace at::native

View File

@ -8,105 +8,20 @@
namespace at::native {
// NOTE: [32KB kernel argument size support]
// 32KB kernel argument size support has three requirements:
// - CUDART_VERSION >= 12010
// - Driver version >= 530
// - GPU arch >= VOLTA
//
// Due to minor version compatibility, it possible for binaries built with
// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
// to build both 4KB and 32KB kernels and determine the appropriate kernel to
// dispatch at runtime.
//
// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
//
// - If CUDART_VERSION >= 12010:
// - Host code:
// - We always instantiate the launching stub for both 4KB and 32KB kernels.
// - Device code:
// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
// kernels.
// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
// 32KB kernel (formal parameter space overflowed). Thus, we only
// instantiate a declaration for 32KB kernels. This is valid as long as the
// declaration-only kernel is not launched.
//
// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
// GPU arch >= VOLTA.
//
// - TODO(yifu): once there's a CUDART version that is not compatible with any
// driver version below 530, we can determine at compile time to not compile
// the kernels for 4KB kernel argument size.
//
// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
bool supports_large_kernel_arg();
namespace {
static constexpr int64_t kILP = 4;
static constexpr int64_t kChunkSize = 65536;
static constexpr int64_t kBlockSize = 512;
// MSVC has a problem with constexpr and can't handle passing them to templates
// as arguments. We need to replace it with const static.
// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662
#if !defined(_WIN32)
#define SWITCH_TYPE constexpr bool
#else
#define SWITCH_TYPE const static bool
#endif
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \
CUDART_VERSION >= 12010
#define DISPATCH_MULTI_TENSOR_APPLY(...) \
if (at::native::supports_large_kernel_arg()) { \
SWITCH_TYPE large_kernel_arg C10_UNUSED = true; \
__VA_ARGS__(); \
} else { \
SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
__VA_ARGS__(); \
}
#else
#define DISPATCH_MULTI_TENSOR_APPLY(...) \
do { \
SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
__VA_ARGS__(); \
} while (0);
#endif
template <bool large_kernel_arg>
struct DepthToMaxConfig;
template <>
struct DepthToMaxConfig<false> {
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
static constexpr int depth_to_max_tensors_scalarlist[5] =
{96, 64, 48, 36, 30};
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
72,
60};
using TensorIdxType = unsigned char;
};
template <>
struct DepthToMaxConfig<true> {
// TODO(yifu): These values are not yet optimally tuned. I simply multiplied
// the values tuned for 4KB kernel argument size limit by 7 (the kernel
// argument size limit increased by 8x but we need to change the type of
// block_to_tensor from unsigned char to uint16_t to support larger number of
// tensors).
static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210};
static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240};
static constexpr int depth_to_max_tensors_scalarlist[5] =
{672, 448, 336, 252, 210};
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
504,
420};
using TensorIdxType = uint16_t;
};
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
72,
60};
template <typename T>
__device__ __forceinline__ bool is_aligned(T* p) {
@ -123,101 +38,73 @@ __device__ __forceinline__ void load_store(
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template <int n, bool large_kernel_arg>
template <int n>
struct TensorListMetadata {
using Conf = DepthToMaxConfig<large_kernel_arg>;
const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
typename Conf::TensorIdxType
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
const void* addresses[n][depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
template <typename scalar_vals_t, int n, bool large_kernel_arg>
template <typename scalar_vals_t, int n>
struct TensorListScalarListMetadata {
using Conf = DepthToMaxConfig<large_kernel_arg>;
const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]];
int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]];
scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]];
typename Conf::TensorIdxType
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
};
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
// 4kb with `c10::complex<double>`
template <bool large_kernel_arg>
struct TensorListScalarListMetadata<c10::complex<double>, 1, large_kernel_arg> {
using Conf = DepthToMaxConfig<large_kernel_arg>;
const void*
addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
int64_t numel_for_tensor
[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
template <>
struct TensorListScalarListMetadata<c10::complex<double>, 1> {
const void* addresses[1]
[depth_to_max_tensors_scalarlist_of_complex_double[0]];
int64_t
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
c10::complex<double>
scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
typename Conf::TensorIdxType
block_to_tensor[Conf::depth_to_max_blocks[1 - 1]];
int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]];
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
int block_to_chunk[depth_to_max_blocks[1 - 1]];
};
template <bool large_kernel_arg>
struct TensorListScalarListMetadata<c10::complex<double>, 2, large_kernel_arg> {
using Conf = DepthToMaxConfig<large_kernel_arg>;
const void*
addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
int64_t numel_for_tensor
[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
template <>
struct TensorListScalarListMetadata<c10::complex<double>, 2> {
const void* addresses[2]
[depth_to_max_tensors_scalarlist_of_complex_double[1]];
int64_t
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
c10::complex<double>
scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
typename Conf::TensorIdxType
block_to_tensor[Conf::depth_to_max_blocks[2 - 1]];
int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]];
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
int block_to_chunk[depth_to_max_blocks[2 - 1]];
};
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
// whose each element is `at::Tensor` of 1 element representing the number of
// `step`s called so far.
template <int n, bool large_kernel_arg>
template <int n>
struct FusedOptimizerTensorListMetadata {
using Conf = DepthToMaxConfig<large_kernel_arg>;
const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
const void*
state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]];
typename Conf::TensorIdxType
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
const void* addresses[n][depth_to_max_tensors[n - 1]];
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]];
int start_tensor_this_launch;
};
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
__global__ void multi_tensor_apply_kernel(
T tensorListMeta,
U callable,
ArgTypes... args) {
// Hand the chunk information to the user-supplied functor to process however
// it likes.
callable(kChunkSize, tensorListMeta, args...);
}
#else
// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a
// declaration for the 32KB kernels.
// For details see: [32KB kernel argument size support]
#pragma nv_diag_suppress 114 // Function was referenced but not defined
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args);
#pragma nv_diag_default 114 // Function was referenced but not defined
#endif
template <typename T, typename U, typename... ArgTypes>
C10_LAUNCH_BOUNDS_1(kBlockSize)
__global__ typename std::enable_if<!U::use_large_kernel_arg, void>::type
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
callable(kChunkSize, tensorListMeta, args...);
}
} // namespace
@ -246,10 +133,7 @@ void multi_tensor_apply(
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
using scalar_vals_t = typename T::opmath_t;
TensorListScalarListMetadata<scalar_vals_t, depth, T::use_large_kernel_arg>
tensorListMeta;
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
int loc_block_info = 0;
int loc_tensor_info = 0;
@ -283,11 +167,10 @@ void multi_tensor_apply(
// a tensor is not considered full unless all its chunks have been
// processed
const bool tensors_full =
(loc_tensor_info ==
Conf::depth_to_max_tensors_scalarlist[depth - 1] &&
(loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
(loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
(loc_block_info == depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
@ -340,11 +223,9 @@ void multi_tensor_apply(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth.");
const size_t n_tensors = tensor_lists[0].size();
TensorListMetadata<depth, T::use_large_kernel_arg> tensorListMeta;
TensorListMetadata<depth> tensorListMeta;
tensorListMeta.start_tensor_this_launch = 0;
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (size_t t = 0; t < n_tensors; t++) {
@ -369,10 +250,10 @@ void multi_tensor_apply(
loc_block_info++;
const bool tensors_full =
(loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const bool blocks_full =
(loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
(loc_block_info == depth_to_max_blocks[depth - 1]);
if (tensors_full || blocks_full) {
multi_tensor_apply_kernel<<<
@ -423,10 +304,7 @@ void multi_tensor_apply_for_fused_optimizer(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
const auto num_tensors = tensor_lists[0].size();
FusedOptimizerTensorListMetadata<depth, T::use_large_kernel_arg>
tensorListMeta;
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
FusedOptimizerTensorListMetadata<depth> tensorListMeta;
int loc_block_info = 0;
int loc_tensor_info = 0;
@ -455,10 +333,9 @@ void multi_tensor_apply_for_fused_optimizer(
loc_block_info++;
const auto tensor_full =
(loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks - 1);
const auto blocks_full =
loc_block_info == Conf::depth_to_max_blocks[depth - 1];
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
if (tensor_full || blocks_full) {
multi_tensor_apply_kernel<<<

View File

@ -42,26 +42,19 @@ void _fused_adam_amsgrad_cuda_impl_(
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
5,
ADAM_MODE::ORIGINAL,
true,
large_kernel_arg>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -100,26 +93,19 @@ void _fused_adam_amsgrad_cuda_impl_(
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
5,
ADAM_MODE::ORIGINAL,
true,
large_kernel_arg>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}

View File

@ -37,26 +37,19 @@ void _fused_adam_cuda_impl_(
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
4,
ADAM_MODE::ORIGINAL,
false,
large_kernel_arg>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -90,26 +83,19 @@ void _fused_adam_cuda_impl_(
params[0].scalar_type(),
"fused_adam_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
4,
ADAM_MODE::ORIGINAL,
false,
large_kernel_arg>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}

View File

@ -102,21 +102,15 @@ C10_DEVICE inline void adam_math(
// parameter updates accordingly. To be functionally on par with `torch.optim`
// optimizers and `_multi_tensor` ones, the kernel below writes out gradients
// only when `grad_scale_ptr != nullptr.
template <
typename scalar_type,
int depth,
ADAM_MODE adam_mode,
bool amsgrad,
bool large_kernel_arg>
template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
struct FusedAdamMathFunctor {
static constexpr bool use_large_kernel_arg = large_kernel_arg;
static_assert(
depth == 4 || depth == 5,
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
using opmath_t = at::opmath_type<scalar_type>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
FusedOptimizerTensorListMetadata<depth>& tl,
const float* lr_ptr,
const double& lr,
const double& beta1,

View File

@ -43,26 +43,19 @@ void _fused_adamw_amsgrad_cuda_impl_(
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
5,
ADAM_MODE::ADAMW,
true,
large_kernel_arg>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -101,26 +94,19 @@ void _fused_adamw_amsgrad_cuda_impl_(
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
5,
ADAM_MODE::ADAMW,
true,
large_kernel_arg>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<5>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}

View File

@ -38,26 +38,19 @@ void _fused_adamw_cuda_impl_(
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
4,
ADAM_MODE::ADAMW,
false,
large_kernel_arg>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
lr_ptr, // unused
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}
@ -91,26 +84,19 @@ void _fused_adamw_cuda_impl_(
params[0].scalar_type(),
"fused_adamw_kernel_cuda",
[&]() {
DISPATCH_MULTI_TENSOR_APPLY([&]() {
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<
scalar_t,
4,
ADAM_MODE::ADAMW,
false,
large_kernel_arg>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
multi_tensor_apply_for_fused_optimizer<4>(
tensor_lists,
state_steps,
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
lr_ptr,
1.0, // unused
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale_ptr,
found_inf_ptr);
});
}

View File

@ -1455,7 +1455,6 @@ aten_cuda_cu_source_list = [
"aten/src/ATen/native/cuda/Equal.cpp",
"aten/src/ATen/native/cuda/GridSampler.cpp",
"aten/src/ATen/native/cuda/IndexKernel.cpp",
"aten/src/ATen/native/cuda/MultiTensorApply.cpp",
"aten/src/ATen/native/cuda/ReduceOps.cpp",
"aten/src/ATen/native/cuda/ScanKernels.cpp",
"aten/src/ATen/native/cuda/Sort.cpp",