mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
enable float32 and float16 in torch._grouped_mm
fallback (#162059)
Summary: Enables `torch.float32` and `torch.float16` options in `torch._grouped_mm`. Note that the fast path is only enabled if `mat_a`, `mat_b`, and `out_dtype` are `torch.bfloat16`. Saving for future PRs: 1. enabling testing on more platforms 2. supporting out_dtype != mat_a.dtype 3. opinfo 4. better compile support Test Plan: ```bash // on A100 and H100 pytest test/test_matmul_cuda.py -s -k test_grouped_gemm -x // on H100 pytest test/test_matmul_cuda.py -s -k test_scaled_grouped_gemm -x ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/162059 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: #161407, #161717
This commit is contained in:
committed by
PyTorch MergeBot
parent
61fb632cfb
commit
9eadb37cdd
@ -346,7 +346,8 @@ const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
return out;
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ inline bool check_valid_strides_and_return_transposed(const Tensor& mat) {
|
||||
inline at::Tensor create_grouped_gemm_output_tensor(const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
std::optional<c10::ScalarType> out_dtype
|
||||
c10::ScalarType out_dtype
|
||||
) {
|
||||
c10::SmallVector<int64_t, 3> out_size;
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -59,14 +59,11 @@ std::optional<c10::ScalarType> out_dtype
|
||||
}
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// For TMA transfers, strides of output tensor have to be either
|
||||
// 1, or aligned to 16 bytes.
|
||||
const auto last_dim = out_size.size() - 1;
|
||||
const auto alignment = 16 / c10::elementSize(out_dtype_);
|
||||
const auto alignment = 16 / c10::elementSize(out_dtype);
|
||||
const int64_t size_padded = (out_size[last_dim] + alignment - 1) / alignment * alignment;
|
||||
std::vector<int64_t> out_stride;
|
||||
if (a_is_2d != b_is_2d) {
|
||||
@ -74,9 +71,9 @@ std::optional<c10::ScalarType> out_dtype
|
||||
} else {
|
||||
out_stride = {out_size[1] * size_padded, size_padded, 1};
|
||||
}
|
||||
return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
|
||||
return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype));
|
||||
#else
|
||||
return at::empty(out_size, mat_a.options().dtype(out_dtype_));
|
||||
return at::empty(out_size, mat_a.options().dtype(out_dtype));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -84,8 +81,8 @@ inline void _grouped_mm_validate_inputs(const Tensor& mat_a, const Tensor& mat_b
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type());
|
||||
TORCH_CHECK((mat_a.dtype() == at::kBFloat16) || (mat_a.dtype() == at::kFloat) || (mat_a.dtype() == at::kHalf), "Expected mat_a to be Float32, BFloat16 or Float16 matrix, got ", mat_a.scalar_type());
|
||||
TORCH_CHECK((mat_b.dtype() == at::kBFloat16) || (mat_b.dtype() == at::kFloat) || (mat_b.dtype() == at::kHalf), "Expected mat_b to be Float32, BFloat16 or Float16 matrix, got ", mat_b.scalar_type());
|
||||
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -106,6 +103,14 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
|
||||
}
|
||||
|
||||
inline c10::ScalarType _resolve_grouped_mm_out_dtype(const Tensor& mat_a, const Tensor& mat_b,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
|
||||
// TODO(future PR): enable float32 output dtype for bfloat16 and float16 inputs
|
||||
TORCH_CHECK(out_dtype_ == mat_a.dtype(), "Grouped gemm output dtype must match `mat_a` dtype");
|
||||
return out_dtype_;
|
||||
}
|
||||
|
||||
|
||||
inline void _grouped_mm_fallback(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
|
@ -1655,7 +1655,10 @@ bool use_fast_accum) {
|
||||
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
|
||||
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
@ -1698,9 +1701,14 @@ const std::optional<at::Tensor>& bias,
|
||||
std::optional<c10::ScalarType> out_dtype) {
|
||||
#ifndef USE_ROCM
|
||||
_grouped_mm_validate_inputs(mat_a, mat_b, offs, bias, out_dtype);
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
||||
bool a_b_and_out_are_bf16 = (
|
||||
mat_a.dtype() == at::kBFloat16 &&
|
||||
mat_b.dtype() == at::kBFloat16 &&
|
||||
out_dtype.value_or(at::kBFloat16) == at::kBFloat16
|
||||
);
|
||||
bool use_fast_path = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true) && a_b_and_out_are_bf16;
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
|
@ -315,9 +315,9 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major):
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
if a_row_major:
|
||||
a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups]
|
||||
@ -334,7 +334,7 @@ class TestMatmulCuda(TestCase):
|
||||
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
||||
|
||||
f = torch._grouped_mm
|
||||
out = f(a, b.t(), offs=offs, out_dtype=torch.bfloat16)
|
||||
out = f(a, b.t(), offs=offs, out_dtype=dtype)
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
offs_cpu = offs.cpu()
|
||||
@ -354,9 +354,9 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major):
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
if a_row_major:
|
||||
@ -388,7 +388,7 @@ class TestMatmulCuda(TestCase):
|
||||
offs[0] = offs[1]
|
||||
|
||||
f = torch._grouped_mm
|
||||
out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16)
|
||||
out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype)
|
||||
gO = torch.rand_like(out)
|
||||
if not check_zero_size:
|
||||
out.backward(gO)
|
||||
@ -411,9 +411,9 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major):
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
if a_row_major:
|
||||
@ -435,7 +435,7 @@ class TestMatmulCuda(TestCase):
|
||||
self.assertTrue(b_contig.is_contiguous() is not strided)
|
||||
|
||||
f = torch._grouped_mm
|
||||
out = f(a, b.transpose(-2, -1), out_dtype=torch.bfloat16)
|
||||
out = f(a, b.transpose(-2, -1), out_dtype=dtype)
|
||||
gO = torch.rand_like(out)
|
||||
out.backward(gO)
|
||||
self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out)
|
||||
@ -446,9 +446,9 @@ class TestMatmulCuda(TestCase):
|
||||
@parametrize("strided", [False, True])
|
||||
@parametrize("a_row_major", [False, True])
|
||||
@parametrize("b_row_major", [False, True])
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major):
|
||||
@dtypes(torch.bfloat16, torch.float32, torch.float16)
|
||||
def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
s_int = int(strided)
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
if a_row_major:
|
||||
@ -477,7 +477,7 @@ class TestMatmulCuda(TestCase):
|
||||
offs[0] = offs[1]
|
||||
|
||||
f = torch._grouped_mm
|
||||
out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=torch.bfloat16)
|
||||
out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype)
|
||||
gO = torch.rand_like(out)
|
||||
if not check_zero_size:
|
||||
out.backward(gO)
|
||||
|
Reference in New Issue
Block a user