torch._scaled_mm: support dims of size 0 for tensorwise scaling (#140967)

Summary:

Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`.

For now only enable support for tensorwise, we can tackle rowwise in a future PR.

Test Plan:

```
python test/test_matmul_cuda.py -k test_zero_dim
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140967
Approved by: https://github.com/eqy, https://github.com/drisspg
This commit is contained in:
vasiliy
2024-11-27 04:07:49 +00:00
committed by PyTorch MergeBot
parent 6e61ff4fd3
commit 3d5fe0ce78
3 changed files with 47 additions and 4 deletions

View File

@ -1050,6 +1050,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
IntArrayRef mat2_sizes = mat2.sizes();
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels
// do not support this case).
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
// `out` was created with `at::empty`. In the case where we are multiplying
// MxK by KxN and K is the zero dim, we need to initialize here to properly
// return a tensor of zeros.
if (mat1_sizes[1] == 0) {
out.zero_();
}
return out;
}
// We are doing row-wise scaling
if (scaling_choice == ScalingType::RowWise) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");

View File

@ -700,6 +700,33 @@ class TestFP8MatmulCuda(TestCase):
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
device = "cuda"
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
out_dtype = torch.bfloat16
M, K, N = 32, 32, 32
if which_dim_zero == 0:
M = 0
elif which_dim_zero == 1:
K = 0
elif which_dim_zero == 2:
N = 0
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(float('-inf'), device=device)
scale_b = torch.tensor(float('-inf'), device=device)
f = torch._scaled_mm
if use_torch_compile:
f = torch.compile(torch._scaled_mm)
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")

View File

@ -5589,13 +5589,16 @@ def meta_scaled_mm(
def is_col_major(stride):
return stride[0] == 1 and stride[1] > 1
def has_zero_dim(tensor_2d):
return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
torch._check(
is_row_major(self.stride()),
lambda: "self must be row_major",
is_row_major(self.stride()) or has_zero_dim(self),
lambda: f"self must be row_major, got stride {self.stride()}",
)
torch._check(
is_col_major(mat2.stride()),
lambda: "mat2 must be col_major",
is_col_major(mat2.stride()) or has_zero_dim(mat2),
lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
)
torch._check(
self.size(1) % 16 == 0,