Use tensor cores for NT bmm (#86856)

Copy of internal diff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86856
Approved by: https://github.com/drisspg
This commit is contained in:
Christian Puhrsch
2022-11-02 21:51:40 +00:00
committed by PyTorch MergeBot
parent 1c0d47cb17
commit 943b20e7ae
5 changed files with 613 additions and 301 deletions

View File

@ -136,6 +136,113 @@ matmul_nested_helper(
}
}
Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) {
// Tensor self = self_.contiguous();
// Tensor mat2 = mat2_.contiguous();
// self [N, n_heads, *, head_dim]
// mat2 [N, n_heads, head_dim, *]
const auto self_ptr = get_nested_tensor_impl(self);
const auto mat2_ptr = get_nested_tensor_impl(mat2);
// metadata for self
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr);
std::vector<IntArrayRef> self_strides = NestedTensor_get_strides(self_ptr);
std::vector<int64_t> self_offsets = self_ptr->get_storage_offsets();
auto opt = self_ptr->get_nested_size_tensor().options();
// metadata for mat2
std::vector<IntArrayRef> mat2_sizes = NestedTensor_get_sizes(mat2_ptr);
std::vector<IntArrayRef> mat2_strides = NestedTensor_get_strides(mat2_ptr);
std::vector<int64_t> mat2_offsets = mat2_ptr->get_storage_offsets();
auto opt2 = mat2_ptr->get_nested_size_tensor().options();
int64_t N = self_sizes.size();
int64_t n_heads = self_sizes[0][0];
// viewed metadata for self
auto self_new_sizes = at::empty({N * n_heads, 2}, opt);
int64_t* self_new_sizes_ptr = self_new_sizes.data_ptr<int64_t>();
auto self_new_strides = at::empty({N * n_heads, 2}, opt);
int64_t* self_new_strides_ptr = self_new_strides.data_ptr<int64_t>();
std::vector<int64_t> self_new_offsets;
// viewed metadata for mat2
auto mat2_new_sizes = at::empty({N * n_heads, 2}, opt2);
int64_t* mat2_new_sizes_ptr = mat2_new_sizes.data_ptr<int64_t>();
auto mat2_new_strides = at::empty({N * n_heads, 2}, opt2);
int64_t* mat2_new_strides_ptr = mat2_new_strides.data_ptr<int64_t>();
std::vector<int64_t> mat2_new_offsets;
for (int64_t i = 0; i < N; i++) {
const IntArrayRef& self_size_i = self_sizes[i];
const IntArrayRef& self_stride_i = self_strides[i];
int64_t self_offset = self_offsets[i];
const IntArrayRef& mat2_size_i = mat2_sizes[i];
const IntArrayRef& mat2_stride_i = mat2_strides[i];
int64_t mat2_offset = mat2_offsets[i];
for (int64_t j = 0; j < n_heads; j++) {
auto idx = (i * n_heads + j) * 2;
self_new_sizes_ptr[idx] = self_size_i[1];
self_new_sizes_ptr[idx + 1] = self_size_i[2];
self_new_strides_ptr[idx] = self_stride_i[1];
self_new_strides_ptr[idx + 1] = self_stride_i[2];
self_new_offsets.push_back(self_offset);
self_offset += self_stride_i[0];
mat2_new_sizes_ptr[idx] = mat2_size_i[1];
mat2_new_sizes_ptr[idx + 1] = mat2_size_i[2];
mat2_new_strides_ptr[idx] = mat2_stride_i[1];
mat2_new_strides_ptr[idx + 1] = mat2_stride_i[2];
mat2_new_offsets.push_back(mat2_offset);
mat2_offset += mat2_stride_i[0];
}
}
// view self as [N * n_heads, *, head_dim] (collapse first 2 dims)
auto viewed_self = create_nested_view_tensor(
self, self_new_sizes, self_new_strides, std::vector<int64_t>(self_new_offsets));
// view mat2 as [N * n_heads, head_dim, *] (collapse first 2_dims)
auto viewed_mat2 = create_nested_view_tensor(
mat2, mat2_new_sizes, mat2_new_strides, std::vector<int64_t>(mat2_new_offsets));
// output [N * n_heads, *, *]
auto bmm_output = at::bmm(viewed_self, viewed_mat2);
// generate metadata for viewing output as [N, n_heads, *, *]
// output of bmm should be contiguous so stride calculations should hold
auto out_new_sizes = at::empty({N, 3}, opt);
auto out_new_strides = at::empty({N, 3}, opt);
std::vector<int64_t> out_new_offsets;
int64_t* out_new_sizes_ptr = out_new_sizes.data_ptr<int64_t>();
int64_t* out_new_strides_ptr = out_new_strides.data_ptr<int64_t>();
int64_t out_offset = 0;
for (int64_t i = 0; i < N; i++) {
out_new_offsets.push_back(out_offset);
const IntArrayRef& self_size_i = self_sizes[i];
const IntArrayRef& mat2_size_i = mat2_sizes[i];
auto idx = i * 3;
out_new_sizes_ptr[idx] = n_heads;
out_new_sizes_ptr[idx + 1] = self_size_i[1];
out_new_sizes_ptr[idx + 2] = mat2_size_i[2];
out_new_strides_ptr[idx] = self_size_i[1] * mat2_size_i[2];
out_new_strides_ptr[idx + 1] = mat2_size_i[2];
out_new_strides_ptr[idx + 2] = 1;
out_offset += n_heads * (self_size_i[1] * mat2_size_i[2]);
}
auto viewed_out = create_nested_view_tensor(
bmm_output, out_new_sizes, out_new_strides, std::vector<int64_t>(out_new_offsets));
return viewed_out;
}
// Note [nested tensor matmul]
// This is really a generalized batched matmul dedicated to nested tensors,
// where `self` and `mat2` have same number (>= 3) of dimensions.
@ -193,6 +300,20 @@ Tensor matmul_nested(const Tensor& self, const Tensor& mat2) {
self_dim_size,
"second last dimension of mat2 has sizes",
mat2_dim_size);
// use bmm inference-only fast path for [N, n_heads, *, head_dim] [N, n_heads, head_dim, *]
if (self.is_cuda() &&
self_dim == 4 && self.is_contiguous() &&
mat2_dim == 4 && mat2.is_contiguous() &&
!(GradMode::is_enabled() && (self.requires_grad() || mat2.requires_grad()))) {
auto n_heads = self_sizes.select(0, 1).select(0, 0).item<int64_t>();
auto self_first_dim_n_heads = at::all(self_sizes.select(1, 0) == n_heads).item<bool>();
auto mat2_first_dim_n_heads = at::all(mat2_sizes.select(1, 0) == n_heads).item<bool>();
if (self_first_dim_n_heads && mat2_first_dim_n_heads) {
return matmul_with_bmm_nested(self, mat2);
}
}
// Construct output size from input sizes
Tensor output_sizes = self_sizes.clone();
// The last entry in every row of output_sizes should be last column of mat2_sizes

View File

@ -0,0 +1,416 @@
#include <type_traits>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <ATen/native/cuda/PersistentSoftmax.cuh>
#include <ATen/native/cuda/block_reduce.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#ifndef USE_ROCM
#ifndef _WIN32
#include <cutlass/gemm/device/default_gemm_configuration.h>
#include <cutlass/gemm/device/gemm_grouped.h>
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
#endif
#endif
#include <ATen/NestedTensorImpl.h>
#define BLOCK_DIM 256
#define GRID_DIM_Y 16
namespace at {
namespace native {
#ifndef USE_ROCM
#ifndef _WIN32
namespace {
template <
typename scalar_t,
unsigned int kPad,
typename LayoutA,
typename LayoutB,
typename OpClass,
typename Arch,
typename ThreadBlockShape,
typename WarpShape,
typename InstructionShape>
void gemm_grouped_cuda_internal(
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
const std::vector<scalar_t*>& aptr,
const std::vector<scalar_t*>& bptr,
const std::vector<scalar_t*>& dptr,
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
const int problem_count,
at::Device& device) {
using Element = scalar_t;
using ElementAcc = float;
using GemmConfiguration =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
Arch,
Element,
Element,
Element,
ElementAcc>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
Element,
LayoutA,
cutlass::ComplexTransform::kNone,
kPad,
Element,
LayoutB,
cutlass::ComplexTransform::kNone,
kPad,
Element,
cutlass::layout::RowMajor,
ElementAcc,
OpClass,
Arch,
ThreadBlockShape,
WarpShape,
InstructionShape,
typename GemmConfiguration::EpilogueOutputOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
GemmConfiguration::kStages>::GemmKernel;
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
const int64_t gemm_coord_size =
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
// Number of gmm args not including *problem_sizes
at::Tensor gmm_args = at::empty(
{problem_count * 6 + gemm_coord_size},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
// Obtain pointers for each argument (on host)
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
int64_t* ldb_data = lda_data + problem_count;
int64_t* ldd_data = lda_data + 2 * problem_count;
int64_t* ptr_a_data = lda_data + 3 * problem_count;
int64_t* ptr_b_data = lda_data + 4 * problem_count;
int64_t* ptr_d_data = lda_data + 5 * problem_count;
cutlass::gemm::GemmCoord* problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Set arguments into gmm_args from input args
for (int i = 0; i < problem_count; ++i) {
problem_sizes_data[i] = gemm_sizes[i];
lda_data[i] = lda[i];
ldb_data[i] = ldb[i];
ldd_data[i] = ldd[i];
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
}
const int threadblock_count =
GemmGrouped::sufficient(problem_sizes_data, problem_count);
// Transfer arguments to GPU
gmm_args = gmm_args.to(device, true);
// Obtain pointers for each of arguments (on GPU)
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
ldb_data = lda_data + problem_count;
ldd_data = lda_data + 2 * problem_count;
ptr_a_data = lda_data + 3 * problem_count;
ptr_b_data = lda_data + 4 * problem_count;
ptr_d_data = lda_data + 5 * problem_count;
problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Create GemmGrouped::Arguments using the arguments prepared above
typename GemmGrouped::Arguments args(
problem_sizes_data,
problem_count,
threadblock_count,
epilogue_op,
reinterpret_cast<Element**>(ptr_a_data),
reinterpret_cast<Element**>(ptr_b_data),
reinterpret_cast<Element**>(ptr_d_data),
reinterpret_cast<Element**>(ptr_d_data),
lda_data,
ldb_data,
ldd_data,
ldd_data);
GemmGrouped gemm;
cutlass::Status status =
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status != cutlass::Status::kErrorWorkspaceNull,
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
TORCH_CHECK(
status != cutlass::Status::kErrorInternal,
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to initialize CUTLASS Grouped GEMM kernel.");
// Run CUTLASS group GEMM
status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to run CUTLASS Grouped GEMM kernel.");
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t>
bool group_gemm_dispatch(
at::Device device,
const std::vector<scalar_t*>& aptr,
const std::vector<scalar_t*>& bptr,
const std::vector<scalar_t*>& dptr,
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
int64_t ntensors) {
return false;
}
template <>
bool group_gemm_dispatch(
at::Device device,
const std::vector<float*>& aptr,
const std::vector<float*>& bptr,
const std::vector<float*>& dptr,
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
int64_t ntensors) {
gemm_grouped_cuda_internal<
float,
1,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return true;
}
template <>
bool group_gemm_dispatch(
at::Device device,
const std::vector<c10::Half*>& aptr_,
const std::vector<c10::Half*>& bptr_,
const std::vector<c10::Half*>& dptr_,
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
int64_t ntensors) {
// Check alignment
bool all_pad_8 = true;
for (int i = 0; i < ntensors; i++) {
all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
// Not sure if this is a requirement, on the safe side
all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
}
std::vector<cutlass::half_t*> aptr;
std::vector<cutlass::half_t*> bptr;
std::vector<cutlass::half_t*> dptr;
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(reinterpret_cast<cutlass::half_t*>(aptr_[i]));
bptr.push_back(reinterpret_cast<cutlass::half_t*>(bptr_[i]));
dptr.push_back(reinterpret_cast<cutlass::half_t*>(dptr_[i]));
}
if (all_pad_8) {
gemm_grouped_cuda_internal<
cutlass::half_t,
8,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return true;
} else {
gemm_grouped_cuda_internal<
cutlass::half_t,
1,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return true;
}
// Did not perform GEMM
return false;
}
} // namespace
#endif
#endif
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
if (self.is_nested() && !mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a nested self and non-nested other");
} else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a non-nested self and nested other");
}
// dispatcher should have guaranteed that at least one is nested
auto self_ptr = get_nested_tensor_impl(self);
auto mat2_ptr = get_nested_tensor_impl(mat2);
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
TORCH_CHECK(
ntensors == ntensors2,
"Expected size for the 1st dimension of batch2 tensor to be: ",
ntensors,
" but got: ",
ntensors2,
".");
// create a contiguous output
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr);
std::vector<IntArrayRef> mat2_sizes = NestedTensor_get_sizes(mat2_ptr);
int64_t out_numel = 0;
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
TORCH_CHECK(
self_size1 == mat2_size0,
i,
"-th nested matrices in batch cannot be multiplied (",
self_size0,
"x",
self_size1,
" and ",
mat2_size0,
"x",
mat2_size1,
")");
out_sizemat_ptr[0] = self_size0;
out_sizemat_ptr[1] = mat2_size1;
out_sizemat_ptr += 2;
out_numel += self_size0 * mat2_size1;
}
const Tensor &self_buffer = self_ptr->get_unsafe_storage_as_tensor();
const Tensor &mat2_buffer = mat2_ptr->get_unsafe_storage_as_tensor();
Tensor out_buffer = self_buffer.new_empty(out_numel);
Tensor output = wrap_buffer(out_buffer, out_sizemat);
auto out_ptr = get_nested_tensor_impl(output);
std::vector<IntArrayRef> self_strides = NestedTensor_get_strides(self_ptr);
std::vector<IntArrayRef> mat2_strides = NestedTensor_get_strides(mat2_ptr);
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
const std::vector<int64_t>& out_offsets = out_ptr->get_storage_offsets();
#ifndef USE_ROCM
#ifndef _WIN32
bool success = false;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
self.scalar_type(), "group_gemm_dispatch", [&] {
std::vector<scalar_t*> aptr(ntensors);
std::vector<scalar_t*> bptr(ntensors);
std::vector<scalar_t*> dptr(ntensors);
std::vector<int64_t> lda(ntensors);
std::vector<int64_t> ldb(ntensors);
std::vector<int64_t> ldd(ntensors);
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
bool all_row_major = true;
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef& self_shape = self_sizes[i];
const IntArrayRef& mat2_shape = mat2_sizes[i];
const int64_t &self_size0 = self_shape[0];
const int64_t &self_size1 = self_shape[1];
const int64_t &mat2_size0 = mat2_shape[0];
const int64_t &mat2_size1 = mat2_shape[1];
gemm_sizes.push_back(
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
aptr[i] = self_buffer.data_ptr<scalar_t>() + self_offsets[i];
bptr[i] = mat2_buffer.data_ptr<scalar_t>() + mat2_offsets[i];
dptr[i] = out_buffer.data_ptr<scalar_t>() + out_offsets[i];
all_row_major = all_row_major && (self_strides[i][1] == 1);
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
lda[i] = self_strides[i][0];
ldb[i] = mat2_strides[i][0];
ldd[i] = mat2_size1;
}
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
if (all_row_major &&
self.is_contiguous() &&
mat2.is_contiguous() &&
is_sm8x) {
success = group_gemm_dispatch<scalar_t>(
output.device(),
aptr,
bptr,
dptr,
lda,
ldb,
ldd,
gemm_sizes,
ntensors);
}
});
if (success) {
return output;
}
#endif
#endif
std::vector<Tensor> output_unbind = output.unbind();
for (int64_t i = 0; i < ntensors; i++) {
at::mm_out(
output_unbind[i],
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
mat2_buffer.as_strided(
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
}
return output;
}
} // namespace native
} // namespace at

View File

@ -462,281 +462,5 @@ template void add_padding_kernelLauncher<c10::Half>(
const int batch_size,
const int output_batch_size);
namespace {
#ifndef USE_ROCM
#ifndef _WIN32
template <typename scalar_t>
void gemm_grouped_cuda_internal(
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
const std::vector<scalar_t*>& aptr,
const std::vector<scalar_t*>& bptr,
const std::vector<scalar_t*>& dptr,
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
const int problem_count,
at::Device& device) {
using Element = scalar_t;
using ElementAcc = float;
using OpClass = cutlass::arch::OpClassSimt;
using GemmConfiguration =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
cutlass::arch::Sm80,
Element,
Element,
Element,
ElementAcc>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
Element,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
GemmConfiguration::kAlignmentA,
Element,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
GemmConfiguration::kAlignmentB,
Element,
cutlass::layout::RowMajor,
ElementAcc,
OpClass,
cutlass::arch::Sm80,
typename GemmConfiguration::ThreadblockShape,
typename GemmConfiguration::WarpShape,
typename GemmConfiguration::InstructionShape,
typename GemmConfiguration::EpilogueOutputOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
GemmConfiguration::kStages>::GemmKernel;
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
const int64_t gemm_coord_size =
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
// Number of gmm args not including *problem_sizes
at::Tensor gmm_args = at::empty(
{problem_count * 6 + gemm_coord_size},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
// Obtain pointers for each argument (on host)
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
int64_t* ldb_data = lda_data + problem_count;
int64_t* ldd_data = lda_data + 2 * problem_count;
int64_t* ptr_a_data = lda_data + 3 * problem_count;
int64_t* ptr_b_data = lda_data + 4 * problem_count;
int64_t* ptr_d_data = lda_data + 5 * problem_count;
cutlass::gemm::GemmCoord* problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Set arguments into gmm_args from input args
for (int i = 0; i < problem_count; ++i) {
problem_sizes_data[i] = gemm_sizes[i];
lda_data[i] = lda[i];
ldb_data[i] = ldb[i];
ldd_data[i] = ldd[i];
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
}
const int threadblock_count =
GemmGrouped::sufficient(problem_sizes_data, problem_count);
// Transfer arguments to GPU
gmm_args = gmm_args.to(device, true);
// Obtain pointers for each of arguments (on GPU)
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
ldb_data = lda_data + problem_count;
ldd_data = lda_data + 2 * problem_count;
ptr_a_data = lda_data + 3 * problem_count;
ptr_b_data = lda_data + 4 * problem_count;
ptr_d_data = lda_data + 5 * problem_count;
problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Create GemmGrouped::Arguments using the arguments prepared above
typename GemmGrouped::Arguments args(
problem_sizes_data,
problem_count,
threadblock_count,
epilogue_op,
reinterpret_cast<Element**>(ptr_a_data),
reinterpret_cast<Element**>(ptr_b_data),
reinterpret_cast<Element**>(ptr_d_data),
reinterpret_cast<Element**>(ptr_d_data),
lda_data,
ldb_data,
ldd_data,
ldd_data);
GemmGrouped gemm;
cutlass::Status status =
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status != cutlass::Status::kErrorWorkspaceNull,
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
TORCH_CHECK(
status != cutlass::Status::kErrorInternal,
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to initialize CUTLASS Grouped GEMM kernel.");
// Run CUTLASS group GEMM
status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to run CUTLASS Grouped GEMM kernel.");
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
#endif
} // namespace
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
if (self.is_nested() && !mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a nested self and non-nested other");
} else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a non-nested self and nested other");
}
// TODO currently we only support contiguous NestedTensors
auto self_contiguous = self.contiguous();
auto mat2_contiguous = mat2.contiguous();
// dispatcher should have guaranteed that at least one is nested
auto self_ptr = get_nested_tensor_impl(self_contiguous);
auto mat2_ptr = get_nested_tensor_impl(mat2_contiguous);
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
TORCH_CHECK(
ntensors == ntensors2,
"Expected size for the 1st dimension of batch2 tensor to be: ",
ntensors,
" but got: ",
ntensors2,
".");
const Tensor &self_buffer = self_ptr->get_buffer(),
&mat2_buffer = mat2_ptr->get_buffer();
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
self_strides = NestedTensor_get_strides(self_ptr),
mat2_strides = NestedTensor_get_strides(mat2_ptr);
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
// create a contiguous output
int64_t out_numel = 0;
int64_t a_numel = 0;
int64_t b_numel = 0;
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
std::vector<int64_t> output_offsets;
std::vector<int64_t> a_offsets;
std::vector<int64_t> b_offsets;
std::vector<int64_t> lda;
std::vector<int64_t> ldb;
std::vector<int64_t> ldd;
#ifndef USE_ROCM
#ifndef _WIN32
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
#endif
#endif
bool all_row_major = true;
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
TORCH_CHECK(
self_size1 == mat2_size0,
i,
"-th nested matrices in batch cannot be multiplied (",
self_size0,
"x",
self_size1,
" and ",
mat2_size0,
"x",
mat2_size1,
")");
out_sizemat_ptr[0] = self_size0;
out_sizemat_ptr[1] = mat2_size1;
out_sizemat_ptr += 2;
output_offsets.push_back(out_numel);
out_numel += self_size0 * mat2_size1;
#ifndef USE_ROCM
#ifndef _WIN32
gemm_sizes.push_back(
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
#endif
#endif
lda.push_back(self_strides[i][0]);
ldb.push_back(mat2_strides[i][0]);
ldd.push_back(mat2_size1);
a_offsets.push_back(a_numel);
b_offsets.push_back(b_numel);
a_numel += self_size0 * self_strides[i][0];
b_numel += mat2_size0 * mat2_strides[i][0];
all_row_major = all_row_major && (self_strides[i][1] == 1);
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
}
Tensor out_buffer = self_buffer.new_empty(out_numel);
Tensor output = wrap_buffer(out_buffer, out_sizemat);
at::Device device = output.device();
#ifndef USE_ROCM
#ifndef _WIN32
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
if (is_sm8x && all_row_major) {
if (self.dtype() == at::kFloat) {
std::vector<float*> aptr;
std::vector<float*> bptr;
std::vector<float*> dptr;
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(self_buffer.data_ptr<float>() + a_offsets[i]);
bptr.push_back(mat2_buffer.data_ptr<float>() + b_offsets[i]);
dptr.push_back(out_buffer.data_ptr<float>() + output_offsets[i]);
}
gemm_grouped_cuda_internal<float>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return output;
}
if (self.dtype() == at::kHalf) {
std::vector<c10::Half*> aptr;
std::vector<c10::Half*> bptr;
std::vector<c10::Half*> dptr;
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(self_buffer.data_ptr<c10::Half>() + a_offsets[i]);
bptr.push_back(mat2_buffer.data_ptr<c10::Half>() + b_offsets[i]);
dptr.push_back(out_buffer.data_ptr<c10::Half>() + output_offsets[i]);
}
gemm_grouped_cuda_internal<c10::Half>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return output;
}
}
#endif
#endif
std::vector<Tensor> output_unbind = output.unbind();
for (int64_t i = 0; i < ntensors; i++) {
at::mm_out(
output_unbind[i],
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
mat2_buffer.as_strided(
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
}
return output;
}
} // namespace native
} // namespace at

View File

@ -1,4 +1,5 @@
import argparse
import random
import torch
@ -15,31 +16,38 @@ def bench(nt_a, nt_b, niter):
nt_c = nt_a.bmm(nt_b)
end_event.record()
torch.cuda.synchronize()
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
runtime = (start_event.elapsed_time(end_event)) / niter
return runtime
def sweep_n(ntensor, niter, dtype):
print("n, dtype, ntensor, gflop, runtime, tflop/s")
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
nt_a = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
def sweep_n(niter, dtype):
for ntensor in [4, 8, 16, 32, 64, 128, 256]:
tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)]
nt_a = torch.nested.nested_tensor(
tensors,
dtype=dtype,
device="cuda",
)
nt_b = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
nt_b = torch.nested.nested_tensor(
[t.t() for t in tensors],
dtype=dtype,
device="cuda",
)
runtime = bench(nt_a, nt_b, niter)
tflop = n * n * n * ntensor * 2 / 1e12
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
nt_a_size = torch.ops.aten._nested_tensor_size(nt_a)
lengths = nt_a_size[:, 1]
print(",".join(map(str, [ntensor, dtype, lengths.min().item(),
lengths.float().mean().item(), lengths.max().item(), runtime])))
if __name__ == "__main__":
random.seed(123)
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
parser.add_argument("--niter", default="10", type=int)
parser.add_argument("--ntensor", default="20", type=int)
args = parser.parse_args()
niter = args.niter
ntensor = args.ntensor
sweep_n(ntensor, niter, torch.float32)
sweep_n(ntensor, niter, torch.float16)
print("ntensor,dtype,min_length,mean_length,max_length,runtime")
sweep_n(niter, torch.float32)
sweep_n(niter, torch.float16)

View File

@ -1,9 +1,9 @@
# Owner(s): ["module: nestedtensor"]
import unittest
import torch
import torch.nn
import unittest
import numpy as np
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
@ -1224,6 +1224,16 @@ class TestNestedTensorDeviceType(TestCase):
else:
self.assertEqual(actual, expect)
# test tensorcore path
nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype)
nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
@onlyCUDA
@dtypes(torch.float, torch.double, torch.float16)
def test_bmm_cuda(self, device, dtype):
@ -1235,15 +1245,48 @@ class TestNestedTensorDeviceType(TestCase):
def test_bmm_cpu(self, device, dtype):
self._test_bmm(device, dtype)
# TODO: Re-enable this test once bmm supports non-contiguous inputs.
# # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
# @dtypes(torch.float, torch.double)
# def test_bmm_noncontiguous(self, device, dtype):
# nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
# nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
# self.assertEqual(
# nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
# nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_bmm_noncontiguous(self, device, dtype):
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
self.assertEqual(
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
@dtypes(torch.float, torch.double)
def test_matmul_with_bmm_path(self, device, dtype):
def unbind_rebind_matmul(nt1, nt2):
t1s = nt1.unbind()
t2s = nt2.unbind()
out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
return torch.nested.nested_tensor(out_ts)
# [N, n_head, *, head_dim], [N, n_head, head_dim, *]
N = np.random.randint(2, 5)
n_heads = np.random.randint(2, 5)
head_dim = 3
t1s = []
t2s = []
for _ in range(N):
seq_len1 = np.random.randint(2, 5)
seq_len2 = np.random.randint(2, 5)
t1s.append(torch.randn(n_heads, seq_len1, head_dim))
t2s.append(torch.randn(n_heads, head_dim, seq_len2))
nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
# test with noncontiguous
t3s = []
t4s = []
for _ in range(N):
seq_len = np.random.randint(2, 5)
t3s.append(torch.randn(seq_len, n_heads, head_dim))
t4s.append(torch.randn(seq_len, n_heads, head_dim))
nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2)
nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3)
self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
@dtypes(torch.float, torch.double)