Add bfloat16 support to torch.bmm(NST, NST) (#141380)

Adds bfloat16 support to torch.bmm(NST, NST) where NST is NestedTensor with the torch.strided (default) layout.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141380
Approved by: https://github.com/jbschlosser
This commit is contained in:
Christian Puhrsch
2024-11-23 04:18:45 +00:00
committed by PyTorch MergeBot
parent 66f2550328
commit 1f734bc90c
2 changed files with 68 additions and 2 deletions

View File

@ -280,6 +280,72 @@ bool group_gemm_dispatch(
return false;
}
template <>
bool group_gemm_dispatch(
at::Device device,
const std::vector<c10::BFloat16*>& aptr_,
const std::vector<c10::BFloat16*>& bptr_,
const std::vector<c10::BFloat16*>& dptr_,
const std::vector<int64_t>& lda,
const std::vector<int64_t>& ldb,
const std::vector<int64_t>& ldd,
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
int64_t ntensors) {
// Check alignment
bool all_pad_8 = true;
for (int i = 0; i < ntensors; i++) {
all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
// Not sure if this is a requirement, on the safe side
all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
}
std::vector<cutlass::bfloat16_t*> aptr;
aptr.reserve(ntensors);
std::vector<cutlass::bfloat16_t*> bptr;
bptr.reserve(ntensors);
std::vector<cutlass::bfloat16_t*> dptr;
dptr.reserve(ntensors);
for (int64_t i = 0; i < ntensors; i++) {
aptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(aptr_[i]));
bptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(bptr_[i]));
dptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(dptr_[i]));
}
if (all_pad_8) {
gemm_grouped_cuda_internal<
cutlass::bfloat16_t,
8,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return true;
} else {
gemm_grouped_cuda_internal<
cutlass::bfloat16_t,
1,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
cutlass::arch::OpClassSimt,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<64, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>>(
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
return true;
}
// Did not perform GEMM
return false;
}
} // namespace
#endif
@ -343,7 +409,7 @@ Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
#ifndef USE_ROCM
#ifndef _WIN32
bool success = false;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
self.scalar_type(), "group_gemm_dispatch", [&] {
std::vector<scalar_t*> aptr(ntensors);
std::vector<scalar_t*> bptr(ntensors);

View File

@ -1993,7 +1993,7 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
self.assertEqual(actual, expect)
@onlyCUDA
@dtypes(torch.float, torch.double, torch.float16)
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
def test_bmm_cuda(self, device, dtype):
self._test_bmm(device, dtype)