mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 02:24:54 +08:00
do shape checking in addmm dtype out overload
This commit is contained in:
@ -1022,14 +1022,17 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
|
||||
Tensor _addmm_dtype_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha) {
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
Tensor result = at::empty({mat1.size(0), mat2.size(1)}, self.options().dtype(out_dtype));
|
||||
return _addmm_dtype_out_cuda(self, mat1, mat2, out_dtype, beta, alpha, result);
|
||||
}
|
||||
|
||||
Tensor& _addmm_dtype_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const at::ScalarType out_dtype, const Scalar& beta, const Scalar& alpha, Tensor &out) {
|
||||
// repeat dimensionality checks for direct calls to `out` overload
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
|
||||
TORCH_CHECK(
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type());
|
||||
TORCH_CHECK(out_dtype == mat1.scalar_type() ||
|
||||
(out_dtype == at::ScalarType::Float && (mat1.scalar_type() == at::ScalarType::Half || mat1.scalar_type() == at::ScalarType::BFloat16)),
|
||||
|
||||
Reference in New Issue
Block a user