mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation. I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert. I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself. I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`. Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531 Approved by: https://github.com/drisspg
1407 lines
58 KiB
Python
1407 lines
58 KiB
Python
# Owner(s): ["module: linear algebra"]
|
|
|
|
import contextlib
|
|
import json
|
|
import math
|
|
import re
|
|
import tempfile
|
|
import unittest
|
|
from itertools import product
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from torch.quantization._quantized_conversions import (
|
|
pack_int4_to_int8,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
|
|
)
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_cuda import (
|
|
SM53OrLater,
|
|
SM89OrLater,
|
|
SM90OrLater,
|
|
_get_torch_cuda_version,
|
|
PLATFORM_SUPPORTS_FP8,
|
|
PLATFORM_SUPPORTS_MX_GEMM
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
onlyCUDA,
|
|
tol as xtol,
|
|
toleranceOverride,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import (
|
|
IS_ARM64,
|
|
IS_JETSON,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfRocm,
|
|
skipIfRocmVersionLessThan,
|
|
TEST_CUDA,
|
|
TEST_WITH_ROCM,
|
|
TestCase,
|
|
)
|
|
|
|
_IS_SM8X = False
|
|
if TEST_CUDA:
|
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
|
|
|
# Protects against includes accidentally setting the default dtype
|
|
assert torch.get_default_dtype() is torch.float32
|
|
|
|
|
|
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
|
|
class TestMatmulCuda(TestCase):
|
|
def setUp(self):
|
|
super(self.__class__, self).setUp()
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
def tearDown(self):
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
super(self.__class__, self).tearDown()
|
|
|
|
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False):
|
|
#
|
|
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
|
|
# results from the CUDA invocation of torch.addmm and the CPU invocation
|
|
# (which does not use CUDA backend).
|
|
#
|
|
# Get dims
|
|
n, m, p = (size + 1, size, size + 2)
|
|
# Disable reduced precision reductions in BFloat16 to bypass some kernels
|
|
# which fail the threshold check
|
|
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
|
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
|
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate
|
|
# Make random tensors on CPU (seed set on common_utils.py import)
|
|
# (Not using numpy because it does not support bfloat16)
|
|
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
|
|
m_beta = make_arg(1)
|
|
m_input = make_arg((n, p))
|
|
m_1 = make_arg((n, m))
|
|
m_2 = make_arg((m, p))
|
|
# scale to abate overflows in fp16 accum
|
|
if fp16_accumulate:
|
|
m_1 = m_1 / 100
|
|
m_2 = m_2 / 100
|
|
# *(B)FLOAT16 Special Handling*
|
|
# Backend does not tensorize float16 on CPU,
|
|
# and bloat16 may present accuracy issues,
|
|
# so convert to float32 for these cases
|
|
# (but keep same for other types, e.g. float32 and int*)
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
m_beta = m_beta.to(dtype=torch.float32)
|
|
m_input = m_input.to(dtype=torch.float32)
|
|
m_1 = m_1.to(dtype=torch.float32)
|
|
m_2 = m_2.to(dtype=torch.float32)
|
|
# Get CPU result
|
|
res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
|
|
# *(B)FLOAT16 Special Handling*``
|
|
# Convert back to (b)float16
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
m_beta = m_beta.to(dtype=dtype)
|
|
m_input = m_input.to(dtype=dtype)
|
|
m_1 = m_1.to(dtype=dtype)
|
|
m_2 = m_2.to(dtype=dtype)
|
|
res_cpu = res_cpu.to(dtype=dtype)
|
|
# Move arg tensors to CUDA
|
|
m_beta = m_beta.to("cuda")
|
|
m_input = m_input.to("cuda")
|
|
m_1 = m_1.to("cuda")
|
|
m_2 = m_2.to("cuda")
|
|
# Get CUDA result
|
|
res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
|
|
# Move to CPU for comparison
|
|
res_cuda = res_cuda.to("cpu")
|
|
# Compare
|
|
self.assertEqual(res_cpu, res_cuda)
|
|
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
|
|
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
|
|
torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
|
|
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16, torch.float32)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, False)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, True)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocmVersionLessThan((5, 2))
|
|
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
|
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
|
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
@parametrize("size", [100, 1000, 10000])
|
|
def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype):
|
|
self.cublas_addmm(size, dtype, False, True)
|
|
|
|
@onlyCUDA
|
|
@skipIfRocm
|
|
def test_cublas_and_lt_reduced_precision_fp16_accumulate(self):
|
|
orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
|
x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half)
|
|
w = torch.rand(512, 512, device='cuda', dtype=torch.half)
|
|
b = torch.rand(512, device='cuda', dtype=torch.half)
|
|
out = torch.nn.functional.linear(x, w, b)
|
|
out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu())
|
|
self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3)
|
|
|
|
a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half)
|
|
out = torch.baddbmm(a, b, c)
|
|
out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu())
|
|
self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3)
|
|
torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate
|
|
|
|
@onlyCUDA
|
|
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
|
|
@dtypes(torch.float16)
|
|
def test_cublas_addmm_alignment(self, dtype):
|
|
device = 'cuda'
|
|
# perturb X, A, or B alignment
|
|
for idx in range(0, 3):
|
|
for offset in range(1, 3):
|
|
offsets = [0, 0, 0]
|
|
offsets[idx] = offset
|
|
x_offset, a_offset, b_offset = offsets
|
|
A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
|
|
A = A[a_offset:].reshape(5120, 2560)
|
|
X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
|
|
X = X[x_offset:].reshape(26, 1, 2560)
|
|
B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
|
|
B = B[b_offset:].reshape(5120)
|
|
out = torch.nn.functional.linear(X, A, B)
|
|
self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
|
|
|
|
@onlyCUDA
|
|
@unittest.skipIf(IS_JETSON, "Too large for Jetson")
|
|
@toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
|
|
@dtypes(*([torch.float32, torch.float16] +
|
|
[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
|
|
@parametrize(
|
|
"batch_size, N, M, P",
|
|
[(2, 100, 100, 100),
|
|
(2, 1000, 1000, 1000),
|
|
(1, 10000, 1000, 10000),
|
|
(1, 10000, 10000, 10000)],
|
|
name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
|
|
)
|
|
@skipIfRocm
|
|
def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
|
|
cpu_dtype = dtype
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
cpu_dtype = torch.float32
|
|
|
|
M1 = torch.rand((N, M), device=device, dtype=dtype)
|
|
M2 = torch.rand((M, P), device=device, dtype=dtype)
|
|
A = torch.rand((N, P), device=device, dtype=dtype)
|
|
|
|
def _convert_to_cpu(t):
|
|
return t.to(device='cpu', dtype=cpu_dtype)
|
|
M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
|
|
|
|
# linear
|
|
out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
|
|
out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
|
|
self.assertEqual(out1_cpu, out1_gpu)
|
|
# test multiply the identity matrix
|
|
if N == M and M == P:
|
|
M2_eye = torch.eye(N, device=device, dtype=dtype)
|
|
out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
|
|
self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
|
|
|
|
# baddbmm
|
|
def _expand_to_batch(t: torch.Tensor):
|
|
return t.expand((batch_size, ) + t.size())
|
|
alpha, beta = 1.0, 1.0
|
|
M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
|
|
|
|
out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
|
|
out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
|
|
self.assertEqual(out2_cpu, out2_gpu)
|
|
# test multiply the identity matrix
|
|
if N == M and M == P:
|
|
M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
|
|
out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
|
|
self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
|
|
|
|
# cross comparison
|
|
self.assertEqual(out1_gpu, out2_gpu[0])
|
|
|
|
|
|
f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices"
|
|
mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+"
|
|
|
|
if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName:
|
|
e4m3_type = torch.float8_e4m3fnuz
|
|
e5m2_type = torch.float8_e5m2fnuz
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
|
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
|
|
else:
|
|
e4m3_type = torch.float8_e4m3fn
|
|
e5m2_type = torch.float8_e5m2
|
|
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
|
|
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
|
|
|
|
# avoid division by zero when calculating scale
|
|
EPS = 1e-12
|
|
|
|
def amax_to_scale(
|
|
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
|
|
):
|
|
""" Converts the amax value of a tensor to the fp8 scale.
|
|
Args:
|
|
amax: The amax value of the tensor.
|
|
float8_dtype: the float8 dtype.
|
|
orig_dtype: The original dtype of the tensor.
|
|
"""
|
|
scale = torch.empty_like(amax, dtype=torch.float32)
|
|
if float8_dtype == e4m3_type:
|
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
elif float8_dtype == e5m2_type:
|
|
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
|
else:
|
|
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
|
|
|
# Ensure the scale is representable in float16,
|
|
# this helps when amax is small. We are assuming that we don't need
|
|
# to care about this for float32/bfloat16
|
|
if orig_dtype is torch.float16:
|
|
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
|
|
|
|
scale.copy_(res)
|
|
return scale
|
|
|
|
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
|
if dim is None:
|
|
amax = torch.max(torch.abs(x))
|
|
else:
|
|
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
|
|
|
|
return amax_to_scale(amax, float8_dtype, x.dtype)
|
|
|
|
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
|
# naive implementation: dq -> op -> q
|
|
x_fp32 = x.to(torch.float) / x_scale
|
|
y_fp32 = y.to(torch.float) / y_scale
|
|
out_fp32 = torch.mm(x_fp32, y_fp32)
|
|
|
|
return out_fp32.to(out_dtype)
|
|
|
|
def addmm_float8_unwrapped(
|
|
a_data: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b_data: torch.Tensor,
|
|
b_scale: torch.tensor,
|
|
output_dtype: torch.dtype,
|
|
output_scale: Optional[torch.Tensor],
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
a_inverse_scale = a_scale.reciprocal()
|
|
b_inverse_scale = b_scale.reciprocal()
|
|
if output_dtype == torch.float32 and bias is not None:
|
|
# Bias is not supported by _scaled_mm when output is fp32
|
|
output = torch._scaled_mm(
|
|
a_data,
|
|
b_data,
|
|
scale_a=a_inverse_scale,
|
|
scale_b=b_inverse_scale,
|
|
scale_result=output_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
output += bias
|
|
return output
|
|
output = torch._scaled_mm(
|
|
a_data,
|
|
b_data,
|
|
bias=bias,
|
|
scale_a=a_inverse_scale,
|
|
scale_b=b_inverse_scale,
|
|
scale_result=output_scale,
|
|
out_dtype=output_dtype,
|
|
)
|
|
return output
|
|
|
|
def mm_float8(
|
|
a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
a_scale: torch.Tensor,
|
|
b_scale: torch.Tensor,
|
|
output_dtype: torch.dtype, # output dtype
|
|
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
|
|
) -> torch.Tensor:
|
|
return addmm_float8_unwrapped(
|
|
a, a_scale, b, b_scale, output_dtype, output_scale
|
|
)
|
|
|
|
def to_fp8_saturated(
|
|
x: torch.Tensor,
|
|
fp8_dtype: torch.dtype
|
|
):
|
|
if fp8_dtype == e4m3_type:
|
|
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
|
|
elif fp8_dtype == e5m2_type:
|
|
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
|
|
else:
|
|
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
|
|
|
|
return x.to(fp8_dtype)
|
|
|
|
# copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
def to_blocked(input_matrix) -> torch.Tensor:
|
|
"""
|
|
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
|
|
|
|
See:
|
|
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
|
|
|
|
Args:
|
|
input_matrix: Input tensor of shape (H, W)
|
|
|
|
Returns:
|
|
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
|
|
"""
|
|
rows, cols = input_matrix.shape
|
|
n_row_blocks = ceil_div(rows, 128)
|
|
n_col_blocks = ceil_div(cols, 4)
|
|
|
|
# Calculate the padded shape
|
|
padded_rows = n_row_blocks * 128
|
|
padded_cols = n_col_blocks * 4
|
|
|
|
padded = input_matrix
|
|
# Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now
|
|
if (rows, cols) != (padded_rows, padded_cols):
|
|
padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype)
|
|
padded[:rows, :cols] = input_matrix
|
|
|
|
# Rearrange the blocks
|
|
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
|
|
return rearranged.flatten()
|
|
|
|
def compute_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
"""Computes the error between two tensors in dB.
|
|
|
|
For more details see:
|
|
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
|
|
|
|
Args:
|
|
x: The original tensor.
|
|
y: The tensor to compare to the original tensor.
|
|
"""
|
|
Ps = torch.norm(x)
|
|
Pn = torch.norm(x - y)
|
|
return 20 * torch.log10(Ps / Pn)
|
|
|
|
# largest power of 2 representable in `torch.float8_e4m3fn`
|
|
F8E4M3_LARGEST_POW2 = 8
|
|
# max value of `torch.float8_e4m3fn` (448)
|
|
F8E4M3_MAX_VAL = torch.finfo(torch.float8_e4m3fn).max
|
|
# exponent bias of `torch.float8_e8m0fnu`
|
|
F8E8M0_EXP_BIAS = 127
|
|
|
|
def data_to_mx_scale(x, block_size):
|
|
# simple implementation of https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
# section 6.3, not all edge cases (such as NaN) are handled/tested
|
|
orig_shape = x.shape
|
|
x = x.reshape(-1, block_size)
|
|
max_abs = torch.amax(torch.abs(x), 1)
|
|
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs))
|
|
scale_e8m0_unbiased = largest_p2_lt_max_abs - F8E4M3_LARGEST_POW2
|
|
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -1 * F8E8M0_EXP_BIAS, F8E8M0_EXP_BIAS)
|
|
scale_e8m0_biased = scale_e8m0_unbiased + F8E8M0_EXP_BIAS
|
|
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
|
|
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
|
|
return scale_e8m0_biased.reshape(orig_shape[0], -1)
|
|
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
|
|
class TestFP8MatmulCuda(TestCase):
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def _test_tautological_mm(self, device: str = "cuda",
|
|
x_dtype: torch.dtype = e4m3_type,
|
|
y_dtype: torch.dtype = e4m3_type,
|
|
out_dtype: Optional[torch.dtype] = None,
|
|
size: int = 16) -> None:
|
|
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
|
|
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
|
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
if out_dtype is not None:
|
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_basics(self, device) -> None:
|
|
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
|
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
|
# supported on ROCm but fails on CUDA
|
|
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext()
|
|
with ctx:
|
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
|
|
|
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
|
|
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
|
|
|
|
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
|
|
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
|
|
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
|
|
|
|
with self.assertRaises(AssertionError if torch.version.hip else RuntimeError):
|
|
self._test_tautological_mm(device, out_dtype=e5m2_type)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_scale(self, device) -> None:
|
|
size = (16, 16)
|
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
# hipblaslt does not yet support mixed e4m3_type input
|
|
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
scale_one = torch.tensor(1.0, device=device)
|
|
scale_a = torch.tensor(1.5, device=device)
|
|
scale_b = torch.tensor(0.66, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one)
|
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
self.assertEqual(out_fp8, out_fp8_s)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_scaled_mm_vs_emulated(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
compare_type = torch.float32
|
|
|
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
|
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8,
|
|
y_fp8,
|
|
a_scale=x_scale,
|
|
b_scale=y_scale,
|
|
output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8,
|
|
x_scale,
|
|
y_fp8,
|
|
y_scale,
|
|
output_dtype
|
|
)
|
|
|
|
if output_dtype != base_dtype:
|
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
|
|
out_emulated = out_emulated.to(compare_type)
|
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 3e-3, 3e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
|
|
def test_scaled_mm_change_stride(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
compare_type = torch.float32
|
|
|
|
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
|
|
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
|
|
|
|
x.normal_()
|
|
y.normal_()
|
|
|
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8,
|
|
y_fp8,
|
|
a_scale=x_scale,
|
|
b_scale=y_scale,
|
|
output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8,
|
|
x_scale,
|
|
y_fp8,
|
|
y_scale,
|
|
output_dtype
|
|
)
|
|
|
|
if output_dtype != base_dtype:
|
|
out_scaled_mm = out_scaled_mm.to(compare_type)
|
|
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
|
|
|
|
out_emulated = out_emulated.to(compare_type)
|
|
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 3e-3, 3e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_bias(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.ones((k, l), device=device).to(e4m3_type)
|
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
|
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
|
out_fp32 = out_fp8.to(torch.float32)
|
|
outb_fp32 = outb_fp8.to(torch.float32)
|
|
difference = torch.abs(out_fp32 - outb_fp32)
|
|
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("bias", [True, False])
|
|
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
|
|
x = torch.rand((17, 16), device=device).to(e4m3_type)
|
|
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
input_bias = None
|
|
if bias:
|
|
input_bias = torch.rand((16,), device=device).to(torch.half)
|
|
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_bias_relu_edgecase(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
|
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
|
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
|
outb_fp32 = outb_fp8.to(torch.float32)
|
|
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float32_output_errors_with_bias(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Bias is not supported when out_dtype is set to Float32",
|
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
|
)
|
|
|
|
@unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_error_message_fp8_pre_sm89(self, device) -> None:
|
|
(k, l, m) = (16, 48, 32)
|
|
x = torch.rand((k, l), device=device).to(e4m3_type)
|
|
y = torch.rand((m, l), device=device).to(e4m3_type).t()
|
|
scale_a = torch.tensor(1.0, device=device)
|
|
scale_b = torch.tensor(1.0, device=device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
|
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
def test_float8_scale_fast_accum(self, device) -> None:
|
|
size = (16, 16)
|
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
|
# hipblaslt does not yet support mixed e4m3_type input
|
|
y_type = e4m3_type if torch.version.hip else e5m2_type
|
|
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
|
scale_a = torch.tensor(1.5, device=device)
|
|
scale_b = torch.tensor(0.66, device=device)
|
|
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
|
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
|
self.assertEqual(out_fp8, out_fp8_s)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
|
@parametrize("use_fast_accum", [True, False])
|
|
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
fill_value = 0.5
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
|
|
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
out_fp8 = torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=x_scales,
|
|
scale_b=y_scales,
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=use_fast_accum,
|
|
)
|
|
self.assertEqual(
|
|
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
def test_float8_error_messages(self, device) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
fill_value = 0.5
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((1, 1), device="cuda"),
|
|
scale_b=torch.ones((1, 2), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
|
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N + 1), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M), device="cuda"),
|
|
scale_b=torch.ones((N, N), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
|
),
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Note re.compile is used, not re.escape. This is to accomodate fn vs fnuz type message.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expected b\.dtype\(\) == at::kFloat8_e4m3fnu?z? to be true, but got false\.",
|
|
):
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8.to(e5m2_type),
|
|
scale_a=torch.ones((M, 1), device="cuda"),
|
|
scale_b=torch.ones((1, N), device="cuda"),
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
|
|
@parametrize("base_dtype", [torch.bfloat16])
|
|
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
|
|
torch.manual_seed(42)
|
|
input_dtype = e4m3_type
|
|
output_dtype = base_dtype
|
|
|
|
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
|
|
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
|
|
|
|
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
|
|
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
|
|
|
|
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
|
|
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
|
|
|
|
# Calculate actual F8 mm
|
|
out_scaled_mm = mm_float8(
|
|
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
|
|
)
|
|
|
|
# Calculate emulated F8 mm
|
|
out_emulated = mm_float8_emulated(
|
|
x_fp8, x_scales, y_fp8, y_scales, output_dtype
|
|
)
|
|
|
|
if base_dtype in {torch.bfloat16, torch.float16}:
|
|
atol, rtol = 7e-2, 7e-2
|
|
else:
|
|
atol, rtol = 2e-3, 2e-3
|
|
|
|
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@parametrize("which_dim_zero", [0, 1, 2])
|
|
@parametrize("use_torch_compile", [False, True])
|
|
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
|
|
device = "cuda"
|
|
x_dtype, y_dtype = torch.float8_e4m3fn, torch.float8_e4m3fn
|
|
out_dtype = torch.bfloat16
|
|
M, K, N = 32, 32, 32
|
|
if which_dim_zero == 0:
|
|
M = 0
|
|
elif which_dim_zero == 1:
|
|
K = 0
|
|
elif which_dim_zero == 2:
|
|
N = 0
|
|
|
|
x_fp8 = torch.zeros(M, K, device=device).to(x_dtype)
|
|
y_fp8 = torch.zeros(N, K, device=device, dtype=y_dtype).t()
|
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
|
scale_a = torch.tensor(float('-inf'), device=device)
|
|
scale_b = torch.tensor(float('-inf'), device=device)
|
|
f = torch._scaled_mm
|
|
if use_torch_compile:
|
|
f = torch.compile(torch._scaled_mm)
|
|
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
|
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
|
|
def test_honor_sm_carveout(self) -> None:
|
|
torch.manual_seed(42)
|
|
|
|
x = torch.randn(8192, 2048, device="cuda", dtype=torch.float32)
|
|
y = torch.randn(8192, 2048, device="cuda", dtype=torch.float32).t()
|
|
x_scales = tensor_to_scale(x, e4m3_type, dim=1).reciprocal()
|
|
y_scales = tensor_to_scale(y, e4m3_type, dim=0).reciprocal()
|
|
x_fp8 = to_fp8_saturated(x / x_scales, e4m3_type)
|
|
y_fp8 = to_fp8_saturated(y / y_scales, e4m3_type)
|
|
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(0)
|
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(66)
|
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
torch._C._set_sm_carveout_experimental(None)
|
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
|
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
|
|
|
prof.export_chrome_trace(f.name)
|
|
no_carveout, carveout_0, carveout_66, no_carveout_again = [
|
|
math.prod(evt.get("args", {}).get("grid", []))
|
|
for evt in json.load(open(f.name))["traceEvents"]
|
|
if evt.get("cat", "") == "kernel"
|
|
]
|
|
|
|
self.assertEqual(no_carveout, no_carveout_again)
|
|
self.assertNotEqual(no_carveout, carveout_66)
|
|
self.assertNotEqual(carveout_66, carveout_0)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
|
|
@parametrize("test_case_name", [
|
|
"a_eye_b_eye",
|
|
"a_ones_b_ones",
|
|
"a_ones_modified_b_ones",
|
|
"a_ones_b_ones_modified",
|
|
"a_scale_modified_b_ones",
|
|
"a_ones_b_scale_modified",
|
|
"data_random_scales_one",
|
|
"data_random_scales_from_data",
|
|
])
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("mkn", [
|
|
# Nice shapes
|
|
(128, 128, 128),
|
|
(256, 256, 256),
|
|
(128, 256, 512),
|
|
(256, 512, 128),
|
|
(512, 128, 256),
|
|
|
|
# Non block multiples
|
|
(65, 96, 112),
|
|
(197, 224, 272),
|
|
# K not multiple of 32
|
|
(197, 240, 272),
|
|
|
|
# Very unbalanced
|
|
(1023, 64, 48),
|
|
(31, 1024, 64),
|
|
(45, 96, 1024),
|
|
|
|
# Mixed large and small
|
|
(2, 1024, 128),
|
|
(127, 96, 1024),
|
|
(1025, 128, 96)
|
|
], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}")
|
|
def test_blockwise_mxfp8_numerics(self, test_case_name, fast_accum, mkn) -> None:
|
|
# inspiration: https://github.com/pytorch/ao/pull/1625
|
|
|
|
device = "cuda"
|
|
M, K, N = mkn
|
|
BLOCK_SIZE = 32
|
|
require_exact_match = True
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
if test_case_name == "a_eye_b_eye":
|
|
if not ((M == K) and (M == N)):
|
|
return unittest.skip("this test is only defined for M == K == N, skipping")
|
|
A_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.eye(M, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_modified_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_ref[1][0:BLOCK_SIZE] = 2
|
|
A[1][0:BLOCK_SIZE] = 2
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_ones_modified":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
B_ref[1][0:BLOCK_SIZE] = 2
|
|
B[1][0:BLOCK_SIZE] = 2
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_scale_modified_b_ones":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
A_ref[1][0:BLOCK_SIZE] = 4
|
|
A[1][0:BLOCK_SIZE] = 2
|
|
A_scale[1][0] = 2
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "a_ones_b_scale_modified":
|
|
A_ref = torch.ones(M, K, device=device, dtype=torch.bfloat16)
|
|
B_ref = torch.ones(N, K, device=device, dtype=torch.bfloat16)
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
B_ref[1][0:BLOCK_SIZE] = 4
|
|
B[1][0:BLOCK_SIZE] = 2
|
|
B_scale[1][0] = 2
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "data_random_scales_one":
|
|
require_exact_match = False
|
|
# scales all-ones, element data random while being exactly representable in float8_e4m3fn
|
|
|
|
# generate integers in [0, 255] and interpret as float8_e4m3fn
|
|
A_ref = torch.randint(0, 255, (M, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
B_ref = torch.randint(0, 255, (N, K), device=device, dtype=torch.uint8).view(torch.float8_e4m3fn).to(torch.bfloat16)
|
|
# modification: don't allow NaN values
|
|
A_ref[torch.isnan(A_ref)] = 0
|
|
B_ref[torch.isnan(B_ref)] = 0
|
|
|
|
A = A_ref.to(torch.float8_e4m3fn)
|
|
B = B_ref.to(torch.float8_e4m3fn)
|
|
|
|
A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
elif test_case_name == "data_random_scales_from_data":
|
|
if not K % BLOCK_SIZE == 0:
|
|
return unittest.skip(f"this test is only defined for K a multiple of {BLOCK_SIZE}, skipping")
|
|
require_exact_match = False
|
|
# random data, scales from data
|
|
A_ref = torch.randn((M, K), device=device, dtype=torch.bfloat16) * 1000
|
|
B_ref = torch.randn((N, K), device=device, dtype=torch.bfloat16) * 1000
|
|
|
|
# Calculate scales based on the inputs
|
|
A_scale = data_to_mx_scale(A_ref, BLOCK_SIZE)
|
|
B_scale = data_to_mx_scale(B_ref, BLOCK_SIZE)
|
|
|
|
max_val = F8E4M3_MAX_VAL
|
|
min_val = -1 * max_val
|
|
|
|
A = (A_ref.reshape(-1, BLOCK_SIZE) / A_scale.reshape(M * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(M, K)
|
|
A = A.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
B = (B_ref.reshape(-1, BLOCK_SIZE) / B_scale.reshape(N * ceil_div(K, BLOCK_SIZE), 1).float()).reshape(N, K)
|
|
B = B.clamp(min=min_val, max=max_val).to(torch.float8_e4m3fn)
|
|
|
|
# convert to swizzled format
|
|
A_scale = to_blocked(A_scale)
|
|
B_scale = to_blocked(B_scale)
|
|
|
|
C_ref = A_ref @ B_ref.t()
|
|
|
|
C = torch._scaled_mm(
|
|
A,
|
|
B.t(),
|
|
A_scale,
|
|
B_scale,
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=fast_accum,
|
|
)
|
|
|
|
if require_exact_match:
|
|
torch.testing.assert_close(C, C_ref, atol=0, rtol=0)
|
|
else:
|
|
sqnr = compute_error(C_ref, C)
|
|
assert sqnr.item() > 22.0
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
|
def test_blockwise_mxfloat8_error_messages(self, device) -> None:
|
|
M, K, N = (1024, 512, 2048)
|
|
BLOCK_SIZE_K = 32
|
|
BLOCK_SIZE_MN = 128
|
|
fill_value = 0.5
|
|
|
|
x = torch.full((M, K), fill_value, device=device)
|
|
y = torch.full((N, K), fill_value, device=device)
|
|
|
|
x_fp8 = x.to(e4m3_type)
|
|
y_fp8 = y.to(e4m3_type).t()
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
num_k_blocks = ceil_div(K, BLOCK_SIZE_K)
|
|
padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
|
|
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
|
|
|
|
|
# Test wrong scale tensor size for scale_a with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
f"For BlockWise scaling: Expected scale_a size to be {expected_a_size} "
|
|
f"but got {expected_a_size - 1}"
|
|
),
|
|
):
|
|
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
correct_size_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=incorrect_size_a,
|
|
scale_b=correct_size_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Test wrong scale tensor size for scale_b with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
f"For BlockWise scaling: Expected scale_b size to be {expected_b_size} "
|
|
f"but got {expected_b_size + 1}"
|
|
),
|
|
):
|
|
correct_size_a = torch.ones(expected_a_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=correct_size_a,
|
|
scale_b=incorrect_size_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Test non-contiguous scale tensors with correct dtype
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
re.escape(
|
|
"For BlockWise scaling: Both scale_a and scale_b must be contiguous"
|
|
),
|
|
):
|
|
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=torch.float8_e8m0fnu)[::2]
|
|
contiguous_b = torch.ones(expected_b_size, device=device, dtype=torch.float8_e8m0fnu)
|
|
torch._scaled_mm(
|
|
x_fp8,
|
|
y_fp8,
|
|
scale_a=non_contiguous_a,
|
|
scale_b=contiguous_b,
|
|
out_dtype=torch.bfloat16,
|
|
)
|
|
|
|
def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
|
|
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
|
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
|
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
|
self.assertEqual(out, out_ref)
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("strided", [False, True])
|
|
def test_grouped_gemm_2d_2d(self, fast_accum, strided):
|
|
device = "cuda"
|
|
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
|
|
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
|
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
|
|
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
|
|
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
|
|
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
|
out = torch._scaled_grouped_mm(a, b.t(), scale_a, scale_b, offs=offs,
|
|
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
|
offs_cpu = offs.cpu()
|
|
alist, blist, ascalelist, bscalelist = [], [], [], []
|
|
start = 0
|
|
for i in range(n_groups):
|
|
alist.append(a[:, start:offs_cpu[i]])
|
|
blist.append(b[:, start:offs_cpu[i]])
|
|
ascalelist.append(scale_a[i * m : (i + 1) * m])
|
|
bscalelist.append(scale_b[i * n : (i + 1) * n])
|
|
start = offs_cpu[i]
|
|
self.grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum)
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("strided", [False, True])
|
|
def test_grouped_gemm_2d_3d(self, fast_accum, strided):
|
|
device = "cuda"
|
|
s_int = int(strided)
|
|
m, n, k, n_groups = 16, 32, 16, 4
|
|
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
|
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
|
self.assertTrue(a.is_contiguous() is not strided)
|
|
self.assertTrue(b.is_contiguous() is not strided)
|
|
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
|
|
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
|
|
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
|
|
|
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
|
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
|
|
|
offs_cpu = offs.cpu()
|
|
alist, ascalelist, outlist = [], [], []
|
|
start = 0
|
|
for i in range(n_groups):
|
|
alist.append(a[start:offs_cpu[i]])
|
|
ascalelist.append(scale_a[start:offs_cpu[i]])
|
|
outlist.append(out[start:offs_cpu[i]])
|
|
start = offs_cpu[i]
|
|
self.grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("strided", [False, True])
|
|
def test_grouped_gemm_3d_3d(self, fast_accum, strided):
|
|
device = "cuda"
|
|
s_int = int(strided)
|
|
m, n, k, n_groups = 16, 32, 16, 4
|
|
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
|
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
|
self.assertTrue(a.is_contiguous() is not strided)
|
|
self.assertTrue(b.is_contiguous() is not strided)
|
|
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
|
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
|
|
|
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b,
|
|
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
|
|
|
self.grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum)
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
|
|
@parametrize("fast_accum", [False, True])
|
|
@parametrize("strided", [False, True])
|
|
def test_grouped_gemm_3d_2d(self, fast_accum, strided):
|
|
device = "cuda"
|
|
s_int = int(strided)
|
|
m, n, k, n_groups = 16, 32, 16, 4
|
|
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
|
|
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
|
|
self.assertTrue(a.is_contiguous() is not strided)
|
|
self.assertTrue(b.is_contiguous() is not strided)
|
|
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
|
|
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
|
|
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
|
|
|
|
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
|
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
|
offs_cpu = offs.cpu()
|
|
blist, bscalelist, outlist = [], [], []
|
|
start = 0
|
|
for i in range(n_groups):
|
|
blist.append(b[start:offs_cpu[i]])
|
|
bscalelist.append(scale_b[start:offs_cpu[i]])
|
|
outlist.append(out[:, start:offs_cpu[i]])
|
|
start = offs_cpu[i]
|
|
self.grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
|
|
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
|
|
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
|
|
class TestMixedDtypesLinearCuda(TestCase):
|
|
@dtypes(torch.float16, torch.bfloat16)
|
|
def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
|
|
version = _get_torch_cuda_version()
|
|
if version < (11, 8):
|
|
self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
|
|
|
|
def run_test(
|
|
batch_shape,
|
|
m,
|
|
n,
|
|
k,
|
|
add_bias,
|
|
activation,
|
|
dtype,
|
|
dtypeq,
|
|
device,
|
|
rtol,
|
|
atol,
|
|
):
|
|
if not add_bias and activation != "none":
|
|
return
|
|
|
|
val_lo, val_hi = -1, 1
|
|
valq_lo, valq_hi = -2, 2
|
|
input = make_tensor(
|
|
*batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
|
|
)
|
|
weight = make_tensor(
|
|
n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
|
|
)
|
|
scale = make_tensor(
|
|
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
|
|
)
|
|
bias = (
|
|
make_tensor(
|
|
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
|
|
)
|
|
if add_bias
|
|
else None
|
|
)
|
|
|
|
input_ref = input.reshape(-1, input.shape[-1])
|
|
|
|
# First, test plain multiplication.
|
|
weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
|
|
weightq = (
|
|
pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
|
|
)
|
|
output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
|
|
output = torch.ops.aten._mixed_dtypes_linear(
|
|
input,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|
weightq, dtypeq, transpose=False
|
|
),
|
|
scale,
|
|
)
|
|
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
|
|
|
|
# Second, test the linear operator itself.
|
|
weight_ref = weight.to(input.dtype) * scale.view(n, 1)
|
|
weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
|
|
bias_ref = bias.view(1, n) if add_bias else None
|
|
output_ref = torch.nn.functional.linear(
|
|
input_ref, weight_ref, bias=bias_ref
|
|
).reshape(*input.shape[:-1], n)
|
|
if activation == "relu":
|
|
relu = torch.nn.ReLU()
|
|
output_ref = relu(output_ref)
|
|
elif activation == "silu":
|
|
silu = torch.nn.SiLU()
|
|
output_ref = silu(output_ref)
|
|
output = torch.ops.aten._mixed_dtypes_linear(
|
|
input,
|
|
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|
weightq, dtypeq, transpose=True
|
|
),
|
|
scale,
|
|
bias=bias,
|
|
activation=activation,
|
|
)
|
|
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
|
|
|
|
dtypeqs = [torch.int8, torch.quint4x2]
|
|
batch_shapes = [[], [2], [2, 1]]
|
|
shapes = [
|
|
[8, 64, 64],
|
|
[8, 64, 128],
|
|
[8, 128, 64],
|
|
[8, 128, 128],
|
|
[8, 128, 192],
|
|
[8, 128, 256],
|
|
[8, 256, 128],
|
|
[8, 256, 384],
|
|
[8, 384, 256],
|
|
]
|
|
activations = [None, "relu", "silu"]
|
|
rtol, atol = 1e-3, 1e-3
|
|
if dtype == torch.bfloat16:
|
|
rtol, atol = 1e-2, 1e-3
|
|
for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
|
|
dtypeqs, batch_shapes, shapes, (False, True), activations
|
|
):
|
|
run_test(
|
|
batch_shape,
|
|
m,
|
|
n,
|
|
k,
|
|
add_bias,
|
|
activation,
|
|
dtype,
|
|
dtypeq,
|
|
device,
|
|
rtol,
|
|
atol,
|
|
)
|
|
|
|
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
|
|
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
|
|
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
|
|
|
|
if __name__ == '__main__':
|
|
TestCase._default_dtype_check_enabled = True
|
|
run_tests()
|