mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation This PR is a part of RFC #114848, and it is a successor PR of #116249 and #116019. This PR would depend on oneDNN compilation in #116249. Some runtime support is needed in #116019. Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`. Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`. `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in #117824. In next stage, we would continually complete the oneDNN support through enabling `matmul fusion` and `convolution` related code. Co-authored-by: xiaolil1 <xiaoli.liu@intel.com> Co-authored-by: lei,zhenyuan <zhenyuan.lei@intel.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/117202 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet ghstack dependencies: #117098, #117112
1149 lines
46 KiB
Python
1149 lines
46 KiB
Python
# Owner(s): ["module: intel"]
|
|
|
|
import itertools
|
|
import math
|
|
import random
|
|
from functools import partial
|
|
from itertools import product
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
precisionOverride,
|
|
)
|
|
from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase
|
|
|
|
|
|
class TestBasicGEMM(TestCase):
|
|
def _test_addmm_addmv(
|
|
self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False, activation=None
|
|
):
|
|
dtype = t.dtype
|
|
numpy_dtype = dtype
|
|
if dtype in {torch.bfloat16, torch.half}:
|
|
numpy_dtype = torch.float
|
|
if dtype.is_complex:
|
|
alpha = 0.9 + 0.3j if alpha is None else alpha
|
|
beta = 0.5 + 0.6j if beta is None else beta
|
|
else:
|
|
alpha = 1.2 if alpha is None else alpha
|
|
beta = 0.8 if beta is None else beta
|
|
if activation == "gelu":
|
|
res1 = f(t, m, v, alpha=alpha, beta=beta, use_gelu=True)
|
|
else:
|
|
res1 = f(t, m, v, alpha=alpha, beta=beta)
|
|
res2 = torch.full_like(res1, math.nan)
|
|
if transpose_out:
|
|
res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
|
|
if activation == "gelu":
|
|
f(t, m, v, alpha=alpha, beta=beta, out=res2, use_gelu=True)
|
|
else:
|
|
f(t, m, v, alpha=alpha, beta=beta, out=res2)
|
|
m.to(numpy_dtype).cpu().numpy()
|
|
v.to(numpy_dtype).cpu().numpy()
|
|
res3 = alpha * (
|
|
m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()
|
|
)
|
|
if beta != 0:
|
|
res3 += (beta * t).to(numpy_dtype).cpu().numpy()
|
|
if activation == "relu":
|
|
res3 = res3 * (res3 > 0)
|
|
elif activation == "gelu":
|
|
res3_t = torch.from_numpy(res3).to(dtype)
|
|
approximate = "tanh" if t.is_cuda else "none"
|
|
res3_t = torch.nn.functional.gelu(res3_t, approximate=approximate)
|
|
res3 = res3_t.to(numpy_dtype).cpu().numpy()
|
|
else:
|
|
assert activation is None, f"unsupported activation {activation}"
|
|
res3 = torch.from_numpy(res3).to(dtype)
|
|
self.assertEqual(res1, res2)
|
|
self.assertEqual(res1, res3)
|
|
|
|
def _test_addmm_impl(self, func, activation, device, dtype):
|
|
M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
|
|
|
|
# vector-shaped bias and beta=1 result in epilogue fusion in CUDA
|
|
V = torch.randn(25, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation)
|
|
|
|
# Test 0-strided
|
|
M = (
|
|
torch.randn(10, 1, device="cpu", dtype=torch.float32)
|
|
.to(dtype)
|
|
.expand(10, 25)
|
|
.to(device)
|
|
)
|
|
m1 = (
|
|
torch.randn(10, 1, device="cpu", dtype=torch.float32)
|
|
.to(dtype)
|
|
.expand(10, 50)
|
|
.to(device)
|
|
)
|
|
m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
self._test_addmm_addmv(func, M, m1, m2, activation=activation)
|
|
|
|
# Test beta=0, M=nan
|
|
M = (
|
|
torch.full((10, 25), math.nan, device="cpu", dtype=torch.float32)
|
|
.to(dtype)
|
|
.to(device)
|
|
)
|
|
m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
m2 = torch.randn(50, 25, device="cpu", dtype=torch.float32).to(dtype).to(device)
|
|
self._test_addmm_addmv(func, M, m1, m2, beta=0, activation=activation)
|
|
|
|
# Test transpose
|
|
for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
|
|
|
|
def maybe_transpose(cond, m):
|
|
if not cond:
|
|
return m
|
|
return m.t().clone(memory_format=torch.contiguous_format).t()
|
|
|
|
M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
|
|
m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
|
|
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
|
|
self._test_addmm_addmv(
|
|
func, M, m1, m2, transpose_out=t4, activation=activation
|
|
)
|
|
|
|
if t1:
|
|
# use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1)
|
|
self._test_addmm_addmv(
|
|
func,
|
|
V,
|
|
m1,
|
|
m2,
|
|
beta=1,
|
|
transpose_out=t4,
|
|
activation=activation,
|
|
)
|
|
|
|
@precisionOverride(
|
|
{
|
|
torch.float: 1e-4,
|
|
torch.half: 1e-1,
|
|
}
|
|
)
|
|
@dtypes(torch.float32, torch.half)
|
|
def test_addmm(self, device, dtype):
|
|
self._test_addmm_impl(torch.addmm, None, device, dtype)
|
|
|
|
@precisionOverride({torch.bfloat16: 1e-0, torch.half: 1e-3, torch.float: 1e-4})
|
|
@dtypes(torch.bfloat16, torch.half, torch.float)
|
|
def test_addmv(self, device, dtype):
|
|
# have to use torch.randn(...).to(bfloat16) instead of
|
|
# torch.randn(..., dtype=bfloat16). randn does not support
|
|
# bfloat16 yet.
|
|
# "*0.2" to reduce errors for low precision
|
|
ts = [
|
|
0.2 * torch.randn(50, device=device).to(dtype),
|
|
0.2 * torch.randn(1, device=device).to(dtype).expand(50),
|
|
]
|
|
vs = [
|
|
0.2 * torch.randn(100, device=device).to(dtype),
|
|
0.2
|
|
* torch.ones(1, device=device)
|
|
.to(dtype)
|
|
.expand(100), # to reduce errors for low precision
|
|
]
|
|
ms = [
|
|
# 0d
|
|
0.2
|
|
* torch.ones((), device=device)
|
|
.to(dtype)
|
|
.expand(50, 100), # to reduce errors for low precision
|
|
# 1d
|
|
0.2 * torch.randn((1, 100), device=device).to(dtype).expand(50, 100),
|
|
# this initialization reduces errors for low precision for broadcasted matrices
|
|
# by making sure that intermediate and result values are exactly representable
|
|
# in low precision type
|
|
0.2
|
|
* torch.randint(3, (50, 1), dtype=torch.float, device=device)
|
|
.to(dtype)
|
|
.expand(50, 100),
|
|
# 2d
|
|
0.2 * torch.randn((50, 100), device=device).to(dtype),
|
|
0.2 * torch.randn((100, 50), device=device).to(dtype).t(),
|
|
]
|
|
for m, v, t in itertools.product(ms, vs, ts):
|
|
self._test_addmm_addmv(torch.addmv, t, m, v)
|
|
# Test beta=0, t=nan
|
|
t = torch.full((50,), math.nan, device=device).to(dtype)
|
|
for m, v in itertools.product(ms, vs):
|
|
self._test_addmm_addmv(torch.addmv, t, m, v, beta=0)
|
|
|
|
@dtypes(
|
|
torch.half,
|
|
torch.float32,
|
|
)
|
|
def test_mm(self, device, dtype):
|
|
def _test_mm(n, m, p, dtype, genf):
|
|
# helper function
|
|
def matrixmultiply(mat1, mat2):
|
|
n = mat1.size(0)
|
|
m = mat1.size(1)
|
|
p = mat2.size(1)
|
|
dtype_ = torch.float if dtype == torch.half else dtype
|
|
if dtype == torch.half:
|
|
mat1 = mat1.float()
|
|
mat2 = mat2.float()
|
|
res = torch.zeros(n, p, dtype=dtype_, device=device)
|
|
for i, j in iter_indices(res):
|
|
res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m))
|
|
return res.half() if dtype == torch.half else res
|
|
|
|
# contiguous case
|
|
mat1 = genf(n, m)
|
|
mat2 = genf(m, p)
|
|
res = torch.mm(mat1, mat2)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# non contiguous case 1
|
|
mat1 = genf(n, m)
|
|
mat2 = genf(p, m).t()
|
|
res = torch.mm(mat1, mat2)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# non contiguous case 2
|
|
mat1 = genf(m, n).t()
|
|
mat2 = genf(m, p)
|
|
res = torch.mm(mat1, mat2)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# non contiguous case 3
|
|
mat1 = genf(m, n).t()
|
|
mat2 = genf(p, m).t()
|
|
res = torch.mm(mat1, mat2)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# test with zero stride
|
|
mat1 = genf(n, m)
|
|
mat2 = genf(m, 1).expand(m, p)
|
|
res = torch.mm(mat1, mat2)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# explicitly exercise the _out variant in torch.mm().
|
|
# contiguous case
|
|
mat1 = genf(n, m)
|
|
mat2 = genf(m, p)
|
|
res = genf(n, p)
|
|
torch.mm(mat1, mat2, out=res)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
# explicitly exercise the _out variant in torch.mm().
|
|
# non contiguous case 3
|
|
mat1 = genf(m, n).t()
|
|
mat2 = genf(p, m).t()
|
|
res = genf(n, p)
|
|
torch.mm(mat1, mat2, out=res)
|
|
|
|
res2 = matrixmultiply(mat1, mat2)
|
|
self.assertEqual(res, res2)
|
|
|
|
def genf_int(x, y):
|
|
return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
|
|
|
|
def genf_bfloat(x, y):
|
|
return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype) * 0.1
|
|
|
|
def genf_float(x, y):
|
|
return torch.randn(x, y, dtype=dtype, device=device)
|
|
|
|
def genf_Half(x, y):
|
|
return torch.randn(x, y, dtype=dtype, device=device)
|
|
|
|
for n, m, p in [(20, 10, 15), (15, 20, 10), (25, 18, 10)]:
|
|
if (dtype == torch.int32) or (dtype == torch.int64):
|
|
genf = genf_int
|
|
elif dtype == torch.bfloat16:
|
|
genf = genf_bfloat
|
|
elif dtype == torch.half:
|
|
genf = genf_Half
|
|
else:
|
|
genf = genf_float
|
|
|
|
_test_mm(n, m, p, dtype, genf)
|
|
|
|
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
|
@dtypes(torch.float32, torch.bfloat16, torch.half)
|
|
def test_bmm(self, device, dtype):
|
|
batch_sizes = [1, 10]
|
|
M, N, O = 23, 15, 12
|
|
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
|
|
|
|
def invert_perm(p):
|
|
d = {x: i for i, x in enumerate(p)}
|
|
return (d[0], d[1], d[2])
|
|
|
|
def generate_inputs(num_batches):
|
|
# transposed tensors
|
|
for perm1, perm2 in itertools.product(
|
|
itertools.permutations((0, 1, 2)), repeat=2
|
|
):
|
|
b1 = make_tensor(
|
|
(num_batches, M, N), dtype=dtype, device=device, low=-0.1, high=0.1
|
|
)
|
|
b2 = make_tensor(
|
|
(num_batches, N, O), dtype=dtype, device=device, low=-0.1, high=0.1
|
|
)
|
|
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
|
|
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
|
|
yield b1, b2
|
|
# broadcasting tensors
|
|
for b1, b2, b3, b4, b5, b6 in itertools.product((True, False), repeat=6):
|
|
shape1 = (num_batches if b1 else 1, M if b2 else 1, N if b3 else 1)
|
|
shape2 = (num_batches if b4 else 1, N if b5 else 1, O if b6 else 1)
|
|
b1 = make_tensor(
|
|
shape1, dtype=dtype, device=device, low=-0.1, high=0.1
|
|
).expand(num_batches, M, N)
|
|
b2 = make_tensor(
|
|
shape2, dtype=dtype, device=device, low=-0.1, high=0.1
|
|
).expand(num_batches, N, O)
|
|
yield b1, b2
|
|
# zero-sized tensors
|
|
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
|
|
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
|
|
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
|
|
b1 = torch.randn(shape1, dtype=dtype, device=device)
|
|
b2 = torch.randn(shape2, dtype=dtype, device=device)
|
|
yield b1, b2
|
|
|
|
for num_batches in batch_sizes:
|
|
for (b1, b2), perm3 in itertools.product(
|
|
generate_inputs(num_batches), itertools.permutations((0, 1, 2))
|
|
):
|
|
res1 = torch.bmm(b1, b2)
|
|
res2 = (
|
|
torch.full(
|
|
(num_batches, M, O), math.nan, dtype=dtype, device=device
|
|
)
|
|
.permute(perm3)
|
|
.contiguous()
|
|
.permute(invert_perm(perm3))
|
|
)
|
|
torch.bmm(b1, b2, out=res2)
|
|
expect = torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
|
|
).to(device=device, dtype=dtype)
|
|
self.assertEqual(expect, res1)
|
|
self.assertEqual(expect, res2)
|
|
|
|
if self.device_type == "cuda":
|
|
# check that mixed arguments are rejected
|
|
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cpu()))
|
|
self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cpu(), b2))
|
|
self.assertRaises(
|
|
RuntimeError, lambda: torch.bmm(b1, b2, out=res2.cpu())
|
|
)
|
|
|
|
def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor):
|
|
getattr(out_tensor, func + "_")(b1, b2)
|
|
self.assertEqual(out_tensor, ref)
|
|
res3 = out_tensor.clone()
|
|
|
|
with self.assertWarnsOnceRegex(
|
|
UserWarning, f"This overload of {func}_ is deprecated"
|
|
):
|
|
getattr(out_tensor, func + "_")(1, b1, b2)
|
|
self.assertEqual(out_tensor, ref * 2),
|
|
getattr(res3, func + "_")(b1, b2, beta=1)
|
|
self.assertEqual(out_tensor, res3)
|
|
|
|
with self.assertWarnsOnceRegex(
|
|
UserWarning, f"This overload of {func}_ is deprecated"
|
|
):
|
|
getattr(out_tensor, func + "_")(1.0, 0.5, b1, b2)
|
|
self.assertEqual(out_tensor, ref * 2.5)
|
|
getattr(res3, func + "_")(b1, b2, beta=1.0, alpha=0.5)
|
|
self.assertEqual(out_tensor, res3)
|
|
|
|
with self.assertWarnsOnceRegex(
|
|
UserWarning, f"This overload of {func} is deprecated"
|
|
):
|
|
self.assertEqual(out_tensor, getattr(torch, func)(1, out_tensor, 0, b1, b2))
|
|
|
|
res4 = getattr(torch, func)(out_tensor, b1, b2, beta=1, alpha=0.5)
|
|
self.assertEqual(res4, ref * 3),
|
|
|
|
nan = torch.full_like(out_tensor, math.nan)
|
|
res5 = getattr(torch, func)(nan, b1, b2, beta=0, alpha=1)
|
|
self.assertEqual(res5, ref)
|
|
|
|
if b1.is_complex():
|
|
res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1j, alpha=0.5j)
|
|
self.assertEqual(res6, out_tensor * 0.1j + 0.5j * ref)
|
|
else:
|
|
res6 = getattr(torch, func)(out_tensor, b1, b2, beta=0.1, alpha=0.5)
|
|
self.assertEqual(res6, out_tensor * 0.1 + 0.5 * ref)
|
|
|
|
res7 = torch.full_like(out_tensor, math.nan)
|
|
getattr(torch, func)(nan, b1, b2, beta=0, out=res7)
|
|
self.assertEqual(res7, ref)
|
|
|
|
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
|
@dtypes(torch.float32, torch.bfloat16, torch.half)
|
|
def test_addbmm(self, device, dtype):
|
|
num_batches = 2
|
|
M, N, O = 16, 17, 18
|
|
|
|
is_supported = True
|
|
|
|
if not is_supported:
|
|
b1 = make_tensor(
|
|
(num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
|
|
)
|
|
b2 = make_tensor(
|
|
(num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
|
|
)
|
|
t = make_tensor((M, O), dtype=dtype, device=device, low=-1, high=1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED",
|
|
lambda: torch.addbmm(t, b1, b2),
|
|
)
|
|
return
|
|
|
|
def invert_perm(p):
|
|
d = {x: i for i, x in enumerate(p)}
|
|
return (d[0], d[1], d[2])
|
|
|
|
def generate_tensor():
|
|
numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32
|
|
# transposed tensors
|
|
for perm1, perm2 in itertools.product(
|
|
itertools.permutations((0, 1, 2)), repeat=2
|
|
):
|
|
for perm3 in itertools.permutations((0, 1)):
|
|
b1 = (
|
|
make_tensor(
|
|
(num_batches, M, N),
|
|
dtype=dtype,
|
|
device=device,
|
|
low=-1,
|
|
high=1,
|
|
)
|
|
* 0.1
|
|
)
|
|
b2 = (
|
|
make_tensor(
|
|
(num_batches, N, O),
|
|
dtype=dtype,
|
|
device=device,
|
|
low=-1,
|
|
high=1,
|
|
)
|
|
* 0.1
|
|
)
|
|
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
|
|
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
|
|
ref = (
|
|
torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy()
|
|
@ b2.to(numpy_dtype).cpu().numpy()
|
|
)
|
|
.to(device=device, dtype=dtype)
|
|
.sum(0)
|
|
)
|
|
out_tensor = (
|
|
torch.zeros_like(ref).permute(perm3).contiguous().permute(perm3)
|
|
)
|
|
yield b1, b2, ref, out_tensor
|
|
# broadcasting tensors
|
|
for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
|
|
shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
|
|
shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
|
|
b1 = (
|
|
make_tensor(
|
|
shape1, dtype=dtype, device=device, low=-1, high=1
|
|
).expand(num_batches, M, N)
|
|
* 0.1
|
|
)
|
|
b2 = (
|
|
make_tensor(
|
|
shape2, dtype=dtype, device=device, low=-1, high=1
|
|
).expand(num_batches, N, O)
|
|
* 0.1
|
|
)
|
|
ref = (
|
|
torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy()
|
|
@ b2.to(numpy_dtype).cpu().numpy()
|
|
)
|
|
.to(device=device, dtype=dtype)
|
|
.sum(0)
|
|
)
|
|
out_tensor = torch.zeros_like(ref)
|
|
yield b1, b2, ref, out_tensor
|
|
# zero-sized tensors
|
|
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
|
|
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
|
|
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
|
|
b1 = (
|
|
make_tensor(shape1, dtype=dtype, device=device, low=-1, high=1)
|
|
* 0.1
|
|
)
|
|
b2 = (
|
|
make_tensor(shape2, dtype=dtype, device=device, low=-1, high=1)
|
|
* 0.1
|
|
)
|
|
ref = (
|
|
torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy()
|
|
@ b2.to(numpy_dtype).cpu().numpy()
|
|
)
|
|
.to(device=device, dtype=dtype)
|
|
.sum(0)
|
|
)
|
|
out_tensor = torch.zeros_like(ref)
|
|
yield b1, b2, ref, out_tensor
|
|
|
|
for b1, b2, ref, out_tensor in generate_tensor():
|
|
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
|
|
|
|
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5})
|
|
@dtypes(torch.float32, torch.bfloat16, torch.half)
|
|
def test_baddbmm(self, device, dtype):
|
|
num_batches = 10
|
|
M, N, O = 12, 8, 50
|
|
|
|
def invert_perm(p):
|
|
d = {x: i for i, x in enumerate(p)}
|
|
return (d[0], d[1], d[2])
|
|
|
|
def generate_tensor():
|
|
numpy_dtype = (
|
|
dtype if dtype not in [torch.bfloat16, torch.half] else torch.float32
|
|
)
|
|
# transposed tensors
|
|
for perm1, perm2, perm3 in itertools.product(
|
|
itertools.permutations((0, 1, 2)), repeat=3
|
|
):
|
|
b1 = make_tensor(
|
|
(num_batches, M, N), dtype=dtype, device=device, low=-1, high=1
|
|
)
|
|
b2 = make_tensor(
|
|
(num_batches, N, O), dtype=dtype, device=device, low=-1, high=1
|
|
)
|
|
b1 = b1.permute(perm1).contiguous().permute(invert_perm(perm1))
|
|
b2 = b2.permute(perm2).contiguous().permute(invert_perm(perm2))
|
|
ref = torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
|
|
).to(device=device, dtype=dtype)
|
|
out_tensor = torch.zeros_like(ref)
|
|
out_tensor = (
|
|
out_tensor.permute(perm3).contiguous().permute(invert_perm(perm3))
|
|
)
|
|
yield b1, b2, ref, out_tensor
|
|
# broadcasting tensors
|
|
for s1, s2, s3, s4, s5, s6 in itertools.product((True, False), repeat=6):
|
|
shape1 = (num_batches if s1 else 1, M if s2 else 1, N if s3 else 1)
|
|
shape2 = (num_batches if s4 else 1, N if s5 else 1, O if s6 else 1)
|
|
b1 = make_tensor(
|
|
shape1, dtype=dtype, device=device, low=-1, high=1
|
|
).expand(num_batches, M, N)
|
|
b2 = make_tensor(
|
|
shape2, dtype=dtype, device=device, low=-1, high=1
|
|
).expand(num_batches, N, O)
|
|
ref = torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
|
|
).to(device=device, dtype=dtype)
|
|
out_tensor = torch.zeros_like(ref)
|
|
yield b1, b2, ref, out_tensor
|
|
# zero-sized tensors
|
|
for z1, z2, z3, z4 in itertools.product((True, False), repeat=4):
|
|
shape1 = (num_batches if z1 else 0, M if z2 else 0, N if z3 else 0)
|
|
shape2 = (num_batches if z1 else 0, N if z3 else 0, O if z4 else 0)
|
|
b1 = make_tensor(shape1, dtype=dtype, device=device, low=-2, high=2)
|
|
b2 = make_tensor(shape2, dtype=dtype, device=device, low=-2, high=2)
|
|
ref = torch.from_numpy(
|
|
b1.to(numpy_dtype).cpu().numpy() @ b2.to(numpy_dtype).cpu().numpy()
|
|
).to(device=device, dtype=dtype)
|
|
out_tensor = torch.zeros_like(ref)
|
|
yield b1, b2, ref, out_tensor
|
|
|
|
for b1, b2, ref, out_tensor in generate_tensor():
|
|
self._test_addbmm_baddbmm("baddbmm", b1, b2, ref, out_tensor)
|
|
|
|
def test_tensordot(self, device):
|
|
a = torch.arange(60.0, device=device).reshape(3, 4, 5)
|
|
b = torch.arange(24.0, device=device).reshape(4, 3, 2)
|
|
c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
|
|
cn = torch.from_numpy(
|
|
np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=([1, 0], [0, 1]))
|
|
)
|
|
self.assertEqual(c, cn)
|
|
|
|
cout = torch.zeros((5, 2), device=device)
|
|
torch.tensordot(a, b, dims=([1, 0], [0, 1]), out=cout).cpu()
|
|
self.assertEqual(c, cout)
|
|
|
|
a = torch.randn(2, 3, 4, 5, device=device)
|
|
b = torch.randn(4, 5, 6, 7, device=device)
|
|
c = torch.tensordot(a, b, dims=2).cpu()
|
|
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
|
|
torch.tensordot(a, b, dims=-1)
|
|
|
|
self.assertEqual(c, cn)
|
|
c = torch.tensordot(a, b).cpu()
|
|
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
|
|
self.assertEqual(c, cn)
|
|
|
|
a = torch.tensordot(torch.tensor(0.0), torch.tensor(0.0), 0)
|
|
an = torch.from_numpy(
|
|
np.tensordot(
|
|
np.zeros((), dtype=np.float32), np.zeros((), dtype=np.float32), 0
|
|
)
|
|
)
|
|
self.assertEqual(a, an)
|
|
|
|
@dtypes(torch.float)
|
|
@precisionOverride({torch.float32: 1e-4})
|
|
def test_1_sized_with_0_strided(self, device, dtype):
|
|
a = make_tensor((8, 1, 64), dtype=dtype, device=device)
|
|
a_strided = torch.as_strided(a, size=[8, 1, 64], stride=[64, 0, 1])
|
|
b = make_tensor((8, 64, 512), dtype=dtype, device=device)
|
|
b_strided = torch.as_strided(b, size=[8, 64, 512], stride=[64, 1, 512])
|
|
res = torch.bmm(a_strided, b_strided)
|
|
expect = torch.from_numpy(a_strided.cpu().numpy() @ b_strided.cpu().numpy()).to(
|
|
device=device, dtype=dtype
|
|
)
|
|
self.assertEqual(expect, res)
|
|
|
|
def _select_broadcastable_dims(self, dims_full=None):
|
|
# select full dimensionality
|
|
if dims_full is None:
|
|
dims_full = []
|
|
ndims = random.randint(1, 4)
|
|
dims_full = [random.randint(1, 8) for _ in range(ndims)]
|
|
else:
|
|
ndims = len(dims_full)
|
|
|
|
# select actual dimensions for ops:
|
|
# larger: full ndims, individual sizes may be reduced
|
|
# smaller: possibly reduced ndims, sizes may be reduced
|
|
smaller_ndims = random.randint(1, ndims)
|
|
dims_small = []
|
|
dims_large = []
|
|
for i in range(ndims - 1, -1, -1):
|
|
j = random.randint(1, 3)
|
|
if j == 1: # no reduced singleton dimension
|
|
ds = dims_full[i]
|
|
dl = dims_full[i]
|
|
elif j == 2: # larger may have reduced singleton dimension
|
|
ds = dims_full[i]
|
|
dl = 1 if len(dims_small) < smaller_ndims else dims_full[i]
|
|
elif j == 3: # smaller may have reduced singleton dimension
|
|
ds = 1
|
|
dl = dims_full[i]
|
|
dims_large = [dl] + dims_large
|
|
if len(dims_small) < smaller_ndims:
|
|
dims_small = [ds] + dims_small
|
|
return (dims_small, dims_large, dims_full)
|
|
|
|
def test_broadcast_fused_matmul(self, device):
|
|
fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"]
|
|
|
|
for fn in fns:
|
|
batch_dim = random.randint(1, 8)
|
|
n_dim = random.randint(1, 8)
|
|
m_dim = random.randint(1, 8)
|
|
p_dim = random.randint(1, 8)
|
|
|
|
def dims_full_for_fn():
|
|
if fn == "baddbmm":
|
|
return (
|
|
[batch_dim, n_dim, p_dim],
|
|
[batch_dim, n_dim, m_dim],
|
|
[batch_dim, m_dim, p_dim],
|
|
)
|
|
elif fn == "addbmm":
|
|
return (
|
|
[n_dim, p_dim],
|
|
[batch_dim, n_dim, m_dim],
|
|
[batch_dim, m_dim, p_dim],
|
|
)
|
|
elif fn == "addmm":
|
|
return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim])
|
|
elif fn == "addmv":
|
|
return ([n_dim], [n_dim, m_dim], [m_dim])
|
|
elif fn == "addr":
|
|
return ([n_dim, m_dim], [n_dim], [m_dim])
|
|
else:
|
|
raise AssertionError("unknown function")
|
|
|
|
(t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn()
|
|
(t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full)
|
|
|
|
t0_small = torch.randn(*t0_dims_small, device=device).float()
|
|
t1 = torch.randn(*t1_dims, device=device).float()
|
|
t2 = torch.randn(*t2_dims, device=device).float()
|
|
|
|
t0_full = t0_small.expand(*t0_dims_full).to(device)
|
|
|
|
fntorch = getattr(torch, fn)
|
|
r0 = fntorch(t0_small, t1, t2)
|
|
r1 = fntorch(t0_full, t1, t2)
|
|
self.assertEqual(r0, r1)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_strided_mm_bmm(self, device, dtype):
|
|
# Tests strided view case with stride smaller than corresponding dimension size
|
|
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=dtype, device=device)
|
|
new_shape = [2, 2, 2]
|
|
new_stride = [3, 1, 1]
|
|
sx = torch.as_strided(x, size=new_shape, stride=new_stride)
|
|
|
|
torch_fn = lambda x: torch.bmm(x, x) # noqa: E731
|
|
np_fn = lambda x: np.matmul(x, x) # noqa: E731
|
|
self.compare_with_numpy(torch_fn, np_fn, sx)
|
|
|
|
torch_fn = lambda x: torch.mm(x, x) # noqa: E731
|
|
self.compare_with_numpy(torch_fn, np_fn, sx[0])
|
|
|
|
def test_mm_empty_inputs_mixed_dtype_errors(self, device):
|
|
a = torch.randint(0, 10, [1, 10], dtype=torch.int16, device=device)
|
|
b = torch.randn(10, 20, dtype=torch.float32, device=device)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected .* and .* to have the same dtype, but got:"
|
|
):
|
|
torch.mm(a, b)
|
|
|
|
def test_matmul_45724(self, device):
|
|
# https://github.com/pytorch/pytorch/issues/45724
|
|
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)
|
|
b = torch.rand(65537, 64, 22, device=device, dtype=torch.half)
|
|
c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device=device)
|
|
cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).half()
|
|
torch.matmul(a, b, out=c)
|
|
self.assertEqual(c, cpu_result)
|
|
|
|
@dtypes(
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.float16,
|
|
torch.float32,
|
|
torch.float64,
|
|
)
|
|
def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
|
|
batch1 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
|
|
batch2 = torch.rand((1, 2, 2), dtype=torch.float32, device=device)
|
|
input_tensor = torch.rand((1, 2, 2), device=device).to(dtype)
|
|
if dtype != torch.float32:
|
|
with self.assertRaisesRegex(RuntimeError, "Input dtypes must be the same"):
|
|
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0)
|
|
else:
|
|
out = torch.randn((1, 2, 2), dtype=dtype, device=device).fill_(torch.nan)
|
|
y_ref = torch.bmm(batch1, batch2)
|
|
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
|
|
self.assertEqual(out, y_ref)
|
|
|
|
@dtypes(torch.float)
|
|
def test_baddbmm_nan_input_with_zero_beta(self, device, dtype):
|
|
for shape in [[3, 2, 2], [2, 20, 20]]:
|
|
mat1, mat2 = (
|
|
torch.randn(shape, dtype=dtype, device=device) for _ in range(2)
|
|
)
|
|
inputs = [
|
|
torch.randn(shape, dtype=dtype, device=device),
|
|
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
|
|
]
|
|
outs = [
|
|
None,
|
|
torch.randn(shape, dtype=dtype, device=device),
|
|
torch.randn(shape, dtype=dtype, device=device).fill_(torch.nan),
|
|
]
|
|
options = itertools.product(inputs, outs)
|
|
for input, out in options:
|
|
y_ref = torch.bmm(mat1, mat2)
|
|
y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out)
|
|
self.assertEqual(y_ref, y)
|
|
|
|
@dtypes(torch.float)
|
|
def test_addmm_sizes(self, device, dtype):
|
|
for m in [0, 1, 25]:
|
|
for n in [0, 1, 10]:
|
|
for k in [0, 1, 8]:
|
|
M = torch.randn(n, m, device=device).to(dtype)
|
|
m1 = torch.randn(n, k, device=device).to(dtype)
|
|
m2 = torch.randn(k, m, device=device).to(dtype)
|
|
self._test_addmm_addmv(torch.addmm, M, m1, m2)
|
|
|
|
m1 = torch.randn(n, k + 1, device=device).to(dtype)
|
|
m2 = torch.randn(k, m, device=device).to(dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
f"{n}x{k + 1}.*{k}x{m}",
|
|
lambda: torch.addmm(M, m1, m2),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)
|
|
)
|
|
|
|
@precisionOverride(
|
|
{
|
|
torch.double: 1e-8,
|
|
torch.float: 1e-4,
|
|
torch.bfloat16: 5e-2,
|
|
torch.half: 5e-2,
|
|
torch.cfloat: 1e-4,
|
|
torch.cdouble: 1e-8,
|
|
}
|
|
)
|
|
@dtypes(torch.float32, torch.bfloat16, torch.half)
|
|
def test_addmm_gelu(self, device, dtype):
|
|
self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype)
|
|
|
|
@precisionOverride(
|
|
{
|
|
torch.double: 1e-8,
|
|
torch.float: 1e-4,
|
|
torch.bfloat16: 5e-2,
|
|
torch.half: 5e-2,
|
|
torch.cfloat: 1e-4,
|
|
torch.cdouble: 1e-8,
|
|
}
|
|
)
|
|
@dtypes(torch.float32, torch.bfloat16, torch.half)
|
|
def test_addmm_relu(self, device, dtype):
|
|
self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype)
|
|
|
|
@dtypes(torch.float, torch.bfloat16, torch.half)
|
|
def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype):
|
|
# tests (o, s)*(s). o is output size, s is summed size.
|
|
o = 5
|
|
s = 3
|
|
a_data = torch.arange(1, o * s + 1, device=device, dtype=dtype).view(o, s)
|
|
x_data = torch.arange(1, s + 1, 1, device=device, dtype=dtype)
|
|
y_data = torch.ones(o, device=device, dtype=dtype)
|
|
control = torch.tensor(
|
|
[15.0, 33.0, 51.0, 69.0, 87.0], device=device, dtype=dtype
|
|
)
|
|
|
|
def _test(row_major, incx, incy, lda_tail):
|
|
if row_major:
|
|
a_storage = torch.full(
|
|
(o, s + lda_tail), float("nan"), device=device, dtype=dtype
|
|
)
|
|
else:
|
|
a_storage = torch.full(
|
|
(s, o + lda_tail), float("nan"), device=device, dtype=dtype
|
|
).permute(1, 0)
|
|
a = a_storage[:o, :s].copy_(a_data)
|
|
|
|
x_storage = torch.full((s, incx), float("nan"), device=device, dtype=dtype)
|
|
x = x_storage[:, 0].copy_(x_data)
|
|
|
|
y_storage = torch.full((o, incy), float("nan"), device=device, dtype=dtype)
|
|
y = y_storage[:, 0].copy_(y_data)
|
|
|
|
self._test_addmm_addmv(torch.addmv, y, a, x)
|
|
|
|
for row_major, incx, incy, lda_tail in itertools.product(
|
|
(False, True), (1, 2), (1, 2), (0, 1)
|
|
):
|
|
_test(row_major, incx, incy, lda_tail)
|
|
|
|
@precisionOverride(
|
|
{
|
|
torch.double: 1e-8,
|
|
torch.float: 1e-4,
|
|
torch.bfloat16: 0.6,
|
|
torch.half: 1e-1,
|
|
torch.cfloat: 1e-4,
|
|
torch.cdouble: 1e-8,
|
|
}
|
|
)
|
|
@dtypes(torch.bfloat16, torch.half, torch.float32)
|
|
def test_corner_cases_of_cublasltmatmul(self, device, dtype):
|
|
# common case
|
|
M = torch.randn(128, device=device).to(dtype)
|
|
m1 = torch.randn(2048, 2400, device=device).to(dtype)
|
|
m2 = torch.randn(128, 2400, device=device).to(dtype)
|
|
torch.nn.functional.linear(m1, m2, M)
|
|
# Ntrans_B has ld >> rows
|
|
m1 = torch.rand([128, 2400]).to(dtype).to(device).t()
|
|
m2 = torch.rand([2048, 25272]).to(dtype).to(device).t()[21940:24340]
|
|
M = torch.rand([128]).to(dtype).to(device)
|
|
torch.addmm(M, m2.t(), m1)
|
|
# trans_A has ld >> rows
|
|
m1 = torch.rand([128, 25272]).to(dtype).to(device)[:, 21940:24340].t()
|
|
m2 = torch.randn(2048, 2400, device=device).to(dtype)
|
|
M = torch.rand([128]).to(dtype).to(device)
|
|
torch.addmm(M, m2, m1)
|
|
# large tensor dim > 65535
|
|
M = torch.randn(16, device=device).to(dtype)
|
|
m1 = torch.randn(32, 131071, device=device).to(dtype)
|
|
m2 = torch.randn(16, 131071, device=device).to(dtype)
|
|
torch.nn.functional.linear(m1, m2, M)
|
|
|
|
def test_blas_empty(self, device):
|
|
def fn(torchfn, *args, test_out=False, **kwargs):
|
|
def call_torch_fn(*args, **kwargs):
|
|
return torchfn(
|
|
*tuple(
|
|
torch.randn(shape, device=device)
|
|
if isinstance(shape, tuple)
|
|
else shape
|
|
for shape in args
|
|
),
|
|
**kwargs,
|
|
)
|
|
|
|
result = call_torch_fn(*args, **kwargs)
|
|
if not test_out:
|
|
return result
|
|
else:
|
|
out = torch.full_like(result, math.nan)
|
|
out1 = call_torch_fn(*args, **kwargs, out=out)
|
|
return out
|
|
|
|
# mm, addmm
|
|
self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape)
|
|
self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape)
|
|
self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape)
|
|
self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape)
|
|
self.assertEqual(
|
|
torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))
|
|
)
|
|
self.assertEqual(
|
|
torch.zeros((5, 6), device=device),
|
|
fn(torch.mm, (5, 0), (0, 6), test_out=True),
|
|
)
|
|
|
|
self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape)
|
|
self.assertEqual((0, 1), fn(torch.addmm, (1,), (0, 17), (17, 1)).shape)
|
|
t = torch.randn((5, 6), device=device)
|
|
self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6)))
|
|
self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True))
|
|
|
|
# mv, addmv
|
|
self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape)
|
|
self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape)
|
|
self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,)))
|
|
self.assertEqual(
|
|
torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)
|
|
)
|
|
|
|
self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape)
|
|
t = torch.randn((3,), device=device)
|
|
self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,)))
|
|
self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True))
|
|
|
|
# bmm, baddbmm
|
|
self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape)
|
|
self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape)
|
|
self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape)
|
|
self.assertEqual(
|
|
torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))
|
|
)
|
|
self.assertEqual(
|
|
torch.zeros((3, 5, 6), device=device),
|
|
fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True),
|
|
)
|
|
|
|
self.assertEqual(
|
|
(0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape
|
|
)
|
|
self.assertEqual(
|
|
(3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape
|
|
)
|
|
self.assertEqual(
|
|
(0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape
|
|
)
|
|
self.assertEqual(
|
|
(3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape
|
|
)
|
|
c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5)
|
|
self.assertEqual(
|
|
-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)
|
|
) # Issue #33467
|
|
self.assertEqual(
|
|
-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)
|
|
) # Issue #33467
|
|
|
|
# addbmm
|
|
self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape)
|
|
self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape)
|
|
t = torch.randn((5, 6), device=device)
|
|
self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6)))
|
|
self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True))
|
|
|
|
# matmul
|
|
self.assertEqual(torch.tensor(0.0, device=device), fn(torch.matmul, (0,), (0,)))
|
|
self.assertEqual(
|
|
torch.tensor(0.0, device=device),
|
|
fn(torch.matmul, (0,), (0,), test_out=True),
|
|
)
|
|
self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape)
|
|
self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape)
|
|
self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape)
|
|
self.assertEqual(
|
|
torch.zeros((5, 3, 4), device=device),
|
|
fn(torch.matmul, (5, 3, 0), (5, 0, 4)),
|
|
)
|
|
self.assertEqual(
|
|
torch.zeros((5, 3, 4), device=device),
|
|
fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True),
|
|
)
|
|
|
|
# dot
|
|
self.assertEqual(torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,)))
|
|
self.assertEqual(
|
|
torch.tensor(0.0, device=device), fn(torch.dot, (0,), (0,), test_out=True)
|
|
)
|
|
|
|
def test_large_bmm_backward(self, device):
|
|
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
|
|
B = torch.randn([1, 1024, 65536], device=device, requires_grad=True)
|
|
G = torch.randn([1024, 2, 65536], device=device)
|
|
|
|
# Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
|
|
(A @ B).backward(G)
|
|
|
|
def test_large_bmm_mm_backward(self, device):
|
|
A = torch.randn([1024, 2, 1024], device=device).mT.contiguous().mT
|
|
B = torch.randn([1024, 65536], device=device, requires_grad=True)
|
|
G = torch.randn([1024, 2, 65536], device=device)
|
|
|
|
# Should not create an intermediary tensor of size [1024, 1024, 65536] (256GB of memory) and OOM
|
|
(A @ B).backward(G)
|
|
|
|
def check_single_matmul(self, x, y):
|
|
def assertEqual(answer, expected):
|
|
if x.dtype.is_floating_point or x.dtype.is_complex:
|
|
k = max(x.shape[-1], 1) # Scale the atol with the size of the matrix
|
|
self.assertEqual(
|
|
answer,
|
|
expected,
|
|
msg=f"{x.shape} x {y.shape} = {answer.shape}",
|
|
atol=k * 5e-5,
|
|
rtol=1e-4,
|
|
)
|
|
else:
|
|
self.assertEqual(
|
|
answer, expected, msg=f"{x.shape} x {y.shape} = {answer.shape}"
|
|
)
|
|
|
|
# test x @ y
|
|
expected = np.matmul(x.cpu(), y.cpu())
|
|
ans = torch.matmul(x, y)
|
|
self.assertTrue(ans.is_contiguous())
|
|
assertEqual(ans, expected)
|
|
|
|
# test out
|
|
out = torch.empty_like(ans)
|
|
ans = torch.matmul(x, y, out=out)
|
|
self.assertIs(ans, out)
|
|
self.assertTrue(ans.is_contiguous())
|
|
assertEqual(ans, expected)
|
|
|
|
def gen_sizes_matmul(self, x_dim, y_dim=4, matrix_size=4, batch_size=3):
|
|
"""
|
|
Generates sequences of tuples (x, y) of with size(x) = x_dim and
|
|
size(y) <= y_dim that are compatible wrt. matmul
|
|
"""
|
|
assert x_dim >= 1
|
|
assert y_dim >= 2
|
|
x = x_dim
|
|
for y in range(1, y_dim + 1):
|
|
for batch, mn in product(
|
|
product(range(batch_size), repeat=max(x - 2, y - 2, 0)),
|
|
product(range(matrix_size), repeat=min(y, 2)),
|
|
):
|
|
if x == 1:
|
|
size_x = mn[:1]
|
|
size_y = batch + mn
|
|
yield size_x, size_y
|
|
else:
|
|
for k in range(matrix_size):
|
|
size_x = (k,) + mn[:1]
|
|
if x > 2:
|
|
size_x = batch[-(x - 2) :] + size_x
|
|
size_y = mn
|
|
if y > 2:
|
|
size_y = batch[-(y - 2) :] + size_y
|
|
yield size_x, size_y
|
|
|
|
@dtypes(torch.float)
|
|
def test_matmul_small_brute_force_1d_Nd(self, device, dtype):
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for (size_x, size_y), nctg_x, nctg_y in product(
|
|
self.gen_sizes_matmul(1), (True, False), (True, False)
|
|
):
|
|
x = make_arg(size_x, noncontiguous=nctg_x)
|
|
y = make_arg(size_y, noncontiguous=nctg_y)
|
|
self.check_single_matmul(x, y)
|
|
|
|
@dtypes(torch.float)
|
|
def test_matmul_small_brute_force_2d_Nd(self, device, dtype):
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for (size_x, size_y), nctg_x, nctg_y in product(
|
|
self.gen_sizes_matmul(2), (True, False), (True, False)
|
|
):
|
|
x = make_arg(size_x, noncontiguous=nctg_x)
|
|
y = make_arg(size_y, noncontiguous=nctg_y)
|
|
self.check_single_matmul(x, y)
|
|
|
|
@dtypes(torch.float)
|
|
def test_matmul_small_brute_force_3d_Nd(self, device, dtype):
|
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
|
|
|
for (size_x, size_y), nctg_x, nctg_y in product(
|
|
self.gen_sizes_matmul(3), (True, False), (True, False)
|
|
):
|
|
x = make_arg(size_x, noncontiguous=nctg_x)
|
|
y = make_arg(size_y, noncontiguous=nctg_y)
|
|
self.check_single_matmul(x, y)
|
|
|
|
@dtypes(torch.float)
|
|
def test_matmul_out_kernel_errors_with_autograd(self, device, dtype):
|
|
a = torch.empty(
|
|
(256, 512), device=device, dtype=dtype, requires_grad=True
|
|
).unsqueeze(0)
|
|
b = torch.empty(
|
|
(4, 128, 512), device=device, dtype=dtype, requires_grad=True
|
|
).transpose(-1, -2)
|
|
c = torch.empty((256, 4, 128), device=device, dtype=dtype).movedim(1, 0)
|
|
|
|
torch.matmul(a.detach(), b.detach(), out=c)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"functions with out=... arguments don't support automatic differentiation",
|
|
):
|
|
torch.matmul(a, b, out=c)
|
|
|
|
with torch.no_grad():
|
|
torch.matmul(a, b, out=c)
|
|
|
|
|
|
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu")
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|