mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add bfloat16 support to torch.bmm(NST, NST) (#141380)
Adds bfloat16 support to torch.bmm(NST, NST) where NST is NestedTensor with the torch.strided (default) layout. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141380 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
66f2550328
commit
1f734bc90c
@ -280,6 +280,72 @@ bool group_gemm_dispatch(
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool group_gemm_dispatch(
|
||||
at::Device device,
|
||||
const std::vector<c10::BFloat16*>& aptr_,
|
||||
const std::vector<c10::BFloat16*>& bptr_,
|
||||
const std::vector<c10::BFloat16*>& dptr_,
|
||||
const std::vector<int64_t>& lda,
|
||||
const std::vector<int64_t>& ldb,
|
||||
const std::vector<int64_t>& ldd,
|
||||
const std::vector<cutlass::gemm::GemmCoord>& gemm_sizes,
|
||||
int64_t ntensors) {
|
||||
|
||||
// Check alignment
|
||||
bool all_pad_8 = true;
|
||||
for (int i = 0; i < ntensors; i++) {
|
||||
all_pad_8 = all_pad_8 && (gemm_sizes[i].n() % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (gemm_sizes[i].k() % 8 == 0);
|
||||
|
||||
// Not sure if this is a requirement, on the safe side
|
||||
all_pad_8 = all_pad_8 && (lda[i] % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (ldb[i] % 8 == 0);
|
||||
all_pad_8 = all_pad_8 && (ldd[i] % 8 == 0);
|
||||
}
|
||||
|
||||
std::vector<cutlass::bfloat16_t*> aptr;
|
||||
aptr.reserve(ntensors);
|
||||
std::vector<cutlass::bfloat16_t*> bptr;
|
||||
bptr.reserve(ntensors);
|
||||
std::vector<cutlass::bfloat16_t*> dptr;
|
||||
dptr.reserve(ntensors);
|
||||
for (int64_t i = 0; i < ntensors; i++) {
|
||||
aptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(aptr_[i]));
|
||||
bptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(bptr_[i]));
|
||||
dptr.push_back(reinterpret_cast<cutlass::bfloat16_t*>(dptr_[i]));
|
||||
}
|
||||
if (all_pad_8) {
|
||||
gemm_grouped_cuda_internal<
|
||||
cutlass::bfloat16_t,
|
||||
8,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::arch::OpClassTensorOp,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 32>,
|
||||
cutlass::gemm::GemmShape<64, 64, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 16>>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return true;
|
||||
} else {
|
||||
gemm_grouped_cuda_internal<
|
||||
cutlass::bfloat16_t,
|
||||
1,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::arch::OpClassSimt,
|
||||
cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<128, 128, 8>,
|
||||
cutlass::gemm::GemmShape<64, 32, 8>,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>>(
|
||||
lda, ldb, ldd, aptr, bptr, dptr, gemm_sizes, ntensors, device);
|
||||
return true;
|
||||
}
|
||||
// Did not perform GEMM
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
@ -343,7 +409,7 @@ Tensor bmm_nested_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
#ifndef USE_ROCM
|
||||
#ifndef _WIN32
|
||||
bool success = false;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
self.scalar_type(), "group_gemm_dispatch", [&] {
|
||||
std::vector<scalar_t*> aptr(ntensors);
|
||||
std::vector<scalar_t*> bptr(ntensors);
|
||||
|
@ -1993,7 +1993,7 @@ class TestNestedTensorDeviceType(NestedTensorTestCase):
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
||||
def test_bmm_cuda(self, device, dtype):
|
||||
self._test_bmm(device, dtype)
|
||||
|
||||
|
Reference in New Issue
Block a user