# Owner(s): ["module: linear algebra"] import contextlib import unittest from itertools import product from functools import partial import torch from torch.quantization._quantized_conversions import ( pack_int4_to_int8, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, ) from torch.testing import make_tensor from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_BF16, SM53OrLater, SM80OrLater, SM90OrLater, SM100OrLater, xfailIfSM120OrLater, _get_torch_cuda_version, ) from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, onlyCUDA, tol as xtol, toleranceOverride, ) from torch.testing._internal.common_utils import ( IS_JETSON, IS_WINDOWS, MI200_ARCH, NAVI_ARCH, getRocmVersion, isRocmArchAnyOf, parametrize, run_tests, runOnRocmArch, skipIfRocm, TEST_CUDA, TEST_WITH_ROCM, TestCase, decorateIf, ) from torch._inductor.test_case import TestCase as InductorTestCase _IS_SM8X = False if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 # Protects against includes accidentally setting the default dtype assert torch.get_default_dtype() is torch.float32 def xfailIfSM100OrLaterNonRTXAndCondition(condition_fn): """ Conditionally xfail tests on SM100+ datacenter SKUs based on a condition function. The condition function receives the test parameters dict and returns True to xfail. """ computeCapabilityCheck = SM100OrLater and torch.cuda.get_device_capability()[0] != 12 return decorateIf( unittest.expectedFailure, lambda params: computeCapabilityCheck and condition_fn(params) ) @contextlib.contextmanager def blas_library_context(backend): prev_backend = torch.backends.cuda.preferred_blas_library() torch.backends.cuda.preferred_blas_library(backend) try: yield finally: torch.backends.cuda.preferred_blas_library(prev_backend) class TestMatmulCuda(InductorTestCase): def setUp(self): super().setUp() torch.backends.cuda.matmul.allow_tf32 = False def tearDown(self): torch.backends.cuda.matmul.allow_tf32 = True super().tearDown() def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): # # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # results from the CUDA invocation of torch.addmm and the CPU invocation # (which does not use CUDA backend). # # Get dims n, m, p = (size + 1, size, size + 2) # Disable reduced precision reductions in BFloat16 to bypass some kernels # which fail the threshold check orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate # Make random tensors on CPU (seed set on common_utils.py import) # (Not using numpy because it does not support bfloat16) make_arg = partial(make_tensor, dtype=dtype, device="cpu") m_beta = make_arg(1) m_input = make_arg((n, p)) m_1 = make_arg((n, m)) m_2 = make_arg((m, p)) # scale to abate overflows in fp16 accum if fp16_accumulate: m_1 = m_1 / 100 m_2 = m_2 / 100 # *(B)FLOAT16 Special Handling* # Backend does not tensorize float16 on CPU, # and bloat16 may present accuracy issues, # so convert to float32 for these cases # (but keep same for other types, e.g. float32 and int*) if dtype == torch.float16 or dtype == torch.bfloat16: m_beta = m_beta.to(dtype=torch.float32) m_input = m_input.to(dtype=torch.float32) m_1 = m_1.to(dtype=torch.float32) m_2 = m_2.to(dtype=torch.float32) # Get CPU result res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) # *(B)FLOAT16 Special Handling*`` # Convert back to (b)float16 if dtype == torch.float16 or dtype == torch.bfloat16: m_beta = m_beta.to(dtype=dtype) m_input = m_input.to(dtype=dtype) m_1 = m_1.to(dtype=dtype) m_2 = m_2.to(dtype=dtype) res_cpu = res_cpu.to(dtype=dtype) # Move arg tensors to CUDA m_beta = m_beta.to("cuda") m_input = m_input.to("cuda") m_1 = m_1.to("cuda") m_2 = m_2.to("cuda") # Get CUDA result res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) # Move to CPU for comparison res_cuda = res_cuda.to("cpu") # Compare self.assertEqual(res_cpu, res_cuda) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate @onlyCUDA # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), torch.float32: xtol(atol=1e-1, rtol=1e-1)}) @dtypes(torch.float16, torch.bfloat16, torch.float32) @parametrize("size", [100, 1000, 10000]) @parametrize("backend", ["cublas", "cublaslt"]) def test_cublas_addmm(self, size: int, dtype: torch.dtype, backend): with blas_library_context(backend): if (TEST_WITH_ROCM and backend == "cublas" and isRocmArchAnyOf(NAVI_ARCH) and getRocmVersion() < (6, 4) and dtype == torch.float16 and size >= 10000): self.skipTest(f"failed on Navi for ROCm6.3 due to hipblas backend, dtype={dtype} and size={size}") self.cublas_addmm(size, dtype, False) @onlyCUDA @xfailIfSM100OrLaterNonRTXAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000) # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) @parametrize("backend", ["cublas", "cublaslt"]) def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, backend): with blas_library_context(backend): self.cublas_addmm(size, dtype, True) @onlyCUDA @dtypes(torch.float16) # m == 4 chooses OUTPUT_TYPE reduction on H200 # m == 8 chooses OUTPUT_TYPE reduction on A100 @parametrize("small_size", [4, 8]) @parametrize("size", [32768]) @parametrize("backend", ["cublaslt", "cublas"]) def test_cublas_addmm_no_reduced_precision(self, small_size: int, size: int, dtype: torch.dtype, backend): with blas_library_context(backend): torch.backends.cuda.preferred_blas_library(backend) orig_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False m1 = torch.full((small_size, size), 65504.0, dtype=dtype, device='cuda') m2 = torch.ones((size, small_size), dtype=dtype, device='cuda') m2[size // 2:, :] = -1.0 b = torch.zeros((small_size,), dtype=dtype, device='cuda') out = torch.addmm(b, m1, m2, beta=1.0) self.assertEqual(out.sum().item(), 0.0) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_precision @onlyCUDA # imported 'tol' as 'xtol' to avoid aliasing in code above @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) @dtypes(torch.float16, torch.bfloat16) @parametrize("size", [100, 1000, 10000]) @parametrize("backend", ["cublas", "cublaslt"]) def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype, backend): with blas_library_context(backend): self.cublas_addmm(size, dtype, False, True) @onlyCUDA def test_cublas_and_lt_reduced_precision_fp16_accumulate(self): orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_fp16_accumulation = True x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half) w = torch.rand(512, 512, device='cuda', dtype=torch.half) b = torch.rand(512, device='cuda', dtype=torch.half) out = torch.nn.functional.linear(x, w, b) out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu()) self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3) a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) out = torch.baddbmm(a, b, c) out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu()) self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3) torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate @onlyCUDA @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) @dtypes(torch.float16) def test_cublas_addmm_alignment(self, dtype): device = 'cuda' # perturb X, A, or B alignment for idx in range(3): for offset in range(1, 3): offsets = [0, 0, 0] offsets[idx] = offset x_offset, a_offset, b_offset = offsets A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) A = A[a_offset:].reshape(5120, 2560) X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) X = X[x_offset:].reshape(26, 1, 2560) B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) B = B[b_offset:].reshape(5120) out = torch.nn.functional.linear(X, A, B) self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) @onlyCUDA @unittest.skipIf(IS_JETSON, "Too large for Jetson") @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) @dtypes(*([torch.float32, torch.float16] + [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) @parametrize( "batch_size, N, M, P", [(2, 100, 100, 100), (2, 1000, 1000, 1000), (1, 10000, 1000, 10000), (1, 10000, 10000, 10000)], name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", ) def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): cpu_dtype = dtype if dtype == torch.float16 or dtype == torch.bfloat16: cpu_dtype = torch.float32 M1 = torch.rand((N, M), device=device, dtype=dtype) M2 = torch.rand((M, P), device=device, dtype=dtype) A = torch.rand((N, P), device=device, dtype=dtype) def _convert_to_cpu(t): return t.to(device='cpu', dtype=cpu_dtype) M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) # linear out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() self.assertEqual(out1_cpu, out1_gpu) # test multiply the identity matrix if N == M and M == P: M2_eye = torch.eye(N, device=device, dtype=dtype) out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) if runOnRocmArch(MI200_ARCH) and dtype == torch.float16: self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu(), atol=1e-4, rtol=0.001) else: self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) # baddbmm def _expand_to_batch(t: torch.Tensor): return t.expand((batch_size, ) + t.size()) alpha, beta = 1.0, 1.0 M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() self.assertEqual(out2_cpu, out2_gpu) # test multiply the identity matrix if N == M and M == P: M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) if runOnRocmArch(MI200_ARCH) and dtype == torch.float16: self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu(), atol=1e-4, rtol=0.001) else: self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) # cross comparison self.assertEqual(out1_gpu, out2_gpu[0]) @onlyCUDA @skipIfRocm @parametrize("shape", [2**i for i in range(5, 14)]) @dtypes(torch.float, torch.half, torch.bfloat16) def test_cublas_deterministic(self, device, shape, dtype): inp = torch.randn(shape, shape, device=device, dtype=dtype) first = torch.matmul(inp, inp) for _ in range(10): self.assertEqual(first, torch.matmul(inp, inp), atol=0., rtol=0.) def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist): for a, b, gO, agrad, bgrad, out in zip(alist, blist, gOlist, agradlist, bgradlist, outlist): a = a.clone().detach().requires_grad_() b = b.clone().detach().requires_grad_() out_ref = torch.mm(a, b.t()) out_ref.backward(gO) self.assertEqual(out, out_ref) if agrad is not None: self.assertEqual(agrad, a.grad) self.assertEqual(bgrad, b.grad) @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] else: a = torch.randn(k * n_groups + k * int(strided), m, device=device, dtype=dtype).t()[:, :k * n_groups] if b_row_major: b = torch.randn(n, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] else: b = torch.randn(k * n_groups + k * int(strided), n, device=device, dtype=dtype).t()[:, :k * n_groups] a.requires_grad_(True) b.requires_grad_(True) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) f = torch._grouped_mm out = f(a, b.t(), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) offs_cpu = offs.cpu() alist, blist, agradlist, bgradlist = [], [], [], [] start = 0 for i in range(n_groups): alist.append(a[:, start:offs_cpu[i]]) blist.append(b[:, start:offs_cpu[i]]) agradlist.append(a.grad[:, start:offs_cpu[i]]) bgradlist.append(b.grad[:, start:offs_cpu[i]]) start = offs_cpu[i] self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] else: a = torch.randn(k, (m + 2 * s_int) * n_groups, device=device, dtype=dtype).t()[:m * n_groups, :] if b_row_major: b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] a.requires_grad_(True) b.requires_grad_(True) a_contig = a if a_row_major else a.t() self.assertTrue(a_contig.is_contiguous() is not strided) b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): if check_zero_size and n_groups <= 1: continue a.grad = None b.grad = None offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32) if check_zero_size: offs[0] = offs[1] f = torch._grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) offs_cpu = offs.cpu() alist, agradlist, gOlist, outlist = [], [], [], [] bgradlist = [None] * n_groups if check_zero_size else b.grad start = 0 for i in range(n_groups): alist.append(a[start:offs_cpu[i]]) agradlist.append(None if check_zero_size else a.grad[start:offs_cpu[i]]) outlist.append(out[start:offs_cpu[i]]) gOlist.append(gO[start:offs_cpu[i]]) start = offs_cpu[i] self.grouped_mm_helper(alist, b, gOlist, agradlist, bgradlist, outlist) @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] if b_row_major: b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] a.requires_grad_(True) b.requires_grad_(True) a_contig = a if a_row_major else a.transpose(-2, -1) self.assertTrue(a_contig.is_contiguous() is not strided) b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) f = torch._grouped_mm out = f(a, b.transpose(-2, -1), out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) @xfailIfSM120OrLater @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]: self.skipTest("failed using hipblaslt on rocm 6.4.2") device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 if a_row_major: a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] else: a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] if b_row_major: b = torch.randn(n * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] else: b = torch.randn(k, n * (n_groups + s_int), device=device, dtype=dtype).transpose(-2, -1)[:n * n_groups, :] a.requires_grad_(True) b.requires_grad_(True) a_contig = a if a_row_major else a.transpose(-2, -1) self.assertTrue(a_contig.is_contiguous() is not strided) b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) for check_zero_size in (False, True): if check_zero_size and n_groups <= 1: continue offs = torch.arange(n, n_groups * n + 1, n, device=device, dtype=torch.int32) if check_zero_size: offs[0] = offs[1] f = torch._grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: out.backward(gO) offs_cpu = offs.cpu() blist, outlist, bgradlist, gOlist = [], [], [], [] agradlist = [None] * n_groups if check_zero_size else a.grad start = 0 for i in range(n_groups): blist.append(b[start:offs_cpu[i]]) bgradlist.append(b.grad[start:offs_cpu[i]]) outlist.append(out[:, start:offs_cpu[i]]) gOlist.append(gO[:, start:offs_cpu[i]]) start = offs_cpu[i] self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") # TODO(future PR): enable compile for torch._grouped_mm fallback path @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) @parametrize("max_autotune", [False, True]) def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): device = "cuda" dtype_AB = torch.bfloat16 dtype_offset = torch.int32 align = 16 // dtype_AB.itemsize f_ref = torch._grouped_mm options = {} if max_autotune: options.update( { "max_autotune": True, "max_autotune_gemm_backends": "TRITON", } ) f = torch.compile( f_ref, options=options, ) if op == "2d/2d": m, n = 3, 7 m_align = (m + align - 1) // align * align n_align = (n + align - 1) // align * align if not a_row_major and not b_row_major: offs = torch.tensor([0, 1, 6, 6, 7], device=device, dtype=dtype_offset) else: offs = torch.tensor([0, 8, 16, 16, 27], device=device, dtype=dtype_offset) ngroups = offs.shape[0] k = offs[-1] k_align = (k + align - 1) // align * align if a_row_major: A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] else: A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] if b_row_major: B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] else: B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] elif op == "2d/3d": n, k = 7, 259 # k is larger here, to validate iterating over k tiles on an op n_align = (n + align - 1) // align * align k_align = (k + align - 1) // align * align if a_row_major: offs = torch.tensor([0, 1, 3, 3, 5], device=device, dtype=dtype_offset) else: offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) ngroups = offs.shape[0] m = offs[-1] m_align = (m + align - 1) // align * align if a_row_major: A = torch.randn(m, k_align, device=device, dtype=dtype_AB)[:, :k] else: A = torch.randn(k, m_align, device=device, dtype=dtype_AB).t()[:m, :] if b_row_major: B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] else: B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( -2, -1 )[:, :n, :] elif op == "3d/2d": m, k = 3, 13 m_align = (m + align - 1) // align * align k_align = (k + align - 1) // align * align offs = torch.tensor([0, 8, 16, 16, 19], device=device, dtype=dtype_offset) ngroups = offs.shape[0] n = offs[-1] n_align = (n + align - 1) // align * align if a_row_major: A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] else: A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( -2, -1 )[:, :m, :] if b_row_major: B = torch.randn(n, k_align, device=device, dtype=dtype_AB)[:, :k] else: B = torch.randn(k, n_align, device=device, dtype=dtype_AB).t()[:n, :] elif op == "3d/3d": offs = None ngroups = 5 m, n, k = 3, 7, 13 m_align = (m + align - 1) // align * align n_align = (n + align - 1) // align * align k_align = (k + align - 1) // align * align if a_row_major: A = torch.randn(ngroups, m, k_align, device=device, dtype=dtype_AB)[:, :, :k] else: A = torch.randn(ngroups, k, m_align, device=device, dtype=dtype_AB).transpose( -2, -1 )[:, :m, :] if b_row_major: B = torch.randn(ngroups, n, k_align, device=device, dtype=dtype_AB)[:, :, :k] else: B = torch.randn(ngroups, k, n_align, device=device, dtype=dtype_AB).transpose( -2, -1 )[:, :n, :] else: raise AssertionError(f"Invalid op: {op}") C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) C = f(A, B.transpose(-2, -1), offs=offs) torch.testing.assert_close(C, C_ref) @onlyCUDA @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64]) @parametrize("batch_size", [None, 1, 16]) @parametrize("backend", ["cublas", "cublaslt"]) def test_mm_bmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): if torch.version.hip: msg = "accuracy regression in hipblas and hipblaslt in ROCm 7.0 for certain shapes" if input_dtype == torch.bfloat16 and N == 1 and K == 32 and batch_size: raise unittest.SkipTest(msg) if input_dtype == torch.bfloat16 and N == 1 and K == 64 and batch_size: raise unittest.SkipTest(msg) if input_dtype == torch.float16 and M == 32 and N == 1 and K == 64 and batch_size == 1: raise unittest.SkipTest(msg) if input_dtype == torch.float16 and M == 64 and N == 1 and K == 64 and batch_size == 1: raise unittest.SkipTest(msg) device = "cuda" dtype = input_dtype with blas_library_context(backend): def create_inputs(B=None): if B is None: a = torch.randn(M, K, device=device, dtype=dtype) b = torch.randn(K, N, device=device, dtype=dtype) else: a = torch.randn(B, M, K, device=device, dtype=dtype) b = torch.randn(B, K, N, device=device, dtype=dtype) return a, b a, b = create_inputs(batch_size) a_fp32, b_fp32 = a.to(torch.float32), b.to(torch.float32) output_dtypes = [torch.float32] if input_dtype != torch.float32: output_dtypes.append(input_dtype) for output_dtype in output_dtypes: # Catch edge case of incompat with bfloat16 and major version < 8 if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: if output_dtype == torch.bfloat16: continue if batch_size: with self.assertRaises(RuntimeError): torch.bmm(a, b, out_dtype=output_dtype) else: with self.assertRaises(RuntimeError): torch.mm(a, b, out_dtype=output_dtype) else: if batch_size: out = torch.bmm(a, b, out_dtype=output_dtype) baseline = torch.bmm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.bmm(a, b) else: out = torch.mm(a, b, out_dtype=output_dtype) baseline = torch.mm(a_fp32, b_fp32) if output_dtype == torch.float32 else torch.mm(a, b) self.assertEqual(out.dtype, output_dtype) torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) @onlyCUDA @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @parametrize("M", [1, 32, 64]) @parametrize("N", [1, 32, 64]) @parametrize("K", [1, 32, 64]) @parametrize("batch_size", [None, 1, 32]) @parametrize("backend", ["cublas", "cublaslt"]) def test_addmm_baddmm_dtype_overload(self, input_dtype, M, N, K, batch_size, backend): if torch.version.hip: msg = "accuracy regression in hipblas and hipblaslt in ROCm 7.0 for certain shapes" if input_dtype == torch.bfloat16 and N == 1 and K == 32 and batch_size: raise unittest.SkipTest(msg) if input_dtype == torch.bfloat16 and N == 1 and K == 64 and batch_size: raise unittest.SkipTest(msg) if input_dtype == torch.float16 and M == 32 and N == 1 and K == 64 and batch_size == 1: raise unittest.SkipTest(msg) if input_dtype == torch.float16 and M == 64 and N == 1 and K == 64 and batch_size == 1: raise unittest.SkipTest(msg) device = "cuda" dtype = input_dtype with blas_library_context(backend): def create_inputs(B=None): if B is None: a = torch.randn(M, K, device=device, dtype=dtype) b = torch.randn(K, N, device=device, dtype=dtype) c = torch.randn(M, N, device=device, dtype=dtype) else: a = torch.randn(B, M, K, device=device, dtype=dtype) b = torch.randn(B, K, N, device=device, dtype=dtype) c = torch.randn(B, M, N, device=device, dtype=dtype) return a, b, c a, b, c = create_inputs(batch_size) a_fp32, b_fp32, c_fp32 = a.to(torch.float32), b.to(torch.float32), c.to(torch.float32) output_dtypes = [torch.float32] if input_dtype != torch.float32: output_dtypes.append(input_dtype) for output_dtype in output_dtypes: # Catch edge case of incompat with bfloat16 and major version < 8 if input_dtype == torch.bfloat16 and not PLATFORM_SUPPORTS_BF16: if output_dtype == torch.bfloat16: continue if batch_size: with self.assertRaises(RuntimeError): torch.baddbmm(c, a, b, out_dtype=output_dtype) else: with self.assertRaises(RuntimeError): torch.addmm(c, a, b, out_dtype=output_dtype) else: if batch_size: out = torch.baddbmm(c, a, b, out_dtype=output_dtype) if output_dtype == torch.float32: baseline = torch.baddbmm(c_fp32, a_fp32, b_fp32) else: baseline = torch.baddbmm(c, a, b) else: out = torch.addmm(c, a, b, out_dtype=output_dtype) if output_dtype == torch.float32: baseline = torch.addmm(c_fp32, a_fp32, b_fp32) else: baseline = torch.addmm(c, a, b) self.assertEqual(out.dtype, output_dtype) torch.testing.assert_close(out, baseline, atol=1e-3, rtol=1e-3) @onlyCUDA @parametrize("batch_size", [1, 32]) @parametrize("backend", ["cublas", "cublaslt"]) def test_fp16_accum_and_fp32_out_failure(self, batch_size, backend): M, N, K = 32, 32, 32 device = "cuda" dtype = torch.float16 with blas_library_context(backend): torch.backends.cuda.preferred_blas_library(backend) orig_fp16_accum = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_fp16_accumulation = True def create_inputs(): a = torch.randn(M, K, device=device, dtype=dtype) b = torch.randn(K, N, device=device, dtype=dtype) c = torch.randn(M, N, device=device, dtype=dtype) return a, b, c def expand(tensor): return tensor.unsqueeze(0).expand(batch_size, *tensor.shape) a, b, c = create_inputs() with self.assertRaises(Exception): torch.baddbmm(expand(c), expand(a), expand(b), out_dtype=torch.float32) with self.assertRaises(Exception): torch.addmm(c, a, b, out_dtype=torch.float32) with self.assertRaises(Exception): torch.bmm(expand(a,), expand(b), out_dtype=torch.float32) with self.assertRaises(Exception): torch.mm(a, b, out_dtype=torch.float32) torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accum @onlyCUDA @parametrize("ops", [("mm", torch.mm), ("bmm", torch.bmm), ("addmm", torch.addmm), ("baddbmm", torch.baddbmm)]) def test_input_dimension_checking_out_dtype(self, ops): op_name, op = ops B = 2 M, N, K = 32, 32, 32 def is_addmm(): return "add" in op_name def is_batched(): return "bmm" in op_name if is_batched(): a = torch.randn(B, M, K, device="cuda", dtype=torch.bfloat16) mismatch_k_b = torch.randn(B, K + 1, N, device="cuda", dtype=torch.bfloat16) c = torch.randn(B, M, N, device="cuda", dtype=torch.bfloat16) extra_dim_b = a.clone().unsqueeze(0) mismatch_k_err = "Expected size for first two dimensions of batch2 tensor to be" extra_dim_err = "batch2 must be a 3D tensor" else: a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) mismatch_k_b = torch.randn(K + 1, N, device="cuda", dtype=torch.bfloat16) c = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) extra_dim_b = a.clone().unsqueeze(0) mismatch_k_err = "mat1 and mat2 shapes cannot be multiplied" extra_dim_err = "mat2 must be a matrix, got 3-D tensor" # Test mismatch K with self.assertRaisesRegex(RuntimeError, mismatch_k_err): if is_addmm(): op(c, a, mismatch_k_b, out_dtype=torch.float32) else: op(a, mismatch_k_b, out_dtype=torch.float32) # Test extra dimension with self.assertRaisesRegex(RuntimeError, extra_dim_err): if is_addmm(): op(c, a, extra_dim_b, out_dtype=torch.float32) else: op(c, extra_dim_b, out_dtype=torch.float32) if is_batched(): with self.assertRaisesRegex(RuntimeError, "Expected size for first two dimensions of batch2 tensor to be"): # Test mismatch B for bmm/baddbmm mismatch_batch_dim_b = torch.randn(B + 1, K, N, device="cuda", dtype=torch.bfloat16) if is_addmm(): op(c, a, mismatch_batch_dim_b, out_dtype=torch.float32) else: op(a, mismatch_batch_dim_b, out_dtype=torch.float32) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") class TestMixedDtypesLinearCuda(TestCase): @dtypes(torch.float16, torch.bfloat16) def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): version = _get_torch_cuda_version() if version < (11, 8): self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") def run_test( batch_shape, m, n, k, add_bias, activation, dtype, dtypeq, device, rtol, atol, ): if not add_bias and activation != "none": return val_lo, val_hi = -1, 1 valq_lo, valq_hi = -2, 2 input = make_tensor( *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device ) weight = make_tensor( n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device ) scale = make_tensor( (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device ) bias = ( make_tensor( (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device ) if add_bias else None ) input_ref = input.reshape(-1, input.shape[-1]) # First, test plain multiplication. weight_ref = weight.T.to(input.dtype) * scale.view(1, n) weightq = ( pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T ) output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) output = torch.ops.aten._mixed_dtypes_linear( input, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( weightq, dtypeq, transpose=False ), scale, ) torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) # Second, test the linear operator itself. weight_ref = weight.to(input.dtype) * scale.view(n, 1) weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight bias_ref = bias.view(1, n) if add_bias else None output_ref = torch.nn.functional.linear( input_ref, weight_ref, bias=bias_ref ).reshape(*input.shape[:-1], n) if activation == "relu": relu = torch.nn.ReLU() output_ref = relu(output_ref) elif activation == "silu": silu = torch.nn.SiLU() output_ref = silu(output_ref) output = torch.ops.aten._mixed_dtypes_linear( input, quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( weightq, dtypeq, transpose=True ), scale, bias=bias, activation=activation, ) torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) dtypeqs = [torch.int8, torch.quint4x2] batch_shapes = [[], [2], [2, 1]] shapes = [ [8, 64, 64], [8, 64, 128], [8, 128, 64], [8, 128, 128], [8, 128, 192], [8, 128, 256], [8, 256, 128], [8, 256, 384], [8, 384, 256], ] activations = [None, "relu", "silu"] rtol, atol = 1e-3, 1e-3 if dtype == torch.bfloat16: rtol, atol = 1e-2, 1e-3 for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( dtypeqs, batch_shapes, shapes, (False, True), activations ): run_test( batch_shape, m, n, k, add_bias, activation, dtype, dtypeq, device, rtol, atol, ) instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") if __name__ == '__main__': TestCase._default_dtype_check_enabled = True run_tests()