mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Revert "Use CUTLASS GEMM for NT bmm [OSS-only] (#85894)"
This reverts commit ef58a132f223d5abf2bd3f8bee380aca6c29d17f. Reverted https://github.com/pytorch/pytorch/pull/85894 on behalf of https://github.com/DanilBaibak due to Break internal build
This commit is contained in:
@ -429,7 +429,6 @@ cu_library(
|
|||||||
"@cuda//:cublas",
|
"@cuda//:cublas",
|
||||||
"@cuda//:cufft",
|
"@cuda//:cufft",
|
||||||
"@cuda//:cusparse",
|
"@cuda//:cusparse",
|
||||||
"@cutlass",
|
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
@ -1674,7 +1673,6 @@ cc_library(
|
|||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
":torch_distributed_cuda",
|
":torch_distributed_cuda",
|
||||||
"@cuda//:nvToolsExt",
|
"@cuda//:nvToolsExt",
|
||||||
"@cutlass",
|
|
||||||
]),
|
]),
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
)
|
)
|
||||||
|
@ -84,12 +84,6 @@ new_local_repository(
|
|||||||
path = "third_party/eigen",
|
path = "third_party/eigen",
|
||||||
)
|
)
|
||||||
|
|
||||||
new_local_repository(
|
|
||||||
name = "cutlass",
|
|
||||||
build_file = "//third_party:cutlass.BUILD",
|
|
||||||
path = "third_party/cutlass",
|
|
||||||
)
|
|
||||||
|
|
||||||
new_local_repository(
|
new_local_repository(
|
||||||
name = "fbgemm",
|
name = "fbgemm",
|
||||||
build_file = "//third_party:fbgemm/BUILD.bazel",
|
build_file = "//third_party:fbgemm/BUILD.bazel",
|
||||||
|
@ -433,7 +433,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_CUDA AND NOT USE_ROCM)
|
if(USE_CUDA AND NOT USE_ROCM)
|
||||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
if(USE_FLASH_ATTENTION)
|
||||||
|
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||||
|
endif()
|
||||||
if($ENV{ATEN_STATIC_CUDA})
|
if($ENV{ATEN_STATIC_CUDA})
|
||||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||||
${CUDA_LIBRARIES}
|
${CUDA_LIBRARIES}
|
||||||
|
@ -1174,8 +1174,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU: bmm_sparse_cpu
|
SparseCPU: bmm_sparse_cpu
|
||||||
SparseCUDA: bmm_sparse_cuda
|
SparseCUDA: bmm_sparse_cuda
|
||||||
NestedTensorCPU: bmm_nested
|
NestedTensorCPU, NestedTensorCUDA: bmm_nested
|
||||||
NestedTensorCUDA: bmm_nested_cuda
|
|
||||||
tags: canonical
|
tags: canonical
|
||||||
|
|
||||||
- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
@ -101,5 +101,27 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
|
|||||||
return splits;
|
return splits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||||
|
const NestedTensorImpl* self_ptr) {
|
||||||
|
int64_t ntensors = self_ptr->size(0);
|
||||||
|
std::vector<IntArrayRef> sizes(ntensors);
|
||||||
|
if (ntensors == 0) {
|
||||||
|
return sizes;
|
||||||
|
}
|
||||||
|
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||||
|
int64_t orig_dim = sizemat.size(1);
|
||||||
|
// nesting scalars has empty sizes
|
||||||
|
if (orig_dim == 0) {
|
||||||
|
return sizes;
|
||||||
|
}
|
||||||
|
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
|
||||||
|
|
||||||
|
for (const auto i : c10::irange(ntensors)) {
|
||||||
|
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
|
||||||
|
sizemat_ptr += orig_dim;
|
||||||
|
}
|
||||||
|
return sizes;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
@ -86,28 +86,8 @@ inline at::Tensor create_nested_view_tensor(
|
|||||||
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
|
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
|
||||||
|
|
||||||
// The sizes of the underlying tensors
|
// The sizes of the underlying tensors
|
||||||
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
|
std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||||
const NestedTensorImpl* self_ptr) {
|
const NestedTensorImpl* self_ptr);
|
||||||
int64_t ntensors = self_ptr->size(0);
|
|
||||||
std::vector<IntArrayRef> sizes(ntensors);
|
|
||||||
if (ntensors == 0) {
|
|
||||||
return sizes;
|
|
||||||
}
|
|
||||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
|
||||||
int64_t orig_dim = sizemat.size(1);
|
|
||||||
// nesting scalars has empty sizes
|
|
||||||
if (orig_dim == 0) {
|
|
||||||
return sizes;
|
|
||||||
}
|
|
||||||
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
|
|
||||||
|
|
||||||
for (const auto i : c10::irange(ntensors)) {
|
|
||||||
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
|
|
||||||
sizemat_ptr += orig_dim;
|
|
||||||
}
|
|
||||||
return sizes;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
|
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
|
||||||
const NestedTensorImpl& nt);
|
const NestedTensorImpl& nt);
|
||||||
|
@ -15,15 +15,6 @@
|
|||||||
#include <c10/cuda/CUDAStream.h>
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
|
||||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#include <cutlass/gemm/device/default_gemm_configuration.h>
|
|
||||||
#include <cutlass/gemm/device/gemm_grouped.h>
|
|
||||||
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <ATen/NestedTensorImpl.h>
|
|
||||||
|
|
||||||
#define BLOCK_DIM 256
|
#define BLOCK_DIM 256
|
||||||
#define GRID_DIM_Y 16
|
#define GRID_DIM_Y 16
|
||||||
@ -347,8 +338,7 @@ __global__ void add_padding_3(
|
|||||||
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
||||||
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
||||||
const int i2 = i % output_sizes_3;
|
const int i2 = i % output_sizes_3;
|
||||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
|
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
|
||||||
i2 < sizes_i[2]) {
|
|
||||||
const int offset = offsets[batch_id];
|
const int offset = offsets[batch_id];
|
||||||
const int input_offset =
|
const int input_offset =
|
||||||
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
||||||
@ -362,8 +352,7 @@ __global__ void add_padding_3(
|
|||||||
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
||||||
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
||||||
const int i2 = i % output_sizes_3;
|
const int i2 = i % output_sizes_3;
|
||||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
|
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
|
||||||
i2 < sizes_i[2]) {
|
|
||||||
const int offset = offsets[batch_id];
|
const int offset = offsets[batch_id];
|
||||||
const int input_offset =
|
const int input_offset =
|
||||||
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
||||||
@ -460,269 +449,5 @@ template void add_padding_kernelLauncher<c10::Half>(
|
|||||||
const int batch_size,
|
const int batch_size,
|
||||||
const int output_batch_size);
|
const int output_batch_size);
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
template <typename scalar_t>
|
|
||||||
void gemm_grouped_cuda_internal(
|
|
||||||
const std::vector<int64_t>& lda,
|
|
||||||
const std::vector<int64_t>& ldb,
|
|
||||||
const std::vector<int64_t>& ldd,
|
|
||||||
const std::vector<scalar_t*>& aptr,
|
|
||||||
const std::vector<scalar_t*>& bptr,
|
|
||||||
const std::vector<scalar_t*>& dptr,
|
|
||||||
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
|
|
||||||
const int problem_count,
|
|
||||||
at::Device& device) {
|
|
||||||
using Element = scalar_t;
|
|
||||||
using ElementAcc = float;
|
|
||||||
using OpClass = cutlass::arch::OpClassSimt;
|
|
||||||
|
|
||||||
using GemmConfiguration =
|
|
||||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
|
||||||
OpClass,
|
|
||||||
cutlass::arch::Sm80,
|
|
||||||
Element,
|
|
||||||
Element,
|
|
||||||
Element,
|
|
||||||
ElementAcc>;
|
|
||||||
|
|
||||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
||||||
Element,
|
|
||||||
cutlass::layout::RowMajor,
|
|
||||||
cutlass::ComplexTransform::kNone,
|
|
||||||
GemmConfiguration::kAlignmentA,
|
|
||||||
Element,
|
|
||||||
cutlass::layout::RowMajor,
|
|
||||||
cutlass::ComplexTransform::kNone,
|
|
||||||
GemmConfiguration::kAlignmentB,
|
|
||||||
Element,
|
|
||||||
cutlass::layout::RowMajor,
|
|
||||||
ElementAcc,
|
|
||||||
OpClass,
|
|
||||||
cutlass::arch::Sm80,
|
|
||||||
typename GemmConfiguration::ThreadblockShape,
|
|
||||||
typename GemmConfiguration::WarpShape,
|
|
||||||
typename GemmConfiguration::InstructionShape,
|
|
||||||
typename GemmConfiguration::EpilogueOutputOp,
|
|
||||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
||||||
GemmConfiguration::kStages>::GemmKernel;
|
|
||||||
|
|
||||||
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
||||||
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
|
|
||||||
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
|
|
||||||
|
|
||||||
const int64_t gemm_coord_size =
|
|
||||||
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
|
|
||||||
// Number of gmm args not including *problem_sizes
|
|
||||||
at::Tensor gmm_args = at::empty(
|
|
||||||
{problem_count * 6 + gemm_coord_size},
|
|
||||||
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
|
|
||||||
|
|
||||||
// Obtain pointers for each argument (on host)
|
|
||||||
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
|
||||||
int64_t* ldb_data = lda_data + problem_count;
|
|
||||||
int64_t* ldd_data = lda_data + 2 * problem_count;
|
|
||||||
int64_t* ptr_a_data = lda_data + 3 * problem_count;
|
|
||||||
int64_t* ptr_b_data = lda_data + 4 * problem_count;
|
|
||||||
int64_t* ptr_d_data = lda_data + 5 * problem_count;
|
|
||||||
cutlass::gemm::GemmCoord* problem_sizes_data =
|
|
||||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
|
||||||
|
|
||||||
// Set arguments into gmm_args from input args
|
|
||||||
for (int i = 0; i < problem_count; ++i) {
|
|
||||||
problem_sizes_data[i] = gemm_sizes[i];
|
|
||||||
lda_data[i] = lda[i];
|
|
||||||
ldb_data[i] = ldb[i];
|
|
||||||
ldd_data[i] = ldd[i];
|
|
||||||
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
|
|
||||||
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
|
|
||||||
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
|
|
||||||
}
|
|
||||||
const int threadblock_count =
|
|
||||||
GemmGrouped::sufficient(problem_sizes_data, problem_count);
|
|
||||||
|
|
||||||
// Transfer arguments to GPU
|
|
||||||
gmm_args = gmm_args.to(device, true);
|
|
||||||
|
|
||||||
// Obtain pointers for each of arguments (on GPU)
|
|
||||||
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
|
||||||
ldb_data = lda_data + problem_count;
|
|
||||||
ldd_data = lda_data + 2 * problem_count;
|
|
||||||
ptr_a_data = lda_data + 3 * problem_count;
|
|
||||||
ptr_b_data = lda_data + 4 * problem_count;
|
|
||||||
ptr_d_data = lda_data + 5 * problem_count;
|
|
||||||
problem_sizes_data =
|
|
||||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
|
||||||
|
|
||||||
// Create GemmGrouped::Arguments using the arguments prepared above
|
|
||||||
typename GemmGrouped::Arguments args(
|
|
||||||
problem_sizes_data,
|
|
||||||
problem_count,
|
|
||||||
threadblock_count,
|
|
||||||
epilogue_op,
|
|
||||||
reinterpret_cast<Element**>(ptr_a_data),
|
|
||||||
reinterpret_cast<Element**>(ptr_b_data),
|
|
||||||
reinterpret_cast<Element**>(ptr_d_data),
|
|
||||||
reinterpret_cast<Element**>(ptr_d_data),
|
|
||||||
lda_data,
|
|
||||||
ldb_data,
|
|
||||||
ldd_data,
|
|
||||||
ldd_data);
|
|
||||||
|
|
||||||
GemmGrouped gemm;
|
|
||||||
cutlass::Status status =
|
|
||||||
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
|
|
||||||
TORCH_CHECK(
|
|
||||||
status != cutlass::Status::kErrorWorkspaceNull,
|
|
||||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
status != cutlass::Status::kErrorInternal,
|
|
||||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
|
|
||||||
TORCH_CHECK(
|
|
||||||
status == cutlass::Status::kSuccess,
|
|
||||||
"Failed to initialize CUTLASS Grouped GEMM kernel.");
|
|
||||||
|
|
||||||
// Run CUTLASS group GEMM
|
|
||||||
status = gemm.run(at::cuda::getCurrentCUDAStream());
|
|
||||||
TORCH_CHECK(
|
|
||||||
status == cutlass::Status::kSuccess,
|
|
||||||
"Failed to run CUTLASS Grouped GEMM kernel.");
|
|
||||||
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
|
|
||||||
if (self.is_nested() && !mat2.is_nested()) {
|
|
||||||
AT_ERROR(
|
|
||||||
"Expected both to be nested, but got a nested self and non-nested other");
|
|
||||||
} else if (!self.is_nested() && mat2.is_nested()) {
|
|
||||||
AT_ERROR(
|
|
||||||
"Expected both to be nested, but got a non-nested self and nested other");
|
|
||||||
}
|
|
||||||
// dispatcher should have guaranteed that at least one is nested
|
|
||||||
auto self_ptr = get_nested_tensor_impl(self);
|
|
||||||
auto mat2_ptr = get_nested_tensor_impl(mat2);
|
|
||||||
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
|
|
||||||
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
|
|
||||||
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
|
|
||||||
TORCH_CHECK(
|
|
||||||
ntensors == ntensors2,
|
|
||||||
"Expected size for the 1st dimension of batch2 tensor to be: ",
|
|
||||||
ntensors,
|
|
||||||
" but got: ",
|
|
||||||
ntensors2,
|
|
||||||
".");
|
|
||||||
const Tensor &self_buffer = self_ptr->get_buffer(),
|
|
||||||
&mat2_buffer = mat2_ptr->get_buffer();
|
|
||||||
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
|
|
||||||
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
|
|
||||||
self_strides = NestedTensor_get_strides(self_ptr),
|
|
||||||
mat2_strides = NestedTensor_get_strides(mat2_ptr);
|
|
||||||
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
|
|
||||||
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
|
|
||||||
|
|
||||||
// create a contiguous output
|
|
||||||
int64_t out_numel = 0;
|
|
||||||
int64_t a_numel = 0;
|
|
||||||
int64_t b_numel = 0;
|
|
||||||
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
|
|
||||||
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
|
|
||||||
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
|
|
||||||
std::vector<int64_t> output_offsets;
|
|
||||||
std::vector<int64_t> a_offsets;
|
|
||||||
std::vector<int64_t> b_offsets;
|
|
||||||
std::vector<int64_t> lda;
|
|
||||||
std::vector<int64_t> ldb;
|
|
||||||
std::vector<int64_t> ldd;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
|
|
||||||
#endif
|
|
||||||
bool all_row_major = true;
|
|
||||||
for (int64_t i = 0; i < ntensors; i++) {
|
|
||||||
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
|
|
||||||
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
|
|
||||||
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
|
|
||||||
TORCH_CHECK(
|
|
||||||
self_size1 == mat2_size0,
|
|
||||||
i,
|
|
||||||
"-th nested matrices in batch cannot be multiplied (",
|
|
||||||
self_size0,
|
|
||||||
"x",
|
|
||||||
self_size1,
|
|
||||||
" and ",
|
|
||||||
mat2_size0,
|
|
||||||
"x",
|
|
||||||
mat2_size1,
|
|
||||||
")");
|
|
||||||
out_sizemat_ptr[0] = self_size0;
|
|
||||||
out_sizemat_ptr[1] = mat2_size1;
|
|
||||||
out_sizemat_ptr += 2;
|
|
||||||
output_offsets.push_back(out_numel);
|
|
||||||
out_numel += self_size0 * mat2_size1;
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
gemm_sizes.push_back(
|
|
||||||
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
|
|
||||||
#endif
|
|
||||||
lda.push_back(self_strides[i][0]);
|
|
||||||
ldb.push_back(mat2_strides[i][0]);
|
|
||||||
ldd.push_back(mat2_size1);
|
|
||||||
a_offsets.push_back(a_numel);
|
|
||||||
b_offsets.push_back(b_numel);
|
|
||||||
a_numel += self_size0 * self_strides[i][0];
|
|
||||||
b_numel += mat2_size0 * mat2_strides[i][0];
|
|
||||||
all_row_major = all_row_major && (self_strides[i][1] == 1);
|
|
||||||
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
|
|
||||||
}
|
|
||||||
Tensor out_buffer = self_buffer.new_empty(out_numel);
|
|
||||||
Tensor output = wrap_buffer(out_buffer, out_sizemat);
|
|
||||||
at::Device device = output.device();
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
||||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
|
||||||
if (is_sm8x && all_row_major) {
|
|
||||||
if (self.dtype() == at::kFloat) {
|
|
||||||
std::vector<float*> aptr;
|
|
||||||
std::vector<float*> bptr;
|
|
||||||
std::vector<float*> dptr;
|
|
||||||
for (int64_t i = 0; i < ntensors; i++) {
|
|
||||||
aptr.push_back(self_buffer.data_ptr<float>() + a_offsets[i]);
|
|
||||||
bptr.push_back(mat2_buffer.data_ptr<float>() + b_offsets[i]);
|
|
||||||
dptr.push_back(out_buffer.data_ptr<float>() + output_offsets[i]);
|
|
||||||
}
|
|
||||||
gemm_grouped_cuda_internal<float>(
|
|
||||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
if (self.dtype() == at::kHalf) {
|
|
||||||
std::vector<c10::Half*> aptr;
|
|
||||||
std::vector<c10::Half*> bptr;
|
|
||||||
std::vector<c10::Half*> dptr;
|
|
||||||
for (int64_t i = 0; i < ntensors; i++) {
|
|
||||||
aptr.push_back(self_buffer.data_ptr<c10::Half>() + a_offsets[i]);
|
|
||||||
bptr.push_back(mat2_buffer.data_ptr<c10::Half>() + b_offsets[i]);
|
|
||||||
dptr.push_back(out_buffer.data_ptr<c10::Half>() + output_offsets[i]);
|
|
||||||
}
|
|
||||||
gemm_grouped_cuda_internal<c10::Half>(
|
|
||||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
std::vector<Tensor> output_unbind = output.unbind();
|
|
||||||
for (int64_t i = 0; i < ntensors; i++) {
|
|
||||||
at::mm_out(
|
|
||||||
output_unbind[i],
|
|
||||||
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
|
|
||||||
mat2_buffer.as_strided(
|
|
||||||
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
|
|
||||||
}
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def bench(nt_a, nt_b, niter):
|
|
||||||
# Warmup
|
|
||||||
nt_c = nt_a.bmm(nt_b)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
start_event.record()
|
|
||||||
for iter in range(niter):
|
|
||||||
nt_c = nt_a.bmm(nt_b)
|
|
||||||
end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
|
|
||||||
return runtime
|
|
||||||
|
|
||||||
|
|
||||||
def sweep_n(ntensor, niter, dtype):
|
|
||||||
print("n, dtype, ntensor, gflop, runtime, tflop/s")
|
|
||||||
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
|
||||||
nt_a = torch.nested_tensor(
|
|
||||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
|
||||||
)
|
|
||||||
nt_b = torch.nested_tensor(
|
|
||||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
|
||||||
)
|
|
||||||
runtime = bench(nt_a, nt_b, niter)
|
|
||||||
tflop = n * n * n * ntensor * 2 / 1e12
|
|
||||||
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
|
|
||||||
parser.add_argument("--niter", default="10", type=int)
|
|
||||||
parser.add_argument("--ntensor", default="20", type=int)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
niter = args.niter
|
|
||||||
ntensor = args.ntensor
|
|
||||||
|
|
||||||
sweep_n(ntensor, niter, torch.float32)
|
|
||||||
sweep_n(ntensor, niter, torch.float16)
|
|
@ -8,7 +8,6 @@ from torch.testing._internal.common_device_type import (
|
|||||||
dtypesIfCUDA,
|
dtypesIfCUDA,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
skipMeta,
|
skipMeta,
|
||||||
onlyCUDA,
|
|
||||||
onlyCPU
|
onlyCPU
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_dtype import floating_types_and_half
|
from torch.testing._internal.common_dtype import floating_types_and_half
|
||||||
@ -972,7 +971,9 @@ class TestNestedTensorDeviceType(TestCase):
|
|||||||
torch.nn.functional.softmax(nt_contiguous, -1),
|
torch.nn.functional.softmax(nt_contiguous, -1),
|
||||||
torch.nn.functional.softmax(nt_noncontiguous, -1))
|
torch.nn.functional.softmax(nt_noncontiguous, -1))
|
||||||
|
|
||||||
def _test_bmm(self, device, dtype):
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||||
|
@dtypes(torch.float, torch.double)
|
||||||
|
def test_bmm(self, device, dtype):
|
||||||
# error case: one is nested but the other is not
|
# error case: one is nested but the other is not
|
||||||
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
||||||
t = torch.randn(4, device=device, dtype=dtype)
|
t = torch.randn(4, device=device, dtype=dtype)
|
||||||
@ -1058,31 +1059,16 @@ class TestNestedTensorDeviceType(TestCase):
|
|||||||
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
|
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
|
||||||
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
||||||
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
||||||
if dtype == torch.float16:
|
self.assertEqual(actual, expect)
|
||||||
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
||||||
else:
|
|
||||||
self.assertEqual(actual, expect)
|
|
||||||
|
|
||||||
@onlyCUDA
|
|
||||||
@dtypes(torch.float, torch.double, torch.float16)
|
|
||||||
def test_bmm_cuda(self, device, dtype):
|
|
||||||
self._test_bmm(device, dtype)
|
|
||||||
|
|
||||||
@onlyCPU
|
|
||||||
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(torch.float, torch.double)
|
||||||
def test_bmm_cpu(self, device, dtype):
|
def test_bmm_noncontiguous(self, device, dtype):
|
||||||
self._test_bmm(device, dtype)
|
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||||
|
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
||||||
# TODO: Re-enable this test once bmm supports non-contiguous inputs.
|
self.assertEqual(
|
||||||
# # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
||||||
# @dtypes(torch.float, torch.double)
|
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
||||||
# def test_bmm_noncontiguous(self, device, dtype):
|
|
||||||
# nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
||||||
# nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
|
||||||
# self.assertEqual(
|
|
||||||
# nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
|
||||||
# nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
|
||||||
|
|
||||||
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(torch.float, torch.double)
|
||||||
|
11
third_party/cutlass.BUILD
vendored
11
third_party/cutlass.BUILD
vendored
@ -1,11 +0,0 @@
|
|||||||
# Description:
|
|
||||||
# CUDA Templates for Linear Algebra Subroutines
|
|
||||||
|
|
||||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "cutlass",
|
|
||||||
hdrs = glob(["include/**/*.h"]),
|
|
||||||
includes = ["include/"],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
)
|
|
Reference in New Issue
Block a user