Revert "Use CUTLASS GEMM for NT bmm [OSS-only] (#85894)"

This reverts commit ef58a132f223d5abf2bd3f8bee380aca6c29d17f.

Reverted https://github.com/pytorch/pytorch/pull/85894 on behalf of https://github.com/DanilBaibak due to Break internal build
This commit is contained in:
PyTorch MergeBot
2022-10-13 15:28:09 +00:00
parent b97ae59e29
commit d169f950da
10 changed files with 40 additions and 390 deletions

View File

@ -429,7 +429,6 @@ cu_library(
"@cuda//:cublas",
"@cuda//:cufft",
"@cuda//:cusparse",
"@cutlass",
],
alwayslink = True,
)
@ -1674,7 +1673,6 @@ cc_library(
] + if_cuda([
":torch_distributed_cuda",
"@cuda//:nvToolsExt",
"@cutlass",
]),
alwayslink = True,
)

View File

@ -84,12 +84,6 @@ new_local_repository(
path = "third_party/eigen",
)
new_local_repository(
name = "cutlass",
build_file = "//third_party:cutlass.BUILD",
path = "third_party/cutlass",
)
new_local_repository(
name = "fbgemm",
build_file = "//third_party:fbgemm/BUILD.bazel",

View File

@ -433,7 +433,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()
if(USE_CUDA AND NOT USE_ROCM)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
if(USE_FLASH_ATTENTION)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
endif()
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}

View File

@ -1174,8 +1174,7 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
NestedTensorCPU, NestedTensorCUDA: bmm_nested
tags: canonical
- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -101,5 +101,27 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
return splits;
}
std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}
} // namespace native
} // namespace at

View File

@ -86,28 +86,8 @@ inline at::Tensor create_nested_view_tensor(
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
// The sizes of the underlying tensors
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}
std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr);
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
const NestedTensorImpl& nt);

View File

@ -15,15 +15,6 @@
#include <c10/cuda/CUDAStream.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#ifndef USE_ROCM
#include <cutlass/gemm/device/default_gemm_configuration.h>
#include <cutlass/gemm/device/gemm_grouped.h>
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
#endif
#include <ATen/NestedTensorImpl.h>
#define BLOCK_DIM 256
#define GRID_DIM_Y 16
@ -347,8 +338,7 @@ __global__ void add_padding_3(
const int i0 = i / (output_sizes_2 * output_sizes_3);
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
const int i2 = i % output_sizes_3;
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
i2 < sizes_i[2]) {
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
const int offset = offsets[batch_id];
const int input_offset =
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
@ -362,8 +352,7 @@ __global__ void add_padding_3(
const int i0 = i / (output_sizes_2 * output_sizes_3);
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
const int i2 = i % output_sizes_3;
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
i2 < sizes_i[2]) {
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
const int offset = offsets[batch_id];
const int input_offset =
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
@ -460,269 +449,5 @@ template void add_padding_kernelLauncher<c10::Half>(
const int batch_size,
const int output_batch_size);
namespace {
#ifndef USE_ROCM
template <typename scalar_t>
void gemm_grouped_cuda_internal(
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
const std::vector<scalar_t*>& aptr,
const std::vector<scalar_t*>& bptr,
const std::vector<scalar_t*>& dptr,
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
const int problem_count,
at::Device& device) {
using Element = scalar_t;
using ElementAcc = float;
using OpClass = cutlass::arch::OpClassSimt;
using GemmConfiguration =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
cutlass::arch::Sm80,
Element,
Element,
Element,
ElementAcc>;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
Element,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
GemmConfiguration::kAlignmentA,
Element,
cutlass::layout::RowMajor,
cutlass::ComplexTransform::kNone,
GemmConfiguration::kAlignmentB,
Element,
cutlass::layout::RowMajor,
ElementAcc,
OpClass,
cutlass::arch::Sm80,
typename GemmConfiguration::ThreadblockShape,
typename GemmConfiguration::WarpShape,
typename GemmConfiguration::InstructionShape,
typename GemmConfiguration::EpilogueOutputOp,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
GemmConfiguration::kStages>::GemmKernel;
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
const int64_t gemm_coord_size =
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
// Number of gmm args not including *problem_sizes
at::Tensor gmm_args = at::empty(
{problem_count * 6 + gemm_coord_size},
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
// Obtain pointers for each argument (on host)
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
int64_t* ldb_data = lda_data + problem_count;
int64_t* ldd_data = lda_data + 2 * problem_count;
int64_t* ptr_a_data = lda_data + 3 * problem_count;
int64_t* ptr_b_data = lda_data + 4 * problem_count;
int64_t* ptr_d_data = lda_data + 5 * problem_count;
cutlass::gemm::GemmCoord* problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Set arguments into gmm_args from input args
for (int i = 0; i < problem_count; ++i) {
problem_sizes_data[i] = gemm_sizes[i];
lda_data[i] = lda[i];
ldb_data[i] = ldb[i];
ldd_data[i] = ldd[i];
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
}
const int threadblock_count =
GemmGrouped::sufficient(problem_sizes_data, problem_count);
// Transfer arguments to GPU
gmm_args = gmm_args.to(device, true);
// Obtain pointers for each of arguments (on GPU)
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
ldb_data = lda_data + problem_count;
ldd_data = lda_data + 2 * problem_count;
ptr_a_data = lda_data + 3 * problem_count;
ptr_b_data = lda_data + 4 * problem_count;
ptr_d_data = lda_data + 5 * problem_count;
problem_sizes_data =
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
// Create GemmGrouped::Arguments using the arguments prepared above
typename GemmGrouped::Arguments args(
problem_sizes_data,
problem_count,
threadblock_count,
epilogue_op,
reinterpret_cast<Element**>(ptr_a_data),
reinterpret_cast<Element**>(ptr_b_data),
reinterpret_cast<Element**>(ptr_d_data),
reinterpret_cast<Element**>(ptr_d_data),
lda_data,
ldb_data,
ldd_data,
ldd_data);
GemmGrouped gemm;
cutlass::Status status =
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status != cutlass::Status::kErrorWorkspaceNull,
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
TORCH_CHECK(
status != cutlass::Status::kErrorInternal,
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to initialize CUTLASS Grouped GEMM kernel.");
// Run CUTLASS group GEMM
status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"Failed to run CUTLASS Grouped GEMM kernel.");
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
} // namespace
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
if (self.is_nested() && !mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a nested self and non-nested other");
} else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR(
"Expected both to be nested, but got a non-nested self and nested other");
}
// dispatcher should have guaranteed that at least one is nested
auto self_ptr = get_nested_tensor_impl(self);
auto mat2_ptr = get_nested_tensor_impl(mat2);
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
TORCH_CHECK(
ntensors == ntensors2,
"Expected size for the 1st dimension of batch2 tensor to be: ",
ntensors,
" but got: ",
ntensors2,
".");
const Tensor &self_buffer = self_ptr->get_buffer(),
&mat2_buffer = mat2_ptr->get_buffer();
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
self_strides = NestedTensor_get_strides(self_ptr),
mat2_strides = NestedTensor_get_strides(mat2_ptr);
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
// create a contiguous output
int64_t out_numel = 0;
int64_t a_numel = 0;
int64_t b_numel = 0;
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
std::vector<int64_t> output_offsets;
std::vector<int64_t> a_offsets;
std::vector<int64_t> b_offsets;
std::vector<int64_t> lda;
std::vector<int64_t> ldb;
std::vector<int64_t> ldd;
#ifndef USE_ROCM
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
#endif
bool all_row_major = true;
for (int64_t i = 0; i < ntensors; i++) {
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
TORCH_CHECK(
self_size1 == mat2_size0,
i,
"-th nested matrices in batch cannot be multiplied (",
self_size0,
"x",
self_size1,
" and ",
mat2_size0,
"x",
mat2_size1,
")");
out_sizemat_ptr[0] = self_size0;
out_sizemat_ptr[1] = mat2_size1;
out_sizemat_ptr += 2;
output_offsets.push_back(out_numel);
out_numel += self_size0 * mat2_size1;
#ifndef USE_ROCM
gemm_sizes.push_back(
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
#endif
lda.push_back(self_strides[i][0]);
ldb.push_back(mat2_strides[i][0]);
ldd.push_back(mat2_size1);
a_offsets.push_back(a_numel);
b_offsets.push_back(b_numel);
a_numel += self_size0 * self_strides[i][0];
b_numel += mat2_size0 * mat2_strides[i][0];
all_row_major = all_row_major && (self_strides[i][1] == 1);
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
}
Tensor out_buffer = self_buffer.new_empty(out_numel);
Tensor output = wrap_buffer(out_buffer, out_sizemat);
at::Device device = output.device();
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
if (is_sm8x && all_row_major) {
if (self.dtype() == at::kFloat) {
std::vector<float*> aptr;
std::vector<float*> bptr;
std::vector<float*> dptr;
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(self_buffer.data_ptr<float>() + a_offsets[i]);
bptr.push_back(mat2_buffer.data_ptr<float>() + b_offsets[i]);
dptr.push_back(out_buffer.data_ptr<float>() + output_offsets[i]);
}
gemm_grouped_cuda_internal<float>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return output;
}
if (self.dtype() == at::kHalf) {
std::vector<c10::Half*> aptr;
std::vector<c10::Half*> bptr;
std::vector<c10::Half*> dptr;
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(self_buffer.data_ptr<c10::Half>() + a_offsets[i]);
bptr.push_back(mat2_buffer.data_ptr<c10::Half>() + b_offsets[i]);
dptr.push_back(out_buffer.data_ptr<c10::Half>() + output_offsets[i]);
}
gemm_grouped_cuda_internal<c10::Half>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return output;
}
}
#endif
std::vector<Tensor> output_unbind = output.unbind();
for (int64_t i = 0; i < ntensors; i++) {
at::mm_out(
output_unbind[i],
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
mat2_buffer.as_strided(
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
}
return output;
}
} // namespace native
} // namespace at

View File

@ -1,45 +0,0 @@
import argparse
import torch
def bench(nt_a, nt_b, niter):
# Warmup
nt_c = nt_a.bmm(nt_b)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for iter in range(niter):
nt_c = nt_a.bmm(nt_b)
end_event.record()
torch.cuda.synchronize()
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
return runtime
def sweep_n(ntensor, niter, dtype):
print("n, dtype, ntensor, gflop, runtime, tflop/s")
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
nt_a = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
)
nt_b = torch.nested_tensor(
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
)
runtime = bench(nt_a, nt_b, niter)
tflop = n * n * n * ntensor * 2 / 1e12
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
parser.add_argument("--niter", default="10", type=int)
parser.add_argument("--ntensor", default="20", type=int)
args = parser.parse_args()
niter = args.niter
ntensor = args.ntensor
sweep_n(ntensor, niter, torch.float32)
sweep_n(ntensor, niter, torch.float16)

View File

@ -8,7 +8,6 @@ from torch.testing._internal.common_device_type import (
dtypesIfCUDA,
instantiate_device_type_tests,
skipMeta,
onlyCUDA,
onlyCPU
)
from torch.testing._internal.common_dtype import floating_types_and_half
@ -972,7 +971,9 @@ class TestNestedTensorDeviceType(TestCase):
torch.nn.functional.softmax(nt_contiguous, -1),
torch.nn.functional.softmax(nt_noncontiguous, -1))
def _test_bmm(self, device, dtype):
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_bmm(self, device, dtype):
# error case: one is nested but the other is not
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
t = torch.randn(4, device=device, dtype=dtype)
@ -1058,31 +1059,16 @@ class TestNestedTensorDeviceType(TestCase):
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
self.assertEqual(actual, expect)
@onlyCUDA
@dtypes(torch.float, torch.double, torch.float16)
def test_bmm_cuda(self, device, dtype):
self._test_bmm(device, dtype)
@onlyCPU
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_bmm_cpu(self, device, dtype):
self._test_bmm(device, dtype)
# TODO: Re-enable this test once bmm supports non-contiguous inputs.
# # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
# @dtypes(torch.float, torch.double)
# def test_bmm_noncontiguous(self, device, dtype):
# nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
# nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
# self.assertEqual(
# nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
# nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
def test_bmm_noncontiguous(self, device, dtype):
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
self.assertEqual(
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
@dtypes(torch.float, torch.double)

View File

@ -1,11 +0,0 @@
# Description:
# CUDA Templates for Linear Algebra Subroutines
load("@rules_cc//cc:defs.bzl", "cc_library")
cc_library(
name = "cutlass",
hdrs = glob(["include/**/*.h"]),
includes = ["include/"],
visibility = ["//visibility:public"],
)