diff --git a/BUILD.bazel b/BUILD.bazel index 172a31723a0b..df780db33f7b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -429,7 +429,6 @@ cu_library( "@cuda//:cublas", "@cuda//:cufft", "@cuda//:cusparse", - "@cutlass", ], alwayslink = True, ) @@ -1674,7 +1673,6 @@ cc_library( ] + if_cuda([ ":torch_distributed_cuda", "@cuda//:nvToolsExt", - "@cutlass", ]), alwayslink = True, ) diff --git a/WORKSPACE b/WORKSPACE index e8591f291abd..61abbdac2b23 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -84,12 +84,6 @@ new_local_repository( path = "third_party/eigen", ) -new_local_repository( - name = "cutlass", - build_file = "//third_party:cutlass.BUILD", - path = "third_party/cutlass", -) - new_local_repository( name = "fbgemm", build_file = "//third_party:fbgemm/BUILD.bazel", diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 629db87dc15d..e23ea710df9c 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -433,7 +433,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) - list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + if(USE_FLASH_ATTENTION) + list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + endif() if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES} diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ef5e19653182..5cb8815fc207 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1174,8 +1174,7 @@ dispatch: SparseCPU: bmm_sparse_cpu SparseCUDA: bmm_sparse_cuda - NestedTensorCPU: bmm_nested - NestedTensorCUDA: bmm_nested_cuda + NestedTensorCPU, NestedTensorCUDA: bmm_nested tags: canonical - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.cpp b/aten/src/ATen/native/nested/NestedTensorUtils.cpp index 6801a640a371..7810d6f652af 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUtils.cpp @@ -101,5 +101,27 @@ std::vector chunk_nested_tensor(const Tensor& self, int64_t chunks, int6 return splits; } +std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr) { + int64_t ntensors = self_ptr->size(0); + std::vector sizes(ntensors); + if (ntensors == 0) { + return sizes; + } + const Tensor& sizemat = self_ptr->get_nested_size_tensor(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars has empty sizes + if (orig_dim == 0) { + return sizes; + } + const int64_t* sizemat_ptr = sizemat.data_ptr(); + + for (const auto i : c10::irange(ntensors)) { + sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); + sizemat_ptr += orig_dim; + } + return sizes; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 84ee31f1f24f..19f55b3f998b 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -86,28 +86,8 @@ inline at::Tensor create_nested_view_tensor( int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt); // The sizes of the underlying tensors -inline std::vector NestedTensor_get_sizes( - const NestedTensorImpl* self_ptr) { - int64_t ntensors = self_ptr->size(0); - std::vector sizes(ntensors); - if (ntensors == 0) { - return sizes; - } - const Tensor& sizemat = self_ptr->get_nested_size_tensor(); - int64_t orig_dim = sizemat.size(1); - // nesting scalars has empty sizes - if (orig_dim == 0) { - return sizes; - } - const int64_t* sizemat_ptr = sizemat.data_ptr(); - - for (const auto i : c10::irange(ntensors)) { - sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim); - sizemat_ptr += orig_dim; - } - return sizes; -} - +std::vector NestedTensor_get_sizes( + const NestedTensorImpl* self_ptr); TORCH_API std::vector NestedTensor_get_max_size( const NestedTensorImpl& nt); diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 49ee9ca7d3a9..dd5e9b80ca6b 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -15,15 +15,6 @@ #include #include -#include - -#ifndef USE_ROCM -#include -#include -#include -#endif - -#include #define BLOCK_DIM 256 #define GRID_DIM_Y 16 @@ -347,8 +338,7 @@ __global__ void add_padding_3( const int i0 = i / (output_sizes_2 * output_sizes_3); const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3; const int i2 = i % output_sizes_3; - if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && - i2 < sizes_i[2]) { + if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) { const int offset = offsets[batch_id]; const int input_offset = offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; @@ -362,8 +352,7 @@ __global__ void add_padding_3( const int i0 = i / (output_sizes_2 * output_sizes_3); const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3; const int i2 = i % output_sizes_3; - if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && - i2 < sizes_i[2]) { + if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) { const int offset = offsets[batch_id]; const int input_offset = offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2; @@ -460,269 +449,5 @@ template void add_padding_kernelLauncher( const int batch_size, const int output_batch_size); -namespace { - -#ifndef USE_ROCM -template -void gemm_grouped_cuda_internal( - const std::vector& lda, - const std::vector& ldb, - const std::vector& ldd, - const std::vector& aptr, - const std::vector& bptr, - const std::vector& dptr, - const std::vector& gemm_sizes, - const int problem_count, - at::Device& device) { - using Element = scalar_t; - using ElementAcc = float; - using OpClass = cutlass::arch::OpClassSimt; - - using GemmConfiguration = - typename cutlass::gemm::device::DefaultGemmConfiguration< - OpClass, - cutlass::arch::Sm80, - Element, - Element, - Element, - ElementAcc>; - - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - Element, - cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, - GemmConfiguration::kAlignmentA, - Element, - cutlass::layout::RowMajor, - cutlass::ComplexTransform::kNone, - GemmConfiguration::kAlignmentB, - Element, - cutlass::layout::RowMajor, - ElementAcc, - OpClass, - cutlass::arch::Sm80, - typename GemmConfiguration::ThreadblockShape, - typename GemmConfiguration::WarpShape, - typename GemmConfiguration::InstructionShape, - typename GemmConfiguration::EpilogueOutputOp, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, - GemmConfiguration::kStages>::GemmKernel; - - using GemmGrouped = typename cutlass::gemm::device::GemmGrouped; - using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0); - - const int64_t gemm_coord_size = - problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord)); - // Number of gmm args not including *problem_sizes - at::Tensor gmm_args = at::empty( - {problem_count * 6 + gemm_coord_size}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - - // Obtain pointers for each argument (on host) - int64_t* lda_data = gmm_args.data_ptr(); // Base pointer - int64_t* ldb_data = lda_data + problem_count; - int64_t* ldd_data = lda_data + 2 * problem_count; - int64_t* ptr_a_data = lda_data + 3 * problem_count; - int64_t* ptr_b_data = lda_data + 4 * problem_count; - int64_t* ptr_d_data = lda_data + 5 * problem_count; - cutlass::gemm::GemmCoord* problem_sizes_data = - reinterpret_cast(lda_data + 6 * problem_count); - - // Set arguments into gmm_args from input args - for (int i = 0; i < problem_count; ++i) { - problem_sizes_data[i] = gemm_sizes[i]; - lda_data[i] = lda[i]; - ldb_data[i] = ldb[i]; - ldd_data[i] = ldd[i]; - ptr_a_data[i] = reinterpret_cast(aptr[i]); - ptr_b_data[i] = reinterpret_cast(bptr[i]); - ptr_d_data[i] = reinterpret_cast(dptr[i]); - } - const int threadblock_count = - GemmGrouped::sufficient(problem_sizes_data, problem_count); - - // Transfer arguments to GPU - gmm_args = gmm_args.to(device, true); - - // Obtain pointers for each of arguments (on GPU) - lda_data = gmm_args.data_ptr(); // Base pointer - ldb_data = lda_data + problem_count; - ldd_data = lda_data + 2 * problem_count; - ptr_a_data = lda_data + 3 * problem_count; - ptr_b_data = lda_data + 4 * problem_count; - ptr_d_data = lda_data + 5 * problem_count; - problem_sizes_data = - reinterpret_cast(lda_data + 6 * problem_count); - - // Create GemmGrouped::Arguments using the arguments prepared above - typename GemmGrouped::Arguments args( - problem_sizes_data, - problem_count, - threadblock_count, - epilogue_op, - reinterpret_cast(ptr_a_data), - reinterpret_cast(ptr_b_data), - reinterpret_cast(ptr_d_data), - reinterpret_cast(ptr_d_data), - lda_data, - ldb_data, - ldd_data, - ldd_data); - - GemmGrouped gemm; - cutlass::Status status = - gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - status != cutlass::Status::kErrorWorkspaceNull, - "Failed to initialize CUTLASS Grouped GEMM kernel due to workspace."); - TORCH_CHECK( - status != cutlass::Status::kErrorInternal, - "Failed to initialize CUTLASS Grouped GEMM kernel due to internal error."); - TORCH_CHECK( - status == cutlass::Status::kSuccess, - "Failed to initialize CUTLASS Grouped GEMM kernel."); - - // Run CUTLASS group GEMM - status = gemm.run(at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - status == cutlass::Status::kSuccess, - "Failed to run CUTLASS Grouped GEMM kernel."); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} -#endif - -} // namespace - -Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) { - if (self.is_nested() && !mat2.is_nested()) { - AT_ERROR( - "Expected both to be nested, but got a nested self and non-nested other"); - } else if (!self.is_nested() && mat2.is_nested()) { - AT_ERROR( - "Expected both to be nested, but got a non-nested self and nested other"); - } - // dispatcher should have guaranteed that at least one is nested - auto self_ptr = get_nested_tensor_impl(self); - auto mat2_ptr = get_nested_tensor_impl(mat2); - TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor"); - TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor"); - int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0); - TORCH_CHECK( - ntensors == ntensors2, - "Expected size for the 1st dimension of batch2 tensor to be: ", - ntensors, - " but got: ", - ntensors2, - "."); - const Tensor &self_buffer = self_ptr->get_buffer(), - &mat2_buffer = mat2_ptr->get_buffer(); - std::vector self_sizes = NestedTensor_get_sizes(self_ptr), - mat2_sizes = NestedTensor_get_sizes(mat2_ptr), - self_strides = NestedTensor_get_strides(self_ptr), - mat2_strides = NestedTensor_get_strides(mat2_ptr); - const std::vector& self_offsets = self_ptr->get_storage_offsets(); - const std::vector& mat2_offsets = mat2_ptr->get_storage_offsets(); - - // create a contiguous output - int64_t out_numel = 0; - int64_t a_numel = 0; - int64_t b_numel = 0; - const Tensor& self_sizemat = self_ptr->get_nested_size_tensor(); - Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes()); - int64_t* out_sizemat_ptr = out_sizemat.data_ptr(); - std::vector output_offsets; - std::vector a_offsets; - std::vector b_offsets; - std::vector lda; - std::vector ldb; - std::vector ldd; -#ifndef USE_ROCM - std::vector gemm_sizes; -#endif - bool all_row_major = true; - for (int64_t i = 0; i < ntensors; i++) { - const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i]; - const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1], - &mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1]; - TORCH_CHECK( - self_size1 == mat2_size0, - i, - "-th nested matrices in batch cannot be multiplied (", - self_size0, - "x", - self_size1, - " and ", - mat2_size0, - "x", - mat2_size1, - ")"); - out_sizemat_ptr[0] = self_size0; - out_sizemat_ptr[1] = mat2_size1; - out_sizemat_ptr += 2; - output_offsets.push_back(out_numel); - out_numel += self_size0 * mat2_size1; -#ifndef USE_ROCM - gemm_sizes.push_back( - cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1)); -#endif - lda.push_back(self_strides[i][0]); - ldb.push_back(mat2_strides[i][0]); - ldd.push_back(mat2_size1); - a_offsets.push_back(a_numel); - b_offsets.push_back(b_numel); - a_numel += self_size0 * self_strides[i][0]; - b_numel += mat2_size0 * mat2_strides[i][0]; - all_row_major = all_row_major && (self_strides[i][1] == 1); - all_row_major = all_row_major && (mat2_strides[i][1] == 1); - } - Tensor out_buffer = self_buffer.new_empty(out_numel); - Tensor output = wrap_buffer(out_buffer, out_sizemat); - at::Device device = output.device(); - -#ifndef USE_ROCM - auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - if (is_sm8x && all_row_major) { - if (self.dtype() == at::kFloat) { - std::vector aptr; - std::vector bptr; - std::vector dptr; - for (int64_t i = 0; i < ntensors; i++) { - aptr.push_back(self_buffer.data_ptr() + a_offsets[i]); - bptr.push_back(mat2_buffer.data_ptr() + b_offsets[i]); - dptr.push_back(out_buffer.data_ptr() + output_offsets[i]); - } - gemm_grouped_cuda_internal( - lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); - return output; - } - if (self.dtype() == at::kHalf) { - std::vector aptr; - std::vector bptr; - std::vector dptr; - for (int64_t i = 0; i < ntensors; i++) { - aptr.push_back(self_buffer.data_ptr() + a_offsets[i]); - bptr.push_back(mat2_buffer.data_ptr() + b_offsets[i]); - dptr.push_back(out_buffer.data_ptr() + output_offsets[i]); - } - gemm_grouped_cuda_internal( - lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device); - return output; - } - } -#endif - std::vector output_unbind = output.unbind(); - for (int64_t i = 0; i < ntensors; i++) { - at::mm_out( - output_unbind[i], - self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]), - mat2_buffer.as_strided( - mat2_sizes[i], mat2_strides[i], mat2_offsets[i])); - } - return output; -} - } // namespace native } // namespace at diff --git a/benchmarks/nested/nested_bmm_bench.py b/benchmarks/nested/nested_bmm_bench.py deleted file mode 100644 index 311b23395efd..000000000000 --- a/benchmarks/nested/nested_bmm_bench.py +++ /dev/null @@ -1,45 +0,0 @@ -import argparse - -import torch - - -def bench(nt_a, nt_b, niter): - # Warmup - nt_c = nt_a.bmm(nt_b) - - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for iter in range(niter): - nt_c = nt_a.bmm(nt_b) - end_event.record() - torch.cuda.synchronize() - runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter - return runtime - - -def sweep_n(ntensor, niter, dtype): - print("n, dtype, ntensor, gflop, runtime, tflop/s") - for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - nt_a = torch.nested_tensor( - [torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)] - ) - nt_b = torch.nested_tensor( - [torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)] - ) - runtime = bench(nt_a, nt_b, niter) - tflop = n * n * n * ntensor * 2 / 1e12 - print(n, dtype, ntensor, tflop, runtime, tflop / runtime) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark") - parser.add_argument("--niter", default="10", type=int) - parser.add_argument("--ntensor", default="20", type=int) - - args = parser.parse_args() - niter = args.niter - ntensor = args.ntensor - - sweep_n(ntensor, niter, torch.float32) - sweep_n(ntensor, niter, torch.float16) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index a708da254fb7..d61661c56d82 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8,7 +8,6 @@ from torch.testing._internal.common_device_type import ( dtypesIfCUDA, instantiate_device_type_tests, skipMeta, - onlyCUDA, onlyCPU ) from torch.testing._internal.common_dtype import floating_types_and_half @@ -972,7 +971,9 @@ class TestNestedTensorDeviceType(TestCase): torch.nn.functional.softmax(nt_contiguous, -1), torch.nn.functional.softmax(nt_noncontiguous, -1)) - def _test_bmm(self, device, dtype): + # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + @dtypes(torch.float, torch.double) + def test_bmm(self, device, dtype): # error case: one is nested but the other is not nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) t = torch.randn(4, device=device, dtype=dtype) @@ -1058,31 +1059,16 @@ class TestNestedTensorDeviceType(TestCase): nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) - if dtype == torch.float16: - self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) - else: - self.assertEqual(actual, expect) + self.assertEqual(actual, expect) - @onlyCUDA - @dtypes(torch.float, torch.double, torch.float16) - def test_bmm_cuda(self, device, dtype): - self._test_bmm(device, dtype) - - @onlyCPU # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' @dtypes(torch.float, torch.double) - def test_bmm_cpu(self, device, dtype): - self._test_bmm(device, dtype) - - # TODO: Re-enable this test once bmm supports non-contiguous inputs. - # # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' - # @dtypes(torch.float, torch.double) - # def test_bmm_noncontiguous(self, device, dtype): - # nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) - # nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) - # self.assertEqual( - # nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), - # nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) + def test_bmm_noncontiguous(self, device, dtype): + nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) + nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) + self.assertEqual( + nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), + nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' @dtypes(torch.float, torch.double) diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD deleted file mode 100644 index bd928c5fc1a1..000000000000 --- a/third_party/cutlass.BUILD +++ /dev/null @@ -1,11 +0,0 @@ -# Description: -# CUDA Templates for Linear Algebra Subroutines - -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "cutlass", - hdrs = glob(["include/**/*.h"]), - includes = ["include/"], - visibility = ["//visibility:public"], -)