mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Use CUTLASS GEMM for NT bmm [OSS-only] (#85894)"
This reverts commit ef58a132f223d5abf2bd3f8bee380aca6c29d17f. Reverted https://github.com/pytorch/pytorch/pull/85894 on behalf of https://github.com/DanilBaibak due to Break internal build
This commit is contained in:
@ -429,7 +429,6 @@ cu_library(
|
||||
"@cuda//:cublas",
|
||||
"@cuda//:cufft",
|
||||
"@cuda//:cusparse",
|
||||
"@cutlass",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
@ -1674,7 +1673,6 @@ cc_library(
|
||||
] + if_cuda([
|
||||
":torch_distributed_cuda",
|
||||
"@cuda//:nvToolsExt",
|
||||
"@cutlass",
|
||||
]),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -84,12 +84,6 @@ new_local_repository(
|
||||
path = "third_party/eigen",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "cutlass",
|
||||
build_file = "//third_party:cutlass.BUILD",
|
||||
path = "third_party/cutlass",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
name = "fbgemm",
|
||||
build_file = "//third_party:fbgemm/BUILD.bazel",
|
||||
|
@ -433,7 +433,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA AND NOT USE_ROCM)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
endif()
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
${CUDA_LIBRARIES}
|
||||
|
@ -1174,8 +1174,7 @@
|
||||
dispatch:
|
||||
SparseCPU: bmm_sparse_cpu
|
||||
SparseCUDA: bmm_sparse_cuda
|
||||
NestedTensorCPU: bmm_nested
|
||||
NestedTensorCUDA: bmm_nested_cuda
|
||||
NestedTensorCPU, NestedTensorCUDA: bmm_nested
|
||||
tags: canonical
|
||||
|
||||
- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -101,5 +101,27 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
|
||||
return splits;
|
||||
}
|
||||
|
||||
std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||
const NestedTensorImpl* self_ptr) {
|
||||
int64_t ntensors = self_ptr->size(0);
|
||||
std::vector<IntArrayRef> sizes(ntensors);
|
||||
if (ntensors == 0) {
|
||||
return sizes;
|
||||
}
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
int64_t orig_dim = sizemat.size(1);
|
||||
// nesting scalars has empty sizes
|
||||
if (orig_dim == 0) {
|
||||
return sizes;
|
||||
}
|
||||
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
|
||||
|
||||
for (const auto i : c10::irange(ntensors)) {
|
||||
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
|
||||
sizemat_ptr += orig_dim;
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -86,28 +86,8 @@ inline at::Tensor create_nested_view_tensor(
|
||||
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);
|
||||
|
||||
// The sizes of the underlying tensors
|
||||
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||
const NestedTensorImpl* self_ptr) {
|
||||
int64_t ntensors = self_ptr->size(0);
|
||||
std::vector<IntArrayRef> sizes(ntensors);
|
||||
if (ntensors == 0) {
|
||||
return sizes;
|
||||
}
|
||||
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
|
||||
int64_t orig_dim = sizemat.size(1);
|
||||
// nesting scalars has empty sizes
|
||||
if (orig_dim == 0) {
|
||||
return sizes;
|
||||
}
|
||||
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();
|
||||
|
||||
for (const auto i : c10::irange(ntensors)) {
|
||||
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
|
||||
sizemat_ptr += orig_dim;
|
||||
}
|
||||
return sizes;
|
||||
}
|
||||
|
||||
std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||
const NestedTensorImpl* self_ptr);
|
||||
|
||||
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
|
||||
const NestedTensorImpl& nt);
|
||||
|
@ -15,15 +15,6 @@
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cutlass/gemm/device/default_gemm_configuration.h>
|
||||
#include <cutlass/gemm/device/gemm_grouped.h>
|
||||
#include <cutlass/gemm/kernel/default_gemm_grouped.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
|
||||
#define BLOCK_DIM 256
|
||||
#define GRID_DIM_Y 16
|
||||
@ -347,8 +338,7 @@ __global__ void add_padding_3(
|
||||
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
||||
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
||||
const int i2 = i % output_sizes_3;
|
||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
|
||||
i2 < sizes_i[2]) {
|
||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
|
||||
const int offset = offsets[batch_id];
|
||||
const int input_offset =
|
||||
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
||||
@ -362,8 +352,7 @@ __global__ void add_padding_3(
|
||||
const int i0 = i / (output_sizes_2 * output_sizes_3);
|
||||
const int i1 = (i % (output_sizes_2 * output_sizes_3)) / output_sizes_3;
|
||||
const int i2 = i % output_sizes_3;
|
||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] &&
|
||||
i2 < sizes_i[2]) {
|
||||
if (batch_id < batch_size && i0 < sizes_i[0] && i1 < sizes_i[1] && i2 < sizes_i[2]) {
|
||||
const int offset = offsets[batch_id];
|
||||
const int input_offset =
|
||||
offset + i0 * (sizes_i[1] * sizes_i[2]) + i1 * sizes_i[2] + i2;
|
||||
@ -460,269 +449,5 @@ template void add_padding_kernelLauncher<c10::Half>(
|
||||
const int batch_size,
|
||||
const int output_batch_size);
|
||||
|
||||
namespace {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
template <typename scalar_t>
|
||||
void gemm_grouped_cuda_internal(
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
const std::vector<scalar_t*>& aptr,
|
||||
const std::vector<scalar_t*>& bptr,
|
||||
const std::vector<scalar_t*>& dptr,
|
||||
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
|
||||
const int problem_count,
|
||||
at::Device& device) {
|
||||
using Element = scalar_t;
|
||||
using ElementAcc = float;
|
||||
using OpClass = cutlass::arch::OpClassSimt;
|
||||
|
||||
using GemmConfiguration =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
cutlass::arch::Sm80,
|
||||
Element,
|
||||
Element,
|
||||
Element,
|
||||
ElementAcc>;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
GemmConfiguration::kAlignmentA,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::ComplexTransform::kNone,
|
||||
GemmConfiguration::kAlignmentB,
|
||||
Element,
|
||||
cutlass::layout::RowMajor,
|
||||
ElementAcc,
|
||||
OpClass,
|
||||
cutlass::arch::Sm80,
|
||||
typename GemmConfiguration::ThreadblockShape,
|
||||
typename GemmConfiguration::WarpShape,
|
||||
typename GemmConfiguration::InstructionShape,
|
||||
typename GemmConfiguration::EpilogueOutputOp,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
||||
GemmConfiguration::kStages>::GemmKernel;
|
||||
|
||||
using GemmGrouped = typename cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
using EpilogueOutputOp = typename GemmGrouped::GemmKernel::Epilogue::OutputOp;
|
||||
typename EpilogueOutputOp::Params epilogue_op(/*alpha*/ 1, /*beta*/ 0);
|
||||
|
||||
const int64_t gemm_coord_size =
|
||||
problem_count * ((int64_t)sizeof(cutlass::gemm::GemmCoord));
|
||||
// Number of gmm args not including *problem_sizes
|
||||
at::Tensor gmm_args = at::empty(
|
||||
{problem_count * 6 + gemm_coord_size},
|
||||
at::TensorOptions().dtype(at::kLong).pinned_memory(true));
|
||||
|
||||
// Obtain pointers for each argument (on host)
|
||||
int64_t* lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
int64_t* ldb_data = lda_data + problem_count;
|
||||
int64_t* ldd_data = lda_data + 2 * problem_count;
|
||||
int64_t* ptr_a_data = lda_data + 3 * problem_count;
|
||||
int64_t* ptr_b_data = lda_data + 4 * problem_count;
|
||||
int64_t* ptr_d_data = lda_data + 5 * problem_count;
|
||||
cutlass::gemm::GemmCoord* problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Set arguments into gmm_args from input args
|
||||
for (int i = 0; i < problem_count; ++i) {
|
||||
problem_sizes_data[i] = gemm_sizes[i];
|
||||
lda_data[i] = lda[i];
|
||||
ldb_data[i] = ldb[i];
|
||||
ldd_data[i] = ldd[i];
|
||||
ptr_a_data[i] = reinterpret_cast<int64_t>(aptr[i]);
|
||||
ptr_b_data[i] = reinterpret_cast<int64_t>(bptr[i]);
|
||||
ptr_d_data[i] = reinterpret_cast<int64_t>(dptr[i]);
|
||||
}
|
||||
const int threadblock_count =
|
||||
GemmGrouped::sufficient(problem_sizes_data, problem_count);
|
||||
|
||||
// Transfer arguments to GPU
|
||||
gmm_args = gmm_args.to(device, true);
|
||||
|
||||
// Obtain pointers for each of arguments (on GPU)
|
||||
lda_data = gmm_args.data_ptr<int64_t>(); // Base pointer
|
||||
ldb_data = lda_data + problem_count;
|
||||
ldd_data = lda_data + 2 * problem_count;
|
||||
ptr_a_data = lda_data + 3 * problem_count;
|
||||
ptr_b_data = lda_data + 4 * problem_count;
|
||||
ptr_d_data = lda_data + 5 * problem_count;
|
||||
problem_sizes_data =
|
||||
reinterpret_cast<cutlass::gemm::GemmCoord*>(lda_data + 6 * problem_count);
|
||||
|
||||
// Create GemmGrouped::Arguments using the arguments prepared above
|
||||
typename GemmGrouped::Arguments args(
|
||||
problem_sizes_data,
|
||||
problem_count,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<Element**>(ptr_a_data),
|
||||
reinterpret_cast<Element**>(ptr_b_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
reinterpret_cast<Element**>(ptr_d_data),
|
||||
lda_data,
|
||||
ldb_data,
|
||||
ldd_data,
|
||||
ldd_data);
|
||||
|
||||
GemmGrouped gemm;
|
||||
cutlass::Status status =
|
||||
gemm.initialize(args, nullptr, at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorWorkspaceNull,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to workspace.");
|
||||
TORCH_CHECK(
|
||||
status != cutlass::Status::kErrorInternal,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel due to internal error.");
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
// Run CUTLASS group GEMM
|
||||
status = gemm.run(at::cuda::getCurrentCUDAStream());
|
||||
TORCH_CHECK(
|
||||
status == cutlass::Status::kSuccess,
|
||||
"Failed to run CUTLASS Grouped GEMM kernel.");
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
if (self.is_nested() && !mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a nested self and non-nested other");
|
||||
} else if (!self.is_nested() && mat2.is_nested()) {
|
||||
AT_ERROR(
|
||||
"Expected both to be nested, but got a non-nested self and nested other");
|
||||
}
|
||||
// dispatcher should have guaranteed that at least one is nested
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
auto mat2_ptr = get_nested_tensor_impl(mat2);
|
||||
TORCH_CHECK(self_ptr->dim() == 3, "batch1 must be a 3D tensor");
|
||||
TORCH_CHECK(mat2_ptr->dim() == 3, "batch2 must be a 3D tensor");
|
||||
int64_t ntensors = self_ptr->size(0), ntensors2 = mat2_ptr->size(0);
|
||||
TORCH_CHECK(
|
||||
ntensors == ntensors2,
|
||||
"Expected size for the 1st dimension of batch2 tensor to be: ",
|
||||
ntensors,
|
||||
" but got: ",
|
||||
ntensors2,
|
||||
".");
|
||||
const Tensor &self_buffer = self_ptr->get_buffer(),
|
||||
&mat2_buffer = mat2_ptr->get_buffer();
|
||||
std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
|
||||
mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
|
||||
self_strides = NestedTensor_get_strides(self_ptr),
|
||||
mat2_strides = NestedTensor_get_strides(mat2_ptr);
|
||||
const std::vector<int64_t>& self_offsets = self_ptr->get_storage_offsets();
|
||||
const std::vector<int64_t>& mat2_offsets = mat2_ptr->get_storage_offsets();
|
||||
|
||||
// create a contiguous output
|
||||
int64_t out_numel = 0;
|
||||
int64_t a_numel = 0;
|
||||
int64_t b_numel = 0;
|
||||
const Tensor& self_sizemat = self_ptr->get_nested_size_tensor();
|
||||
Tensor out_sizemat = self_sizemat.new_empty(self_sizemat.sizes());
|
||||
int64_t* out_sizemat_ptr = out_sizemat.data_ptr<int64_t>();
|
||||
std::vector<int64_t> output_offsets;
|
||||
std::vector<int64_t> a_offsets;
|
||||
std::vector<int64_t> b_offsets;
|
||||
std::vector<int64_t> lda;
|
||||
std::vector<int64_t> ldb;
|
||||
std::vector<int64_t> ldd;
|
||||
#ifndef USE_ROCM
|
||||
std::vector<cutlass::gemm::GemmCoord> gemm_sizes;
|
||||
#endif
|
||||
bool all_row_major = true;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
const IntArrayRef &self_shape = self_sizes[i], &mat2_shape = mat2_sizes[i];
|
||||
const int64_t &self_size0 = self_shape[0], &self_size1 = self_shape[1],
|
||||
&mat2_size0 = mat2_shape[0], &mat2_size1 = mat2_shape[1];
|
||||
TORCH_CHECK(
|
||||
self_size1 == mat2_size0,
|
||||
i,
|
||||
"-th nested matrices in batch cannot be multiplied (",
|
||||
self_size0,
|
||||
"x",
|
||||
self_size1,
|
||||
" and ",
|
||||
mat2_size0,
|
||||
"x",
|
||||
mat2_size1,
|
||||
")");
|
||||
out_sizemat_ptr[0] = self_size0;
|
||||
out_sizemat_ptr[1] = mat2_size1;
|
||||
out_sizemat_ptr += 2;
|
||||
output_offsets.push_back(out_numel);
|
||||
out_numel += self_size0 * mat2_size1;
|
||||
#ifndef USE_ROCM
|
||||
gemm_sizes.push_back(
|
||||
cutlass::gemm::GemmCoord(self_size0, mat2_size1, self_size1));
|
||||
#endif
|
||||
lda.push_back(self_strides[i][0]);
|
||||
ldb.push_back(mat2_strides[i][0]);
|
||||
ldd.push_back(mat2_size1);
|
||||
a_offsets.push_back(a_numel);
|
||||
b_offsets.push_back(b_numel);
|
||||
a_numel += self_size0 * self_strides[i][0];
|
||||
b_numel += mat2_size0 * mat2_strides[i][0];
|
||||
all_row_major = all_row_major && (self_strides[i][1] == 1);
|
||||
all_row_major = all_row_major && (mat2_strides[i][1] == 1);
|
||||
}
|
||||
Tensor out_buffer = self_buffer.new_empty(out_numel);
|
||||
Tensor output = wrap_buffer(out_buffer, out_sizemat);
|
||||
at::Device device = output.device();
|
||||
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
if (is_sm8x && all_row_major) {
|
||||
if (self.dtype() == at::kFloat) {
|
||||
std::vector<float*> aptr;
|
||||
std::vector<float*> bptr;
|
||||
std::vector<float*> dptr;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(self_buffer.data_ptr<float>() + a_offsets[i]);
|
||||
bptr.push_back(mat2_buffer.data_ptr<float>() + b_offsets[i]);
|
||||
dptr.push_back(out_buffer.data_ptr<float>() + output_offsets[i]);
|
||||
}
|
||||
gemm_grouped_cuda_internal<float>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return output;
|
||||
}
|
||||
if (self.dtype() == at::kHalf) {
|
||||
std::vector<c10::Half*> aptr;
|
||||
std::vector<c10::Half*> bptr;
|
||||
std::vector<c10::Half*> dptr;
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(self_buffer.data_ptr<c10::Half>() + a_offsets[i]);
|
||||
bptr.push_back(mat2_buffer.data_ptr<c10::Half>() + b_offsets[i]);
|
||||
dptr.push_back(out_buffer.data_ptr<c10::Half>() + output_offsets[i]);
|
||||
}
|
||||
gemm_grouped_cuda_internal<c10::Half>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return output;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
std::vector<Tensor> output_unbind = output.unbind();
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
at::mm_out(
|
||||
output_unbind[i],
|
||||
self_buffer.as_strided(self_sizes[i], self_strides[i], self_offsets[i]),
|
||||
mat2_buffer.as_strided(
|
||||
mat2_sizes[i], mat2_strides[i], mat2_offsets[i]));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -1,45 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def bench(nt_a, nt_b, niter):
|
||||
# Warmup
|
||||
nt_c = nt_a.bmm(nt_b)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start_event.record()
|
||||
for iter in range(niter):
|
||||
nt_c = nt_a.bmm(nt_b)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
runtime = (start_event.elapsed_time(end_event) * 1.0e-3) / niter
|
||||
return runtime
|
||||
|
||||
|
||||
def sweep_n(ntensor, niter, dtype):
|
||||
print("n, dtype, ntensor, gflop, runtime, tflop/s")
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
nt_a = torch.nested_tensor(
|
||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
||||
)
|
||||
nt_b = torch.nested_tensor(
|
||||
[torch.randn(n, n).to(dtype).cuda() for t in range(ntensor)]
|
||||
)
|
||||
runtime = bench(nt_a, nt_b, niter)
|
||||
tflop = n * n * n * ntensor * 2 / 1e12
|
||||
print(n, dtype, ntensor, tflop, runtime, tflop / runtime)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
|
||||
parser.add_argument("--niter", default="10", type=int)
|
||||
parser.add_argument("--ntensor", default="20", type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
niter = args.niter
|
||||
ntensor = args.ntensor
|
||||
|
||||
sweep_n(ntensor, niter, torch.float32)
|
||||
sweep_n(ntensor, niter, torch.float16)
|
@ -8,7 +8,6 @@ from torch.testing._internal.common_device_type import (
|
||||
dtypesIfCUDA,
|
||||
instantiate_device_type_tests,
|
||||
skipMeta,
|
||||
onlyCUDA,
|
||||
onlyCPU
|
||||
)
|
||||
from torch.testing._internal.common_dtype import floating_types_and_half
|
||||
@ -972,7 +971,9 @@ class TestNestedTensorDeviceType(TestCase):
|
||||
torch.nn.functional.softmax(nt_contiguous, -1),
|
||||
torch.nn.functional.softmax(nt_noncontiguous, -1))
|
||||
|
||||
def _test_bmm(self, device, dtype):
|
||||
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_bmm(self, device, dtype):
|
||||
# error case: one is nested but the other is not
|
||||
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
||||
t = torch.randn(4, device=device, dtype=dtype)
|
||||
@ -1058,31 +1059,16 @@ class TestNestedTensorDeviceType(TestCase):
|
||||
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
|
||||
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
||||
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
||||
if dtype == torch.float16:
|
||||
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
self.assertEqual(actual, expect)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
def test_bmm_cuda(self, device, dtype):
|
||||
self._test_bmm(device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_bmm_cpu(self, device, dtype):
|
||||
self._test_bmm(device, dtype)
|
||||
|
||||
# TODO: Re-enable this test once bmm supports non-contiguous inputs.
|
||||
# # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
||||
# @dtypes(torch.float, torch.double)
|
||||
# def test_bmm_noncontiguous(self, device, dtype):
|
||||
# nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||
# nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
||||
# self.assertEqual(
|
||||
# nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
||||
# nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
||||
def test_bmm_noncontiguous(self, device, dtype):
|
||||
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
||||
self.assertEqual(
|
||||
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
||||
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
||||
|
||||
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
11
third_party/cutlass.BUILD
vendored
11
third_party/cutlass.BUILD
vendored
@ -1,11 +0,0 @@
|
||||
# Description:
|
||||
# CUDA Templates for Linear Algebra Subroutines
|
||||
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
|
||||
cc_library(
|
||||
name = "cutlass",
|
||||
hdrs = glob(["include/**/*.h"]),
|
||||
includes = ["include/"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
Reference in New Issue
Block a user