mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use tensor cores for NT bmm (#86856)
Copy of internal diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86856 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c0d47cb17
commit
943b20e7ae
@ -136,6 +136,113 @@ matmul_nested_helper(
|
||||
}
|
||||
}
|
||||
|
||||
Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) {
|
||||
// Tensor self = self_.contiguous();
|
||||
// Tensor mat2 = mat2_.contiguous();
|
||||
// self [N, n_heads, *, head_dim]
|
||||
// mat2 [N, n_heads, head_dim, *]
|
||||
const auto self_ptr = get_nested_tensor_impl(self);
|
||||
const auto mat2_ptr = get_nested_tensor_impl(mat2);
|
||||
// metadata for self
|
||||
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr);
|
||||
std::vector<IntArrayRef> self_strides = NestedTensor_get_strides(self_ptr);
|
||||
std::vector<int64_t> self_offsets = self_ptr->get_storage_offsets();
|
||||
auto opt = self_ptr->get_nested_size_tensor().options();
|
||||
|
||||
// metadata for mat2
|
||||
std::vector<IntArrayRef> mat2_sizes = NestedTensor_get_sizes(mat2_ptr);
|
||||
std::vector<IntArrayRef> mat2_strides = NestedTensor_get_strides(mat2_ptr);
|
||||
std::vector<int64_t> mat2_offsets = mat2_ptr->get_storage_offsets();
|
||||
auto opt2 = mat2_ptr->get_nested_size_tensor().options();
|
||||
|
||||
int64_t N = self_sizes.size();
|
||||
int64_t n_heads = self_sizes[0][0];
|
||||
|
||||
// viewed metadata for self
|
||||
auto self_new_sizes = at::empty({N * n_heads, 2}, opt);
|
||||
int64_t* self_new_sizes_ptr = self_new_sizes.data_ptr<int64_t>();
|
||||
|
||||
auto self_new_strides = at::empty({N * n_heads, 2}, opt);
|
||||
int64_t* self_new_strides_ptr = self_new_strides.data_ptr<int64_t>();
|
||||
std::vector<int64_t> self_new_offsets;
|
||||
|
||||
// viewed metadata for mat2
|
||||
auto mat2_new_sizes = at::empty({N * n_heads, 2}, opt2);
|
||||
int64_t* mat2_new_sizes_ptr = mat2_new_sizes.data_ptr<int64_t>();
|
||||
|
||||
auto mat2_new_strides = at::empty({N * n_heads, 2}, opt2);
|
||||
int64_t* mat2_new_strides_ptr = mat2_new_strides.data_ptr<int64_t>();
|
||||
std::vector<int64_t> mat2_new_offsets;
|
||||
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
const IntArrayRef& self_size_i = self_sizes[i];
|
||||
const IntArrayRef& self_stride_i = self_strides[i];
|
||||
int64_t self_offset = self_offsets[i];
|
||||
|
||||
const IntArrayRef& mat2_size_i = mat2_sizes[i];
|
||||
const IntArrayRef& mat2_stride_i = mat2_strides[i];
|
||||
int64_t mat2_offset = mat2_offsets[i];
|
||||
for (int64_t j = 0; j < n_heads; j++) {
|
||||
auto idx = (i * n_heads + j) * 2;
|
||||
self_new_sizes_ptr[idx] = self_size_i[1];
|
||||
self_new_sizes_ptr[idx + 1] = self_size_i[2];
|
||||
self_new_strides_ptr[idx] = self_stride_i[1];
|
||||
self_new_strides_ptr[idx + 1] = self_stride_i[2];
|
||||
self_new_offsets.push_back(self_offset);
|
||||
self_offset += self_stride_i[0];
|
||||
|
||||
mat2_new_sizes_ptr[idx] = mat2_size_i[1];
|
||||
mat2_new_sizes_ptr[idx + 1] = mat2_size_i[2];
|
||||
mat2_new_strides_ptr[idx] = mat2_stride_i[1];
|
||||
mat2_new_strides_ptr[idx + 1] = mat2_stride_i[2];
|
||||
mat2_new_offsets.push_back(mat2_offset);
|
||||
mat2_offset += mat2_stride_i[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// view self as [N * n_heads, *, head_dim] (collapse first 2 dims)
|
||||
auto viewed_self = create_nested_view_tensor(
|
||||
self, self_new_sizes, self_new_strides, std::vector<int64_t>(self_new_offsets));
|
||||
|
||||
// view mat2 as [N * n_heads, head_dim, *] (collapse first 2_dims)
|
||||
auto viewed_mat2 = create_nested_view_tensor(
|
||||
mat2, mat2_new_sizes, mat2_new_strides, std::vector<int64_t>(mat2_new_offsets));
|
||||
|
||||
// output [N * n_heads, *, *]
|
||||
auto bmm_output = at::bmm(viewed_self, viewed_mat2);
|
||||
|
||||
// generate metadata for viewing output as [N, n_heads, *, *]
|
||||
// output of bmm should be contiguous so stride calculations should hold
|
||||
auto out_new_sizes = at::empty({N, 3}, opt);
|
||||
auto out_new_strides = at::empty({N, 3}, opt);
|
||||
std::vector<int64_t> out_new_offsets;
|
||||
|
||||
int64_t* out_new_sizes_ptr = out_new_sizes.data_ptr<int64_t>();
|
||||
int64_t* out_new_strides_ptr = out_new_strides.data_ptr<int64_t>();
|
||||
|
||||
int64_t out_offset = 0;
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
out_new_offsets.push_back(out_offset);
|
||||
const IntArrayRef& self_size_i = self_sizes[i];
|
||||
const IntArrayRef& mat2_size_i = mat2_sizes[i];
|
||||
auto idx = i * 3;
|
||||
out_new_sizes_ptr[idx] = n_heads;
|
||||
out_new_sizes_ptr[idx + 1] = self_size_i[1];
|
||||
out_new_sizes_ptr[idx + 2] = mat2_size_i[2];
|
||||
out_new_strides_ptr[idx] = self_size_i[1] * mat2_size_i[2];
|
||||
out_new_strides_ptr[idx + 1] = mat2_size_i[2];
|
||||
out_new_strides_ptr[idx + 2] = 1;
|
||||
out_offset += n_heads * (self_size_i[1] * mat2_size_i[2]);
|
||||
}
|
||||
|
||||
auto viewed_out = create_nested_view_tensor(
|
||||
bmm_output, out_new_sizes, out_new_strides, std::vector<int64_t>(out_new_offsets));
|
||||
|
||||
return viewed_out;
|
||||
|
||||
}
|
||||
|
||||
// Note [nested tensor matmul]
|
||||
// This is really a generalized batched matmul dedicated to nested tensors,
|
||||
// where `self` and `mat2` have same number (>= 3) of dimensions.
|
||||
@ -193,6 +300,20 @@ Tensor matmul_nested(const Tensor& self, const Tensor& mat2) {
|
||||
self_dim_size,
|
||||
"second last dimension of mat2 has sizes",
|
||||
mat2_dim_size);
|
||||
|
||||
// use bmm inference-only fast path for [N, n_heads, *, head_dim] [N, n_heads, head_dim, *]
|
||||
if (self.is_cuda() &&
|
||||
self_dim == 4 && self.is_contiguous() &&
|
||||
mat2_dim == 4 && mat2.is_contiguous() &&
|
||||
!(GradMode::is_enabled() && (self.requires_grad() || mat2.requires_grad()))) {
|
||||
auto n_heads = self_sizes.select(0, 1).select(0, 0).item<int64_t>();
|
||||
auto self_first_dim_n_heads = at::all(self_sizes.select(1, 0) == n_heads).item<bool>();
|
||||
auto mat2_first_dim_n_heads = at::all(mat2_sizes.select(1, 0) == n_heads).item<bool>();
|
||||
if (self_first_dim_n_heads && mat2_first_dim_n_heads) {
|
||||
return matmul_with_bmm_nested(self, mat2);
|
||||
}
|
||||
}
|
||||
|
||||
// Construct output size from input sizes
|
||||
Tensor output_sizes = self_sizes.clone();
|
||||
// The last entry in every row of output_sizes should be last column of mat2_sizes
|
||||
|
||||
416
aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu
Normal file
416
aten/src/ATen/native/nested/cuda/NestedTensorMatmul.cu
Normal file
@ -0,0 +1,416 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/PersistentSoftmax.cuh>
|
||||
#include <ATen/native/cuda/block_reduce.cuh>
|
||||
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
#include <cutlass/gemm/device/default_gemm_configuration.h>
|
||||
#include <cutlass/gemm/device/gemm_grouped.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
|
||||
#define BLOCK_DIM 256
|
||||
#define GRID_DIM_Y 16
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
namespace {
|
||||
|
||||
template <
|
||||
typename scalar_t,
|
||||
unsigned int kPad,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename OpClass,
|
||||
typename Arch,
|
||||
typename ThreadBlockShape,
|
||||
typename WarpShape,
|
||||
typename InstructionShape>
|
||||
void gemm_grouped_cuda_internal(
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
const std::vector<scalar_t*>& aptr,
|
||||
const std::vector<scalar_t*>& bptr,
|
||||
const std::vector<scalar_t*>& dptr,
|
||||
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
|
||||
const int problem_count,
|
||||
at::Device& device) {
|
||||
using Element = scalar_t;
|
||||
using ElementAcc = float;
|
||||
|
||||
using GemmConfiguration =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
Arch,
|
||||
Element,
|
||||
Element,
|
||||
Element,
|
||||
ElementAcc>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
Element,
|
||||
LayoutA,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
kPad,
|
||||
Element,
|
||||
LayoutB,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
kPad,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAcc,
|
||||
OpClass,
|
||||
Arch,
|
||||
ThreadBlockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename GemmConfiguration::EpilogueOutputOp,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
GemmConfiguration::kStages>::GemmKernel;
|
||||
|
||||
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
|
||||
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
|
||||
|
||||
const int64_t gemm_coord_size =
|
||||
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
|
||||
// Number of gmm args not including *problem_sizes
|
||||
at::Tensor gmm_args = at::empty(
|
||||
{problem_count * 6 + gemm_coord_size},
|
||||
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
|
||||
|
||||
// Obtain pointers for each argument (on host)
|
||||
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
int64_t* ldb_data = lda_data + problem_count;
|
||||
int64_t* ldd_data = lda_data + 2 * problem_count;
|
||||
int64_t* ptr_a_data = lda_data + 3 * problem_count;
|
||||
int64_t* ptr_b_data = lda_data + 4 * problem_count;
|
||||
int64_t* ptr_d_data = lda_data + 5 * problem_count;
|
||||
cutlass::gemm::GemmCoord* problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Set arguments into gmm_args from input args
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
problem_sizes_data[i] = gemm_sizes[i];
|
||||
lda_data[i] = lda[i];
|
||||
ldb_data[i] = ldb[i];
|
||||
ldd_data[i] = ldd[i];
|
||||
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
|
||||
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
|
||||
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
|
||||
}
|
||||
const int threadblock_count =
|
||||
GemmGrouped::sufficient(problem_sizes_data, problem_count);
|
||||
|
||||
// Transfer arguments to GPU
|
||||
gmm_args = gmm_args.to(device, true);
|
||||
|
||||
// Obtain pointers for each of arguments (on GPU)
|
||||
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
ldb_data = lda_data + problem_count;
|
||||
ldd_data = lda_data + 2 * problem_count;
|
||||
ptr_a_data = lda_data + 3 * problem_count;
|
||||
ptr_b_data = lda_data + 4 * problem_count;
|
||||
ptr_d_data = lda_data + 5 * problem_count;
|
||||
problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Create GemmGrouped::Arguments using the arguments prepared above
|
||||
typename GemmGrouped::Arguments args(
|
||||
problem_sizes_data,
|
||||
problem_count,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<Element**>(ptr_a_data),
|
||||
reinterpret_cast<Element**>(ptr_b_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
lda_data,
|
||||
ldb_data,
|
||||
ldd_data,
|
||||
ldd_data);
|
||||
|
||||
GemmGrouped gemm;
|
||||
cutlass::Status status =
|
||||
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorWorkspaceNull,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorInternal,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
// Run CUTLASS group GEMM
|
||||
status = gemm.run(at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
bool group_gemm_dispatch(
|
||||
at::Device device,
|
||||
const std::vector<scalar_t*>& aptr,
|
||||
const std::vector<scalar_t*>& bptr,
|
||||
const std::vector<scalar_t*>& dptr,
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
|
||||
int64_t ntensors) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool group_gemm_dispatch(
|
||||
at::Device device,
|
||||
const std::vector<float*>& aptr,
|
||||
const std::vector<float*>& bptr,
|
||||
const std::vector<float*>& dptr,
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
|
||||
int64_t ntensors) {
|
||||
|
||||
gemm_grouped_cuda_internal<
|
||||
float,
|
||||
1,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::arch::OpClassSimt,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 32, 8>,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool group_gemm_dispatch(
|
||||
at::Device device,
|
||||
const std::vector<c10::Half*>& aptr_,
|
||||
const std::vector<c10::Half*>& bptr_,
|
||||
const std::vector<c10::Half*>& dptr_,
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes,
|
||||
int64_t ntensors) {
|
||||
|
||||
// Check alignment
|
||||
bool all_pad_8 = true;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
|
||||
|
||||
// Not sure if this is a requirement, on the safe side
|
||||
all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
|
||||
}
|
||||
|
||||
std::vector<cutlass::half_t*> aptr;
|
||||
std::vector<cutlass::half_t*> bptr;
|
||||
std::vector<cutlass::half_t*> dptr;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(reinterpret_cast<cutlass::half_t*>(aptr_[i]));
|
||||
bptr.push_back(reinterpret_cast<cutlass::half_t*>(bptr_[i]));
|
||||
dptr.push_back(reinterpret_cast<cutlass::half_t*>(dptr_[i]));
|
||||
}
|
||||
if (all_pad_8) {
|
||||
gemm_grouped_cuda_internal<
|
||||
cutlass::half_t,
|
||||
8,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return true;
|
||||
} else {
|
||||
gemm_grouped_cuda_internal<
|
||||
cutlass::half_t,
|
||||
1,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::arch::OpClassSimt,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 32, 8>,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return true;
|
||||
}
|
||||
// Did not perform GEMM
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
if (self.is_nested() && !mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a nested self and non-nested other");
|
||||
} else if (!self.is_nested() && mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a non-nested self and nested other");
|
||||
}
|
||||
// dispatcher should have guaranteed that at least one is nested
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
auto mat2_ptr = get_nested_tensor_impl(mat2);
|
||||
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
|
||||
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors2,
|
||||
"Expected size for the 1st dimension of batch2 tensor to be: ",
|
||||
ntensors,
|
||||
" but got: ",
|
||||
ntensors2,
|
||||
".");
|
||||
|
||||
// create a contiguous output
|
||||
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
|
||||
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
|
||||
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
|
||||
|
||||
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr);
|
||||
std::vector<IntArrayRef> mat2_sizes = NestedTensor_get_sizes(mat2_ptr);
|
||||
|
||||
int64_t out_numel = 0;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
|
||||
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
|
||||
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
|
||||
TORCH_CHECK(
|
||||
self_size1 == mat2_size0,
|
||||
i,
|
||||
"-th nested matrices in batch cannot be multiplied (",
|
||||
self_size0,
|
||||
"x",
|
||||
self_size1,
|
||||
" and ",
|
||||
mat2_size0,
|
||||
"x",
|
||||
mat2_size1,
|
||||
")");
|
||||
out_sizemat_ptr[0] = self_size0;
|
||||
out_sizemat_ptr[1] = mat2_size1;
|
||||
out_sizemat_ptr += 2;
|
||||
out_numel += self_size0 * mat2_size1;
|
||||
}
|
||||
const Tensor &self_buffer = self_ptr->get_unsafe_storage_as_tensor();
|
||||
const Tensor &mat2_buffer = mat2_ptr->get_unsafe_storage_as_tensor();
|
||||
Tensor out_buffer = self_buffer.new_empty(out_numel);
|
||||
Tensor output = wrap_buffer(out_buffer, out_sizemat);
|
||||
auto out_ptr = get_nested_tensor_impl(output);
|
||||
|
||||
std::vector<IntArrayRef> self_strides = NestedTensor_get_strides(self_ptr);
|
||||
std::vector<IntArrayRef> mat2_strides = NestedTensor_get_strides(mat2_ptr);
|
||||
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
|
||||
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
|
||||
const std::vector<int64_t>& out_offsets = out_ptr->get_storage_offsets();
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
bool success = false;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
self.scalar_type(), "group_gemm_dispatch", [&] {
|
||||
std::vector<scalar_t*> aptr(ntensors);
|
||||
std::vector<scalar_t*> bptr(ntensors);
|
||||
std::vector<scalar_t*> dptr(ntensors);
|
||||
std::vector<int64_t> lda(ntensors);
|
||||
std::vector<int64_t> ldb(ntensors);
|
||||
std::vector<int64_t> ldd(ntensors);
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
|
||||
bool all_row_major = true;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
const IntArrayRef& self_shape = self_sizes[i];
|
||||
const IntArrayRef& mat2_shape = mat2_sizes[i];
|
||||
const int64_t &self_size0 = self_shape[0];
|
||||
const int64_t &self_size1 = self_shape[1];
|
||||
const int64_t &mat2_size0 = mat2_shape[0];
|
||||
const int64_t &mat2_size1 = mat2_shape[1];
|
||||
gemm_sizes.push_back(
|
||||
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
|
||||
aptr[i] = self_buffer.data_ptr<scalar_t>() + self_offsets[i];
|
||||
bptr[i] = mat2_buffer.data_ptr<scalar_t>() + mat2_offsets[i];
|
||||
dptr[i] = out_buffer.data_ptr<scalar_t>() + out_offsets[i];
|
||||
all_row_major = all_row_major && (self_strides[i][1] == 1);
|
||||
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
|
||||
lda[i] = self_strides[i][0];
|
||||
ldb[i] = mat2_strides[i][0];
|
||||
ldd[i] = mat2_size1;
|
||||
}
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
if (all_row_major &&
|
||||
self.is_contiguous() &&
|
||||
mat2.is_contiguous() &&
|
||||
is_sm8x) {
|
||||
success = group_gemm_dispatch<scalar_t>(
|
||||
output.device(),
|
||||
aptr,
|
||||
bptr,
|
||||
dptr,
|
||||
lda,
|
||||
ldb,
|
||||
ldd,
|
||||
gemm_sizes,
|
||||
ntensors);
|
||||
}
|
||||
});
|
||||
if (success) {
|
||||
return output;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
std::vector<Tensor> output_unbind = output.unbind();
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
at::mm_out(
|
||||
output_unbind[i],
|
||||
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
|
||||
mat2_buffer.as_strided(
|
||||
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
@ -462,281 +462,5 @@ template void add_padding_kernelLauncher<c10::Half>(
|
||||
const int batch_size,
|
||||
const int output_batch_size);
|
||||
|
||||
namespace {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
template <typename scalar_t>
|
||||
void gemm_grouped_cuda_internal(
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
const std::vector<scalar_t*>& aptr,
|
||||
const std::vector<scalar_t*>& bptr,
|
||||
const std::vector<scalar_t*>& dptr,
|
||||
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
|
||||
const int problem_count,
|
||||
at::Device& device) {
|
||||
using Element = scalar_t;
|
||||
using ElementAcc = float;
|
||||
using OpClass = cutlass::arch::OpClassSimt;
|
||||
|
||||
using GemmConfiguration =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
cutlass::arch::Sm80,
|
||||
Element,
|
||||
Element,
|
||||
Element,
|
||||
ElementAcc>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
GemmConfiguration::kAlignmentA,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
GemmConfiguration::kAlignmentB,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAcc,
|
||||
OpClass,
|
||||
cutlass::arch::Sm80,
|
||||
typename GemmConfiguration::ThreadblockShape,
|
||||
typename GemmConfiguration::WarpShape,
|
||||
typename GemmConfiguration::InstructionShape,
|
||||
typename GemmConfiguration::EpilogueOutputOp,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
GemmConfiguration::kStages>::GemmKernel;
|
||||
|
||||
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
|
||||
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
|
||||
|
||||
const int64_t gemm_coord_size =
|
||||
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
|
||||
// Number of gmm args not including *problem_sizes
|
||||
at::Tensor gmm_args = at::empty(
|
||||
{problem_count * 6 + gemm_coord_size},
|
||||
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
|
||||
|
||||
// Obtain pointers for each argument (on host)
|
||||
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
int64_t* ldb_data = lda_data + problem_count;
|
||||
int64_t* ldd_data = lda_data + 2 * problem_count;
|
||||
int64_t* ptr_a_data = lda_data + 3 * problem_count;
|
||||
int64_t* ptr_b_data = lda_data + 4 * problem_count;
|
||||
int64_t* ptr_d_data = lda_data + 5 * problem_count;
|
||||
cutlass::gemm::GemmCoord* problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Set arguments into gmm_args from input args
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
problem_sizes_data[i] = gemm_sizes[i];
|
||||
lda_data[i] = lda[i];
|
||||
ldb_data[i] = ldb[i];
|
||||
ldd_data[i] = ldd[i];
|
||||
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
|
||||
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
|
||||
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
|
||||
}
|
||||
const int threadblock_count =
|
||||
GemmGrouped::sufficient(problem_sizes_data, problem_count);
|
||||
|
||||
// Transfer arguments to GPU
|
||||
gmm_args = gmm_args.to(device, true);
|
||||
|
||||
// Obtain pointers for each of arguments (on GPU)
|
||||
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
ldb_data = lda_data + problem_count;
|
||||
ldd_data = lda_data + 2 * problem_count;
|
||||
ptr_a_data = lda_data + 3 * problem_count;
|
||||
ptr_b_data = lda_data + 4 * problem_count;
|
||||
ptr_d_data = lda_data + 5 * problem_count;
|
||||
problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Create GemmGrouped::Arguments using the arguments prepared above
|
||||
typename GemmGrouped::Arguments args(
|
||||
problem_sizes_data,
|
||||
problem_count,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<Element**>(ptr_a_data),
|
||||
reinterpret_cast<Element**>(ptr_b_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
lda_data,
|
||||
ldb_data,
|
||||
ldd_data,
|
||||
ldd_data);
|
||||
|
||||
GemmGrouped gemm;
|
||||
cutlass::Status status =
|
||||
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorWorkspaceNull,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorInternal,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
// Run CUTLASS group GEMM
|
||||
status = gemm.run(at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
if (self.is_nested() && !mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a nested self and non-nested other");
|
||||
} else if (!self.is_nested() && mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a non-nested self and nested other");
|
||||
}
|
||||
// TODO currently we only support contiguous NestedTensors
|
||||
auto self_contiguous = self.contiguous();
|
||||
auto mat2_contiguous = mat2.contiguous();
|
||||
|
||||
// dispatcher should have guaranteed that at least one is nested
|
||||
auto self_ptr = get_nested_tensor_impl(self_contiguous);
|
||||
auto mat2_ptr = get_nested_tensor_impl(mat2_contiguous);
|
||||
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
|
||||
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors2,
|
||||
"Expected size for the 1st dimension of batch2 tensor to be: ",
|
||||
ntensors,
|
||||
" but got: ",
|
||||
ntensors2,
|
||||
".");
|
||||
const Tensor &self_buffer = self_ptr->get_buffer(),
|
||||
&mat2_buffer = mat2_ptr->get_buffer();
|
||||
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
|
||||
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
|
||||
self_strides = NestedTensor_get_strides(self_ptr),
|
||||
mat2_strides = NestedTensor_get_strides(mat2_ptr);
|
||||
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
|
||||
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
|
||||
|
||||
// create a contiguous output
|
||||
int64_t out_numel = 0;
|
||||
int64_t a_numel = 0;
|
||||
int64_t b_numel = 0;
|
||||
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
|
||||
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
|
||||
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
|
||||
std::vector<int64_t> output_offsets;
|
||||
std::vector<int64_t> a_offsets;
|
||||
std::vector<int64_t> b_offsets;
|
||||
std::vector<int64_t> lda;
|
||||
std::vector<int64_t> ldb;
|
||||
std::vector<int64_t> ldd;
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
|
||||
#endif
|
||||
#endif
|
||||
bool all_row_major = true;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
|
||||
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
|
||||
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
|
||||
TORCH_CHECK(
|
||||
self_size1 == mat2_size0,
|
||||
i,
|
||||
"-th nested matrices in batch cannot be multiplied (",
|
||||
self_size0,
|
||||
"x",
|
||||
self_size1,
|
||||
" and ",
|
||||
mat2_size0,
|
||||
"x",
|
||||
mat2_size1,
|
||||
")");
|
||||
out_sizemat_ptr[0] = self_size0;
|
||||
out_sizemat_ptr[1] = mat2_size1;
|
||||
out_sizemat_ptr += 2;
|
||||
output_offsets.push_back(out_numel);
|
||||
out_numel += self_size0 * mat2_size1;
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
gemm_sizes.push_back(
|
||||
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
|
||||
#endif
|
||||
#endif
|
||||
lda.push_back(self_strides[i][0]);
|
||||
ldb.push_back(mat2_strides[i][0]);
|
||||
ldd.push_back(mat2_size1);
|
||||
a_offsets.push_back(a_numel);
|
||||
b_offsets.push_back(b_numel);
|
||||
a_numel += self_size0 * self_strides[i][0];
|
||||
b_numel += mat2_size0 * mat2_strides[i][0];
|
||||
all_row_major = all_row_major && (self_strides[i][1] == 1);
|
||||
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
|
||||
}
|
||||
Tensor out_buffer = self_buffer.new_empty(out_numel);
|
||||
Tensor output = wrap_buffer(out_buffer, out_sizemat);
|
||||
at::Device device = output.device();
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
if (is_sm8x && all_row_major) {
|
||||
if (self.dtype() == at::kFloat) {
|
||||
std::vector<float*> aptr;
|
||||
std::vector<float*> bptr;
|
||||
std::vector<float*> dptr;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(self_buffer.data_ptr<float>() + a_offsets[i]);
|
||||
bptr.push_back(mat2_buffer.data_ptr<float>() + b_offsets[i]);
|
||||
dptr.push_back(out_buffer.data_ptr<float>() + output_offsets[i]);
|
||||
}
|
||||
gemm_grouped_cuda_internal<float>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return output;
|
||||
}
|
||||
if (self.dtype() == at::kHalf) {
|
||||
std::vector<c10::Half*> aptr;
|
||||
std::vector<c10::Half*> bptr;
|
||||
std::vector<c10::Half*> dptr;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(self_buffer.data_ptr<c10::Half>() + a_offsets[i]);
|
||||
bptr.push_back(mat2_buffer.data_ptr<c10::Half>() + b_offsets[i]);
|
||||
dptr.push_back(out_buffer.data_ptr<c10::Half>() + output_offsets[i]);
|
||||
}
|
||||
gemm_grouped_cuda_internal<c10::Half>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return output;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
std::vector<Tensor> output_unbind = output.unbind();
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
at::mm_out(
|
||||
output_unbind[i],
|
||||
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
|
||||
mat2_buffer.as_strided(
|
||||
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,31 +16,38 @@ def bench(nt_a, nt_b, niter):
|
||||
nt_c = nt_a.bmm(nt_b)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
|
||||
runtime = (start_event.elapsed_time(end_event)) / niter
|
||||
return runtime
|
||||
|
||||
|
||||
def sweep_n(ntensor, niter, dtype):
|
||||
print("n, dtype, ntensor, gflop, runtime, tflop/s")
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
nt_a = torch.nested_tensor(
|
||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
||||
def sweep_n(niter, dtype):
|
||||
for ntensor in [4, 8, 16, 32, 64, 128, 256]:
|
||||
tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)]
|
||||
nt_a = torch.nested.nested_tensor(
|
||||
tensors,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
nt_b = torch.nested_tensor(
|
||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
||||
nt_b = torch.nested.nested_tensor(
|
||||
[t.t() for t in tensors],
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
)
|
||||
runtime = bench(nt_a, nt_b, niter)
|
||||
tflop = n * n * n * ntensor * 2 / 1e12
|
||||
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
|
||||
nt_a_size = torch.ops.aten._nested_tensor_size(nt_a)
|
||||
lengths = nt_a_size[:, 1]
|
||||
print(",".join(map(str, [ntensor, dtype, lengths.min().item(),
|
||||
lengths.float().mean().item(), lengths.max().item(), runtime])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(123)
|
||||
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
|
||||
parser.add_argument("--niter", default="10", type=int)
|
||||
parser.add_argument("--ntensor", default="20", type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
niter = args.niter
|
||||
ntensor = args.ntensor
|
||||
|
||||
sweep_n(ntensor, niter, torch.float32)
|
||||
sweep_n(ntensor, niter, torch.float16)
|
||||
print("ntensor,dtype,min_length,mean_length,max_length,runtime")
|
||||
sweep_n(niter, torch.float32)
|
||||
sweep_n(niter, torch.float16)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# Owner(s): ["module: nestedtensor"]
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
@ -1224,6 +1224,16 @@ class TestNestedTensorDeviceType(TestCase):
|
||||
else:
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
# test tensorcore path
|
||||
nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype)
|
||||
nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype)
|
||||
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
||||
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
||||
if dtype == torch.float16:
|
||||
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
def test_bmm_cuda(self, device, dtype):
|
||||
@ -1235,15 +1245,48 @@ class TestNestedTensorDeviceType(TestCase):
|
||||
def test_bmm_cpu(self, device, dtype):
|
||||
self._test_bmm(device, dtype)
|
||||
|
||||
# TODO: Re-enable this test once bmm supports non-contiguous inputs.
|
||||
# # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||
# @dtypes(torch.float, torch.double)
|
||||
# def test_bmm_noncontiguous(self, device, dtype):
|
||||
# nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||
# nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
||||
# self.assertEqual(
|
||||
# nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
||||
# nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
||||
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_bmm_noncontiguous(self, device, dtype):
|
||||
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
||||
self.assertEqual(
|
||||
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
||||
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_matmul_with_bmm_path(self, device, dtype):
|
||||
def unbind_rebind_matmul(nt1, nt2):
|
||||
t1s = nt1.unbind()
|
||||
t2s = nt2.unbind()
|
||||
out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
|
||||
return torch.nested.nested_tensor(out_ts)
|
||||
|
||||
# [N, n_head, *, head_dim], [N, n_head, head_dim, *]
|
||||
N = np.random.randint(2, 5)
|
||||
n_heads = np.random.randint(2, 5)
|
||||
head_dim = 3
|
||||
t1s = []
|
||||
t2s = []
|
||||
for _ in range(N):
|
||||
seq_len1 = np.random.randint(2, 5)
|
||||
seq_len2 = np.random.randint(2, 5)
|
||||
t1s.append(torch.randn(n_heads, seq_len1, head_dim))
|
||||
t2s.append(torch.randn(n_heads, head_dim, seq_len2))
|
||||
nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
|
||||
nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
|
||||
self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
|
||||
|
||||
# test with noncontiguous
|
||||
t3s = []
|
||||
t4s = []
|
||||
for _ in range(N):
|
||||
seq_len = np.random.randint(2, 5)
|
||||
t3s.append(torch.randn(seq_len, n_heads, head_dim))
|
||||
t4s.append(torch.randn(seq_len, n_heads, head_dim))
|
||||
nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2)
|
||||
nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3)
|
||||
self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
|
||||
|
||||
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
||||
Reference in New Issue
Block a user