mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Add support for 32KB multi_tensor_apply kernel arguments (#134373)"
This reverts commit 08184aa85cf183198ebdf2fd7a49fe7bc4842c13. Reverted https://github.com/pytorch/pytorch/pull/134373 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/135126 for more details ([comment](https://github.com/pytorch/pytorch/pull/134373#issuecomment-2329839011))
This commit is contained in:
@ -157,24 +157,21 @@ void _amp_foreach_non_finite_check_and_unscale_cuda_(TensorList scaled_grads,
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
|
||||
// multi_tensor_apply guards onto tensor_lists[0][0], no need to guard explicitly.
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(tensor_lists,
|
||||
UnaryOpFunctor<scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
|
||||
// There is a slight asymmetry here with the TensorIterator kernel above.
|
||||
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
|
||||
if (!isfinite_ensure_cuda_math(val)) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
|
||||
});
|
||||
});
|
||||
multi_tensor_apply<1>(tensor_lists,
|
||||
UnaryOpFunctor<scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
[found_inf_ptr, inv_scale_ptr] GPU_LAMBDA (opmath_t val) -> opmath_t {
|
||||
// There is a slight asymmetry here with the TensorIterator kernel above.
|
||||
// MTA Functors ensure val comes in as opmath_t rather than scalar_t.
|
||||
if (!isfinite_ensure_cuda_math(val)) {
|
||||
*found_inf_ptr = 1.f;
|
||||
}
|
||||
// Every thread accesses inv_scale, but it will hit in cache.
|
||||
const auto inv_scale_val = *inv_scale_ptr;
|
||||
return static_cast<opmath_t>(inv_scale_val == 1.f ? val : val * inv_scale_val);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -41,18 +41,15 @@ std::vector<Tensor> foreach_tensor_list_op(
|
||||
tensor_lists.emplace_back(std::move(vec_res));
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
BinaryOpListAlphaFunctor<
|
||||
T,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 2,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
alpha.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
BinaryOpListAlphaFunctor<
|
||||
T,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 2>(),
|
||||
Op<opmath_t>(),
|
||||
alpha.to<opmath_t>());
|
||||
|
||||
return tensor_lists[2];
|
||||
}
|
||||
@ -67,18 +64,15 @@ void foreach_tensor_list_op_(
|
||||
tensor_lists.emplace_back(tensors2.vec());
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpListAlphaFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
alpha.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpListAlphaFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>(),
|
||||
alpha.to<opmath_t>());
|
||||
increment_version(tensors1);
|
||||
}
|
||||
|
||||
@ -337,15 +331,13 @@ template <
|
||||
typename src_t,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
int res_arg_index>
|
||||
struct CopyFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
@ -428,17 +420,14 @@ void foreach_tensor_copy_list_kernel_cuda_(
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
|
||||
if constexpr (std::is_same_v<scalar_t, src_t>) {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
Copy<opmath_t, opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Copy<opmath_t, opmath_t>());
|
||||
} else {
|
||||
// Ref:
|
||||
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
|
||||
@ -446,18 +435,15 @@ void foreach_tensor_copy_list_kernel_cuda_(
|
||||
TORCH_WARN_ONCE(
|
||||
"Casting complex values to real discards the imaginary part");
|
||||
}
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
CopyFunctor<
|
||||
scalar_t,
|
||||
src_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
Copy<scalar_t, src_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
CopyFunctor<
|
||||
scalar_t,
|
||||
src_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Copy<scalar_t, src_t>());
|
||||
}
|
||||
});
|
||||
});
|
||||
|
@ -36,18 +36,15 @@ std::vector<Tensor> foreach_binary_op(
|
||||
tensor_lists.emplace_back(std::move(vec_res));
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
@ -57,18 +54,15 @@ void foreach_binary_op_(TensorList tensors, const Scalar& scalar) {
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
increment_version(tensors);
|
||||
}
|
||||
|
||||
|
@ -36,19 +36,16 @@ std::vector<Tensor> foreach_binary_op(
|
||||
tensor_lists.emplace_back(vec_res);
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
multi_tensor_apply<2, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
|
||||
Op<opmath_t>());
|
||||
});
|
||||
Op<opmath_t>());
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
@ -58,18 +55,15 @@ void foreach_binary_op_(TensorList tensors, at::ArrayRef<Scalar> scalars) {
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<1, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
BinaryOpScalarListFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
increment_version(tensors);
|
||||
}
|
||||
|
||||
|
@ -46,19 +46,16 @@ std::vector<Tensor> foreach_binary_op(
|
||||
tensor_lists.emplace_back(std::move(vec_res));
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarTensorFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.data_ptr<T>(),
|
||||
alpha.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarTensorFunctor<
|
||||
T,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.data_ptr<T>(),
|
||||
alpha.to<opmath_t>());
|
||||
return tensor_lists[1];
|
||||
}
|
||||
|
||||
@ -84,19 +81,16 @@ void foreach_binary_op_(
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarTensorFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.data_ptr<T>(),
|
||||
alpha.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
BinaryOpScalarTensorFunctor<
|
||||
T,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.data_ptr<T>(),
|
||||
alpha.to<opmath_t>());
|
||||
increment_version(tensors);
|
||||
}
|
||||
|
||||
|
@ -18,10 +18,10 @@ inline void increment_version(TensorList tensors) {
|
||||
}
|
||||
|
||||
// Initializes args and checks if all args are aligned
|
||||
template <int depth, typename T, bool large_kernel_arg>
|
||||
template <int depth, typename T>
|
||||
__device__ bool init_args(
|
||||
T** args,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
@ -38,10 +38,10 @@ __device__ bool init_args(
|
||||
}
|
||||
|
||||
// Initializes args and checks if all args are aligned
|
||||
template <int depth, typename T, typename T2, bool large_kernel_arg>
|
||||
template <int depth, typename T, typename T2>
|
||||
__device__ bool init_args(
|
||||
T** args,
|
||||
TensorListScalarListMetadata<T2, depth, large_kernel_arg>& tl,
|
||||
TensorListScalarListMetadata<T2, depth>& tl,
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
@ -57,10 +57,10 @@ __device__ bool init_args(
|
||||
return all_aligned;
|
||||
}
|
||||
|
||||
template <int depth, typename T, bool large_kernel_arg>
|
||||
template <int depth, typename T>
|
||||
__device__ bool init_args(
|
||||
T** args,
|
||||
FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
FusedOptimizerTensorListMetadata<depth>& tl,
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
@ -203,19 +203,13 @@ __device__ __forceinline__ void pointwise_op_scalar(
|
||||
//
|
||||
// Binary Functors
|
||||
//
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpScalarFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
@ -233,19 +227,13 @@ struct BinaryOpScalarFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpScalarListFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
@ -263,19 +251,13 @@ struct BinaryOpScalarListFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpListAlphaFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t alpha) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
@ -321,19 +303,13 @@ struct BinaryOpListAlphaFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct BinaryOpScalarTensorFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
T* scalar,
|
||||
opmath_t alpha) {
|
||||
@ -385,17 +361,11 @@ struct BinaryOpScalarTensorFunctor {
|
||||
// Unary Functors
|
||||
//
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct ZeroFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<1, large_kernel_arg>& tl) {
|
||||
TensorListMetadata<1>& tl) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
@ -431,19 +401,13 @@ struct ZeroFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct UnaryOpFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
@ -489,19 +453,13 @@ struct UnaryOpFunctor {
|
||||
// Pointwise Functors
|
||||
//
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct PointwiseOpScalarFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
@ -519,19 +477,13 @@ struct PointwiseOpScalarFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct PointwiseOpScalarListFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth, large_kernel_arg>& tl,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
@ -549,14 +501,13 @@ struct PointwiseOpScalarListFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int depth, bool large_kernel_arg>
|
||||
template <typename T, int depth>
|
||||
struct PointwiseOpListFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
@ -601,19 +552,13 @@ struct PointwiseOpListFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct TernaryOpListFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
static_assert(depth == 3 || depth == 4, "");
|
||||
static_assert(depth >= r_args_depth, "");
|
||||
@ -661,19 +606,13 @@ struct TernaryOpListFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename T,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index,
|
||||
bool large_kernel_arg>
|
||||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct TernaryOpScalarFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t alpha) {
|
||||
static_assert(depth == 2 || depth == 3, "");
|
||||
|
@ -46,18 +46,15 @@ std::vector<Tensor> foreach_pointwise_op(
|
||||
"foreach_pointwise_op_cuda",
|
||||
[&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<4>(
|
||||
tensor_lists,
|
||||
PointwiseOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<4>(
|
||||
tensor_lists,
|
||||
PointwiseOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[3];
|
||||
@ -81,18 +78,15 @@ void foreach_pointwise_op_(
|
||||
"foreach_pointwise_op__cuda",
|
||||
[&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
PointwiseOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
PointwiseOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>(),
|
||||
scalar.to<opmath_t>());
|
||||
});
|
||||
increment_version(input);
|
||||
}
|
||||
@ -116,18 +110,15 @@ void foreach_pointwise_op_(
|
||||
"foreach_pointwise_op__cuda",
|
||||
[&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<3, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
increment_version(input);
|
||||
}
|
||||
@ -158,18 +149,15 @@ std::vector<Tensor> foreach_pointwise_op(
|
||||
"foreach_pointwise_op_cuda",
|
||||
[&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<4, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<4, opmath_t>(
|
||||
tensor_lists,
|
||||
scalars,
|
||||
PointwiseOpScalarListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[3];
|
||||
|
@ -50,13 +50,11 @@ template <
|
||||
typename T,
|
||||
int depth = 1,
|
||||
int r_args_depth = 1,
|
||||
int res_arg_index = 0,
|
||||
bool large_kernel_arg = false>
|
||||
int res_arg_index = 0>
|
||||
struct LpMaxFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
T* output_per_tensor_ptr,
|
||||
const int max_chunks_per_tensor) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
@ -180,13 +178,11 @@ std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
|
||||
tensor_lists[0][0].scalar_type(),
|
||||
"foreach_tensor_max_cuda_scalar_type",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpMaxFunctor<scalar_t, 1, 1, 0, large_kernel_arg>(),
|
||||
output_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
max_chunks_per_tensor);
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpMaxFunctor<scalar_t>(),
|
||||
output_per_tensor.mutable_data_ptr<scalar_t>(),
|
||||
max_chunks_per_tensor);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
@ -243,14 +239,12 @@ template <
|
||||
typename out_t,
|
||||
int depth = 1,
|
||||
int r_args_depth = 1,
|
||||
int res_arg_index = 0,
|
||||
bool large_kernel_arg = false>
|
||||
int res_arg_index = 0>
|
||||
struct LpNormFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
using out_opmath_t = typename at::opmath_type<out_t>;
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
out_opmath_t* output_per_tensor_ptr,
|
||||
const int max_chunks_per_tensor) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
@ -482,50 +476,23 @@ std::vector<Tensor> foreach_tensor_norm_cuda(
|
||||
output_dtype, "foreach_tensor_norm_cuda_out_dtype", [&]() {
|
||||
using out_opmath_t = typename at::opmath_type<out_t>;
|
||||
if (p == static_cast<double>(1)) {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<
|
||||
scalar_t,
|
||||
NormType::L1,
|
||||
out_t,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
large_kernel_arg>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<scalar_t, NormType::L1, out_t>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
} else if (p == static_cast<double>(2)) {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<
|
||||
scalar_t,
|
||||
NormType::L2,
|
||||
out_t,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
large_kernel_arg>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<scalar_t, NormType::L2, out_t>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
} else if (p == std::numeric_limits<double>::infinity()) {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<
|
||||
scalar_t,
|
||||
NormType::LInf,
|
||||
out_t,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
large_kernel_arg>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
LpNormFunctor<scalar_t, NormType::LInf, out_t>(),
|
||||
output_per_tensor.mutable_data_ptr<out_opmath_t>(),
|
||||
max_chunks_per_tensor);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
|
@ -46,17 +46,14 @@ std::vector<at::Tensor> foreach_tensor_lerp_ternary_cuda(
|
||||
"foreach_tensor_lerp_ternary_cuda",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<4>(
|
||||
tensor_lists,
|
||||
TernaryOpListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3,
|
||||
large_kernel_arg>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<4>(
|
||||
tensor_lists,
|
||||
TernaryOpListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 4,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 3>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[3];
|
||||
@ -80,17 +77,14 @@ void foreach_tensor_lerp_ternary_cuda_(
|
||||
"foreach_tensor_lerp_ternary_cuda_",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
TernaryOpListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
TernaryOpListFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 3,
|
||||
/* res_arg_index */ 0>(),
|
||||
LerpFunctor<opmath_t>());
|
||||
});
|
||||
increment_version(tensors1);
|
||||
}
|
||||
@ -119,18 +113,15 @@ std::vector<at::Tensor> foreach_tensor_lerp_list_cuda(
|
||||
"foreach_tensor_lerp_scalar_cuda",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
TernaryOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 2,
|
||||
large_kernel_arg>(),
|
||||
LerpFunctor<opmath_t>(),
|
||||
weight.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
TernaryOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 3,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 2>(),
|
||||
LerpFunctor<opmath_t>(),
|
||||
weight.to<opmath_t>());
|
||||
});
|
||||
|
||||
return tensor_lists[2];
|
||||
@ -154,18 +145,15 @@ void foreach_tensor_lerp_list_cuda_(
|
||||
"foreach_tensor_lerp_scalar_cuda_",
|
||||
[&]() {
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
TernaryOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
LerpFunctor<opmath_t>(),
|
||||
weight.to<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
TernaryOpScalarFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 2,
|
||||
/* res_arg_index */ 0>(),
|
||||
LerpFunctor<opmath_t>(),
|
||||
weight.to<opmath_t>());
|
||||
});
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -56,17 +56,14 @@ std::vector<Tensor> foreach_unary_op(TensorList tensors) {
|
||||
tensor_lists.emplace_back(std::move(vec_res));
|
||||
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Op<opmath_t>());
|
||||
|
||||
return tensor_lists[1];
|
||||
}
|
||||
@ -76,17 +73,14 @@ void foreach_unary_op_(TensorList tensors) {
|
||||
std::vector<std::vector<at::Tensor>> tensor_lists;
|
||||
tensor_lists.emplace_back(tensors.vec());
|
||||
using opmath_t = typename at::opmath_type<scalar_t>;
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>(),
|
||||
Op<opmath_t>());
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>(),
|
||||
Op<opmath_t>());
|
||||
increment_version(tensors);
|
||||
}
|
||||
|
||||
@ -401,16 +395,13 @@ void foreach_tensor_zero_cuda_(TensorList tensors) {
|
||||
tensors[0].scalar_type(),
|
||||
"foreach_zero_cuda_",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
ZeroFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0,
|
||||
large_kernel_arg>());
|
||||
});
|
||||
multi_tensor_apply<1>(
|
||||
tensor_lists,
|
||||
ZeroFunctor<
|
||||
scalar_t,
|
||||
/* depth */ 1,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 0>());
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -56,15 +56,14 @@ C10_DEVICE __forceinline__ void sgd_math(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int depth, bool large_kernel_arg>
|
||||
template <typename scalar_t, int depth>
|
||||
struct FusedSgdMathFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
static_assert(
|
||||
depth == 2 || depth == 3,
|
||||
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
const int chunk_size,
|
||||
TensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
TensorListMetadata<depth>& tl,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
const float* lr_ptr,
|
||||
@ -173,21 +172,19 @@ void _fused_sgd_with_momentum_kernel_cuda_(
|
||||
params[0].scalar_type(),
|
||||
"fused_sgd_with_momentum_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr_ptr,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 3>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr_ptr,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -249,21 +246,19 @@ void _fused_sgd_with_momentum_kernel_cuda_(
|
||||
params[0].scalar_type(),
|
||||
"fused_sgd_with_momentum_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 3, large_kernel_arg>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr.data_ptr<float>(),
|
||||
1.0,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply<3>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 3>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr.data_ptr<float>(),
|
||||
1.0,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
is_first_step,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -317,21 +312,19 @@ void _fused_sgd_kernel_cuda_(
|
||||
params[0].scalar_type(),
|
||||
"fused_sgd_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr_ptr,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
/* is_first_step */ false,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 2>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr_ptr,
|
||||
lr,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
/* is_first_step */ false,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -411,21 +404,19 @@ void _fused_sgd_kernel_cuda_(
|
||||
params[0].scalar_type(),
|
||||
"fused_sgd_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 2, large_kernel_arg>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr.data_ptr<float>(),
|
||||
1.0,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
/* is_first_step */ false,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
FusedSgdMathFunctor<scalar_t, 2>(),
|
||||
weight_decay,
|
||||
momentum,
|
||||
lr.data_ptr<float>(),
|
||||
1.0,
|
||||
dampening,
|
||||
nesterov,
|
||||
maximize,
|
||||
/* is_first_step */ false,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,25 +0,0 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
bool supports_large_kernel_arg() {
|
||||
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && CUDART_VERSION >= 12010
|
||||
static std::optional<bool> supports_large_kernel_arg_ = std::nullopt;
|
||||
if (!supports_large_kernel_arg_.has_value()) {
|
||||
int driver_ver = 0;
|
||||
AT_CUDA_CHECK(cudaDriverGetVersion(&driver_ver));
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
supports_large_kernel_arg_ = (driver_ver >= 12010) && prop->major >= 7;
|
||||
}
|
||||
const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() !=
|
||||
at::cuda::CaptureStatus::None;
|
||||
return !is_capturing && *supports_large_kernel_arg_;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace at::native
|
@ -8,105 +8,20 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
// NOTE: [32KB kernel argument size support]
|
||||
// 32KB kernel argument size support has three requirements:
|
||||
// - CUDART_VERSION >= 12010
|
||||
// - Driver version >= 530
|
||||
// - GPU arch >= VOLTA
|
||||
//
|
||||
// Due to minor version compatibility, it possible for binaries built with
|
||||
// CUDART_VERSION >= 12010 to run with driver version < 530. Since driver
|
||||
// version can only be checked at runtime, if CUDART_VERSION >= 12010, we have
|
||||
// to build both 4KB and 32KB kernels and determine the appropriate kernel to
|
||||
// dispatch at runtime.
|
||||
//
|
||||
// - If CUDART_VERSION < 12010, only 4KB kernels will be instantiated.
|
||||
//
|
||||
// - If CUDART_VERSION >= 12010:
|
||||
// - Host code:
|
||||
// - We always instantiate the launching stub for both 4KB and 32KB kernels.
|
||||
// - Device code:
|
||||
// - If __CUDA_ARCH__ >= 700, we always instantiate both 4KB and 32KB
|
||||
// kernels.
|
||||
// - If __CUDA_ARCH__ < 700, it's not possible to even compile an empty
|
||||
// 32KB kernel (formal parameter space overflowed). Thus, we only
|
||||
// instantiate a declaration for 32KB kernels. This is valid as long as the
|
||||
// declaration-only kernel is not launched.
|
||||
//
|
||||
// - At runtime, we dispatch to the 32KB kernel if driver version >= 530 and
|
||||
// GPU arch >= VOLTA.
|
||||
//
|
||||
// - TODO(yifu): once there's a CUDART version that is not compatible with any
|
||||
// driver version below 530, we can determine at compile time to not compile
|
||||
// the kernels for 4KB kernel argument size.
|
||||
//
|
||||
// https://developer.nvidia.com/blog/cuda-12-1-supports-large-kernel-parameters/
|
||||
bool supports_large_kernel_arg();
|
||||
|
||||
namespace {
|
||||
|
||||
static constexpr int64_t kILP = 4;
|
||||
static constexpr int64_t kChunkSize = 65536;
|
||||
static constexpr int64_t kBlockSize = 512;
|
||||
|
||||
// MSVC has a problem with constexpr and can't handle passing them to templates
|
||||
// as arguments. We need to replace it with const static.
|
||||
// https://github.com/state-spaces/mamba/issues/12#issuecomment-1848835662
|
||||
#if !defined(_WIN32)
|
||||
#define SWITCH_TYPE constexpr bool
|
||||
#else
|
||||
#define SWITCH_TYPE const static bool
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDART_VERSION) && \
|
||||
CUDART_VERSION >= 12010
|
||||
#define DISPATCH_MULTI_TENSOR_APPLY(...) \
|
||||
if (at::native::supports_large_kernel_arg()) { \
|
||||
SWITCH_TYPE large_kernel_arg C10_UNUSED = true; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
|
||||
__VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define DISPATCH_MULTI_TENSOR_APPLY(...) \
|
||||
do { \
|
||||
SWITCH_TYPE large_kernel_arg C10_UNUSED = false; \
|
||||
__VA_ARGS__(); \
|
||||
} while (0);
|
||||
#endif
|
||||
|
||||
template <bool large_kernel_arg>
|
||||
struct DepthToMaxConfig;
|
||||
|
||||
template <>
|
||||
struct DepthToMaxConfig<false> {
|
||||
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
static constexpr int depth_to_max_tensors_scalarlist[5] =
|
||||
{96, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
|
||||
72,
|
||||
60};
|
||||
using TensorIdxType = unsigned char;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DepthToMaxConfig<true> {
|
||||
// TODO(yifu): These values are not yet optimally tuned. I simply multiplied
|
||||
// the values tuned for 4KB kernel argument size limit by 7 (the kernel
|
||||
// argument size limit increased by 8x but we need to change the type of
|
||||
// block_to_tensor from unsigned char to uint16_t to support larger number of
|
||||
// tensors).
|
||||
static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210};
|
||||
static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240};
|
||||
static constexpr int depth_to_max_tensors_scalarlist[5] =
|
||||
{672, 448, 336, 252, 210};
|
||||
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
|
||||
504,
|
||||
420};
|
||||
using TensorIdxType = uint16_t;
|
||||
};
|
||||
// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
|
||||
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
|
||||
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {
|
||||
72,
|
||||
60};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T* p) {
|
||||
@ -123,101 +38,73 @@ __device__ __forceinline__ void load_store(
|
||||
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
|
||||
}
|
||||
|
||||
template <int n, bool large_kernel_arg>
|
||||
template <int n>
|
||||
struct TensorListMetadata {
|
||||
using Conf = DepthToMaxConfig<large_kernel_arg>;
|
||||
const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
|
||||
int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
|
||||
typename Conf::TensorIdxType
|
||||
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
|
||||
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
template <typename scalar_vals_t, int n, bool large_kernel_arg>
|
||||
template <typename scalar_vals_t, int n>
|
||||
struct TensorListScalarListMetadata {
|
||||
using Conf = DepthToMaxConfig<large_kernel_arg>;
|
||||
const void* addresses[n][Conf::depth_to_max_tensors_scalarlist[n - 1]];
|
||||
int64_t numel_for_tensor[Conf::depth_to_max_tensors_scalarlist[n - 1]];
|
||||
scalar_vals_t scalar_vals[Conf::depth_to_max_tensors_scalarlist[n - 1]];
|
||||
typename Conf::TensorIdxType
|
||||
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
|
||||
const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]];
|
||||
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
||||
};
|
||||
|
||||
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of
|
||||
// 4kb with `c10::complex<double>`
|
||||
template <bool large_kernel_arg>
|
||||
struct TensorListScalarListMetadata<c10::complex<double>, 1, large_kernel_arg> {
|
||||
using Conf = DepthToMaxConfig<large_kernel_arg>;
|
||||
const void*
|
||||
addresses[1][Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
int64_t numel_for_tensor
|
||||
[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
template <>
|
||||
struct TensorListScalarListMetadata<c10::complex<double>, 1> {
|
||||
const void* addresses[1]
|
||||
[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
int64_t
|
||||
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
c10::complex<double>
|
||||
scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
typename Conf::TensorIdxType
|
||||
block_to_tensor[Conf::depth_to_max_blocks[1 - 1]];
|
||||
int block_to_chunk[Conf::depth_to_max_blocks[1 - 1]];
|
||||
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[1 - 1]];
|
||||
};
|
||||
|
||||
template <bool large_kernel_arg>
|
||||
struct TensorListScalarListMetadata<c10::complex<double>, 2, large_kernel_arg> {
|
||||
using Conf = DepthToMaxConfig<large_kernel_arg>;
|
||||
const void*
|
||||
addresses[2][Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
int64_t numel_for_tensor
|
||||
[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
template <>
|
||||
struct TensorListScalarListMetadata<c10::complex<double>, 2> {
|
||||
const void* addresses[2]
|
||||
[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
int64_t
|
||||
numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
c10::complex<double>
|
||||
scalar_vals[Conf::depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
typename Conf::TensorIdxType
|
||||
block_to_tensor[Conf::depth_to_max_blocks[2 - 1]];
|
||||
int block_to_chunk[Conf::depth_to_max_blocks[2 - 1]];
|
||||
scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[2 - 1]];
|
||||
};
|
||||
|
||||
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
|
||||
// whose each element is `at::Tensor` of 1 element representing the number of
|
||||
// `step`s called so far.
|
||||
template <int n, bool large_kernel_arg>
|
||||
template <int n>
|
||||
struct FusedOptimizerTensorListMetadata {
|
||||
using Conf = DepthToMaxConfig<large_kernel_arg>;
|
||||
const void* addresses[n][Conf::depth_to_max_tensors[n - 1]];
|
||||
int64_t numel_for_tensor[Conf::depth_to_max_tensors[n - 1]];
|
||||
const void*
|
||||
state_steps_addresses[Conf::depth_to_max_tensors_scalarlist[n - 1]];
|
||||
typename Conf::TensorIdxType
|
||||
block_to_tensor[Conf::depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[Conf::depth_to_max_blocks[n - 1]];
|
||||
const void* addresses[n][depth_to_max_tensors[n - 1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors[n - 1]];
|
||||
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n - 1]];
|
||||
int start_tensor_this_launch;
|
||||
};
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
||||
__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
|
||||
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
|
||||
__global__ void multi_tensor_apply_kernel(
|
||||
T tensorListMeta,
|
||||
U callable,
|
||||
ArgTypes... args) {
|
||||
// Hand the chunk information to the user-supplied functor to process however
|
||||
// it likes.
|
||||
callable(kChunkSize, tensorListMeta, args...);
|
||||
}
|
||||
#else
|
||||
// When compiling device code with __CUDA_ARCH__ < 700, we only instantiate a
|
||||
// declaration for the 32KB kernels.
|
||||
// For details see: [32KB kernel argument size support]
|
||||
#pragma nv_diag_suppress 114 // Function was referenced but not defined
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
||||
__global__ typename std::enable_if<U::use_large_kernel_arg, void>::type
|
||||
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args);
|
||||
#pragma nv_diag_default 114 // Function was referenced but not defined
|
||||
#endif
|
||||
|
||||
template <typename T, typename U, typename... ArgTypes>
|
||||
C10_LAUNCH_BOUNDS_1(kBlockSize)
|
||||
__global__ typename std::enable_if<!U::use_large_kernel_arg, void>::type
|
||||
multi_tensor_apply_kernel(T tensorListMeta, U callable, ArgTypes... args) {
|
||||
callable(kChunkSize, tensorListMeta, args...);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -246,10 +133,7 @@ void multi_tensor_apply(
|
||||
"Number of tensor lists has to match the depth.");
|
||||
const size_t n_tensors = tensor_lists[0].size();
|
||||
using scalar_vals_t = typename T::opmath_t;
|
||||
TensorListScalarListMetadata<scalar_vals_t, depth, T::use_large_kernel_arg>
|
||||
tensorListMeta;
|
||||
|
||||
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
|
||||
TensorListScalarListMetadata<scalar_vals_t, depth> tensorListMeta;
|
||||
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
@ -283,11 +167,10 @@ void multi_tensor_apply(
|
||||
// a tensor is not considered full unless all its chunks have been
|
||||
// processed
|
||||
const bool tensors_full =
|
||||
(loc_tensor_info ==
|
||||
Conf::depth_to_max_tensors_scalarlist[depth - 1] &&
|
||||
(loc_tensor_info == depth_to_max_tensors_scalarlist[depth - 1] &&
|
||||
chunk == chunks - 1);
|
||||
const bool blocks_full =
|
||||
(loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
|
||||
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
|
||||
if (tensors_full || blocks_full) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
@ -340,11 +223,9 @@ void multi_tensor_apply(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth.");
|
||||
const size_t n_tensors = tensor_lists[0].size();
|
||||
TensorListMetadata<depth, T::use_large_kernel_arg> tensorListMeta;
|
||||
TensorListMetadata<depth> tensorListMeta;
|
||||
tensorListMeta.start_tensor_this_launch = 0;
|
||||
|
||||
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
|
||||
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
for (size_t t = 0; t < n_tensors; t++) {
|
||||
@ -369,10 +250,10 @@ void multi_tensor_apply(
|
||||
loc_block_info++;
|
||||
|
||||
const bool tensors_full =
|
||||
(loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
|
||||
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks - 1);
|
||||
const bool blocks_full =
|
||||
(loc_block_info == Conf::depth_to_max_blocks[depth - 1]);
|
||||
(loc_block_info == depth_to_max_blocks[depth - 1]);
|
||||
|
||||
if (tensors_full || blocks_full) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
@ -423,10 +304,7 @@ void multi_tensor_apply_for_fused_optimizer(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth");
|
||||
const auto num_tensors = tensor_lists[0].size();
|
||||
FusedOptimizerTensorListMetadata<depth, T::use_large_kernel_arg>
|
||||
tensorListMeta;
|
||||
|
||||
using Conf = DepthToMaxConfig<T::use_large_kernel_arg>;
|
||||
FusedOptimizerTensorListMetadata<depth> tensorListMeta;
|
||||
|
||||
int loc_block_info = 0;
|
||||
int loc_tensor_info = 0;
|
||||
@ -455,10 +333,9 @@ void multi_tensor_apply_for_fused_optimizer(
|
||||
loc_block_info++;
|
||||
|
||||
const auto tensor_full =
|
||||
(loc_tensor_info == Conf::depth_to_max_tensors[depth - 1] &&
|
||||
(loc_tensor_info == depth_to_max_tensors[depth - 1] &&
|
||||
chunk == chunks - 1);
|
||||
const auto blocks_full =
|
||||
loc_block_info == Conf::depth_to_max_blocks[depth - 1];
|
||||
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
|
||||
|
||||
if (tensor_full || blocks_full) {
|
||||
multi_tensor_apply_kernel<<<
|
||||
|
@ -42,26 +42,19 @@ void _fused_adam_amsgrad_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adam_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
5,
|
||||
ADAM_MODE::ORIGINAL,
|
||||
true,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -100,26 +93,19 @@ void _fused_adam_amsgrad_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adam_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
5,
|
||||
ADAM_MODE::ORIGINAL,
|
||||
true,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ORIGINAL, true>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -37,26 +37,19 @@ void _fused_adam_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adam_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
4,
|
||||
ADAM_MODE::ORIGINAL,
|
||||
false,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -90,26 +83,19 @@ void _fused_adam_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adam_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
4,
|
||||
ADAM_MODE::ORIGINAL,
|
||||
false,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ORIGINAL, false>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -102,21 +102,15 @@ C10_DEVICE inline void adam_math(
|
||||
// parameter updates accordingly. To be functionally on par with `torch.optim`
|
||||
// optimizers and `_multi_tensor` ones, the kernel below writes out gradients
|
||||
// only when `grad_scale_ptr != nullptr.
|
||||
template <
|
||||
typename scalar_type,
|
||||
int depth,
|
||||
ADAM_MODE adam_mode,
|
||||
bool amsgrad,
|
||||
bool large_kernel_arg>
|
||||
template <typename scalar_type, int depth, ADAM_MODE adam_mode, bool amsgrad>
|
||||
struct FusedAdamMathFunctor {
|
||||
static constexpr bool use_large_kernel_arg = large_kernel_arg;
|
||||
static_assert(
|
||||
depth == 4 || depth == 5,
|
||||
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
|
||||
using opmath_t = at::opmath_type<scalar_type>;
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
FusedOptimizerTensorListMetadata<depth, large_kernel_arg>& tl,
|
||||
FusedOptimizerTensorListMetadata<depth>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
const double& beta1,
|
||||
|
@ -43,26 +43,19 @@ void _fused_adamw_amsgrad_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adamw_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
5,
|
||||
ADAM_MODE::ADAMW,
|
||||
true,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -101,26 +94,19 @@ void _fused_adamw_amsgrad_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adamw_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
5,
|
||||
ADAM_MODE::ADAMW,
|
||||
true,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<5>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 5, ADAM_MODE::ADAMW, true>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -38,26 +38,19 @@ void _fused_adamw_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adamw_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
4,
|
||||
ADAM_MODE::ADAMW,
|
||||
false,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
|
||||
lr_ptr, // unused
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
@ -91,26 +84,19 @@ void _fused_adamw_cuda_impl_(
|
||||
params[0].scalar_type(),
|
||||
"fused_adamw_kernel_cuda",
|
||||
[&]() {
|
||||
DISPATCH_MULTI_TENSOR_APPLY([&]() {
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<
|
||||
scalar_t,
|
||||
4,
|
||||
ADAM_MODE::ADAMW,
|
||||
false,
|
||||
large_kernel_arg>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
multi_tensor_apply_for_fused_optimizer<4>(
|
||||
tensor_lists,
|
||||
state_steps,
|
||||
FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>(),
|
||||
lr_ptr,
|
||||
1.0, // unused
|
||||
beta1,
|
||||
beta2,
|
||||
weight_decay,
|
||||
eps,
|
||||
maximize,
|
||||
grad_scale_ptr,
|
||||
found_inf_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1455,7 +1455,6 @@ aten_cuda_cu_source_list = [
|
||||
"aten/src/ATen/native/cuda/Equal.cpp",
|
||||
"aten/src/ATen/native/cuda/GridSampler.cpp",
|
||||
"aten/src/ATen/native/cuda/IndexKernel.cpp",
|
||||
"aten/src/ATen/native/cuda/MultiTensorApply.cpp",
|
||||
"aten/src/ATen/native/cuda/ReduceOps.cpp",
|
||||
"aten/src/ATen/native/cuda/ScanKernels.cpp",
|
||||
"aten/src/ATen/native/cuda/Sort.cpp",
|
||||
|
Reference in New Issue
Block a user