mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch._scaled_mm: support dims of size 0 for tensorwise scaling (#140967)
Summary: Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`. For now only enable support for tensorwise, we can tackle rowwise in a future PR. Test Plan: ``` python test/test_matmul_cuda.py -k test_zero_dim ``` Reviewers: Subscribers: Tasks: Tags: Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/140967 Approved by: https://github.com/eqy, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e61ff4fd3
commit
3d5fe0ce78
@ -1050,6 +1050,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
IntArrayRef mat2_sizes = mat2.sizes();
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
|
||||
// If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels
|
||||
// do not support this case).
|
||||
if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) {
|
||||
// `out` was created with `at::empty`. In the case where we are multiplying
|
||||
// MxK by KxN and K is the zero dim, we need to initialize here to properly
|
||||
// return a tensor of zeros.
|
||||
if (mat1_sizes[1] == 0) {
|
||||
out.zero_();
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
// We are doing row-wise scaling
|
||||
if (scaling_choice == ScalingType::RowWise) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precsion output types are supported for row-wise scaling.");
|
||||
|
@ -700,6 +700,33 @@ class TestFP8MatmulCuda(TestCase):
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
||||
device = "cuda"
|
||||
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
||||
out_dtype = torch.bfloat16
|
||||
M, K, N = 32, 32, 32
|
||||
if which_dim_zero == 0:
|
||||
M = 0
|
||||
elif which_dim_zero == 1:
|
||||
K = 0
|
||||
elif which_dim_zero == 2:
|
||||
N = 0
|
||||
|
||||
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
||||
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(float('-inf'), device=device)
|
||||
scale_b = torch.tensor(float('-inf'), device=device)
|
||||
f = torch._scaled_mm
|
||||
if use_torch_compile:
|
||||
f = torch.compile(torch._scaled_mm)
|
||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
||||
|
@ -5589,13 +5589,16 @@ def meta_scaled_mm(
|
||||
def is_col_major(stride):
|
||||
return stride[0] == 1 and stride[1] > 1
|
||||
|
||||
def has_zero_dim(tensor_2d):
|
||||
return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
|
||||
|
||||
torch._check(
|
||||
is_row_major(self.stride()),
|
||||
lambda: "self must be row_major",
|
||||
is_row_major(self.stride()) or has_zero_dim(self),
|
||||
lambda: f"self must be row_major, got stride {self.stride()}",
|
||||
)
|
||||
torch._check(
|
||||
is_col_major(mat2.stride()),
|
||||
lambda: "mat2 must be col_major",
|
||||
is_col_major(mat2.stride()) or has_zero_dim(mat2),
|
||||
lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
|
||||
)
|
||||
torch._check(
|
||||
self.size(1) % 16 == 0,
|
||||
|
Reference in New Issue
Block a user