port addmv to structured kernels (#55746)

Summary:
Per title
I've revamped size checks a bit to provide better error message if `self` is of the wrong size, also added check that inplace variant has correct `self` size

Ref: https://github.com/pytorch/pytorch/issues/55070

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55746

Reviewed By: ezyang

Differential Revision: D27782980

Pulled By: ngimel

fbshipit-source-id: 6ba949b682b8fd1170d0304da0ed348dd1a7b8c7
This commit is contained in:
Natalia Gimelshein
2021-04-15 15:52:11 -07:00
committed by Facebook GitHub Bot
parent 8e82e932f3
commit 3fbca31be3
6 changed files with 116 additions and 115 deletions

View File

@ -369,18 +369,16 @@ static std::vector<Dimname> compute_matmul_outnames(
return result;
}
void propagate_names_for_addmv(
Tensor& result,
std::vector<Dimname> propagate_names_for_addmv(
const Tensor& mat,
const Tensor& vec,
const Tensor& bias) {
if (!result.has_names() && !mat.has_names() &&
if (!mat.has_names() &&
!vec.has_names() && !bias.has_names()) {
return;
return std::vector<Dimname>{};
}
auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names());
auto add_outnames = unify_from_right(mv_outnames, bias.names());
propagate_names(result, add_outnames);
return unify_from_right(mv_outnames, bias.names());
}
void propagate_names_for_addmm(

View File

@ -146,8 +146,7 @@ TORCH_API void propagate_names_for_addmm(
const Tensor& m2,
const Tensor& bias);
TORCH_API void propagate_names_for_addmv(
Tensor& result,
TORCH_API std::vector<Dimname> propagate_names_for_addmv(
const Tensor& mat,
const Tensor& vec,
const Tensor& bias);

View File

@ -4,7 +4,24 @@
#include <ATen/NamedTensorUtils.h>
#include <ATen/ScalarOps.h>
namespace at { namespace native {
namespace at {
namespace meta {
TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) {
TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
"vector + matrix @ vector expected, got ", self.dim(), ", ", mat.dim(), ", ", vec.dim());
TORCH_CHECK(mat.size(1) == vec.size(0) && (mat.size(0) == self.numel() || self.numel() == 1),
"size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names);
auto result = maybe_get_output(0);
//this check can fire for inplace op only, for all other versions result is guaranteed to be correct size
TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ",
"size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0));
}
}
namespace native {
template<typename scalar_t>
void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
@ -19,86 +36,69 @@ constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) {
return n == 1 || lda >= std::max<int64_t>(1L, m);
}
Tensor &addmv_impl_cpu(Tensor& result, const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_) {
auto r_stride = result.stride(0);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) {
gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) {
gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
});
return result;
}
Tensor &addmv_out(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha, Tensor& result) {
{ // scope of NoNamesGuard
at::NoNamesGuard guard;
result.resize_({mat.size(0)});
Tensor self_ = self;
if (self.dim() == 0 || self.size(0) == 1) {
self_ = self.expand({mat.size(0)});
}
TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self_.dim() == 1),
"vector + matrix @ vector expected, got ", self_.dim(), ", ", mat.dim(), ", ", vec.dim());
TORCH_CHECK((mat.size(1) == vec.size(0) && mat.size(0) == self_.size(0)),
"size mismatch, get ", self_.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta_.toComplexDouble();
if (mat.numel() == 0) {
// shortcut for an empty matrix
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.0) {
if (betaval == 0.0) {
result.zero_();
} else {
at::cpu::mul_out(
result,
const_cast<Tensor&>(result),
self,
at::native::scalar_tensor(
beta, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */));
beta_, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */));
}
} else {
if (!result.is_same(self_)) {
at::native::copy_(result, self_);
if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents is ignored
at::native::copy_(const_cast<Tensor&>(result), *self_);
}
if (result.numel() != 0) {
at::_addmv_impl_(result, self_, mat, vec, beta, alpha);
auto r_stride = result.stride(0);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) {
gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) {
gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
});
}
}
} // scope of NoNamesGuard
at::namedinference::propagate_names_for_addmv(result, mat, vec, self);
return result;
}
Tensor addmv(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) {
Tensor result = at::empty({mat.size(0)}, mat.options());
return native::addmv_out(self, mat, vec, beta, alpha, result);
}
Tensor &addmv_(Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) {
return native::addmv_out(self, mat, vec, beta, alpha, self);
}
Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) {
return native::addmv_out(result, self, vec, 0, 1, result);
//self arg sent to addmv_out cannot be resized
//here we use result as self argument for addmv, and result is user supplied and can be wrong size
//it's not a hard error, because we allow resizing result, but it becomes a hard error
//in addmv, because addmv expects self to satisfy proper conditions
//to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0
if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) {
Tensor self_addmv = at::empty({self.size(0)}, self.options());
return at::addmv_out(result, self_addmv, self, vec, 0, 1);
}
return at::addmv_out(result, result, self, vec, 0, 1);
}
Tensor mv(const Tensor &self, const Tensor &vec) {
Tensor result = at::empty({self.size(0)}, self.options());
return native::mv_out(self, vec, result);
//inplace version is more efficient if we can use it
return at::addmv_(result, self, vec, 0, 1);
}
inline void dot_check(const Tensor& self, const Tensor& other) {

View File

@ -1,38 +1,60 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDABlas.h>
#include <c10/util/MaybeOwned.h>
namespace at { namespace native {
Tensor &addmv_impl_cuda(Tensor& result, const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_) {
auto r_stride = result.stride(0);
auto vec_stride = vec.stride(0);
TORCH_IMPL_FUNC(addmv_out_cuda)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
auto betaval = beta_.toComplexDouble();
if (mat.numel() == 0) {
// shortcut for an empty matrix
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (betaval == 0.0) {
result.zero_();
} else {
at::mul_out(
const_cast<Tensor&>(result),
self,
at::native::scalar_tensor(
beta_, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */));
}
} else {
if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents will be zeroed later
at::native::copy_(const_cast<Tensor&>(result), *self_);
}
if (result.numel() != 0) {
auto r_stride = result.stride(0);
auto vec_stride = vec.stride(0);
// Check for contiguity of `vec` and update `vec_stride` accordingly
const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec;
vec_stride = vec_contiguous.stride(0);
// Check for contiguity of `vec` and update `vec_stride` accordingly
const auto vec_contiguous = vec_stride == 0 ? vec.contiguous() : vec;
vec_stride = vec_contiguous.stride(0);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
if (mat.stride(0) == 1 && mat.stride(1) >= std::max<int64_t>(1, mat.size(0))) {
at::cuda::blas::gemv<scalar_t>('n',
mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1), vec_contiguous.data_ptr<scalar_t>(),
vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, mat.scalar_type(), "addmv_impl_cuda", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
if (mat.stride(0) == 1 && mat.stride(1) >= std::max<int64_t>(1, mat.size(0))) {
at::cuda::blas::gemv<scalar_t>('n',
mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1), vec_contiguous.data_ptr<scalar_t>(),
vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
}
else if (mat.stride(1) == 1 && mat.stride(0) >= std::max<int64_t>(1, mat.size(1))) {
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0),
vec_contiguous.data_ptr<scalar_t>(), vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0),
vec_contiguous.data_ptr<scalar_t>(), vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
}
});
}
else if (mat.stride(1) == 1 && mat.stride(0) >= std::max<int64_t>(1, mat.size(1))) {
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0),
vec_contiguous.data_ptr<scalar_t>(), vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
at::cuda::blas::gemv<scalar_t>('t',
mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0),
vec_contiguous.data_ptr<scalar_t>(), vec_stride, beta, result.data_ptr<scalar_t>(), r_stride);
}
});
return result;
}
}
}} // namespace at::native

View File

@ -378,24 +378,18 @@
CompositeExplicitAutograd: add_
- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
structured_delegate: addmv.out
variants: function, method
dispatch:
CPU, CUDA: addmv
- func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
structured_delegate: addmv.out
variants: function, method
dispatch:
CPU, CUDA: addmv_
- func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
CPU, CUDA: addmv_out
- func: _addmv_impl_(Tensor(a!) self, Tensor self2, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
dispatch:
CPU: addmv_impl_cpu
CUDA: addmv_impl_cuda
CPU: addmv_out_cpu
CUDA: addmv_out_cuda
- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function, method

View File

@ -64,19 +64,6 @@ allow_list = [
("aten::fake_quantize_per_tensor_affine_backward", datetime.date(2021, 2, 20)),
("aten::fake_quantize_per_channel_affine_backward", datetime.date(2021, 2, 20)),
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
("aten::_foreach_mul_", datetime.date(2021, 4, 2)),
("aten::_foreach_addcdiv_", datetime.date(2021, 4, 2)),
("aten::_foreach_div", datetime.date(2021, 4, 2)),
("aten::_foreach_addcmul_", datetime.date(2021, 4, 2)),
("aten::_foreach_sub", datetime.date(2021, 4, 2)),
("aten::_foreach_add", datetime.date(2021, 4, 2)),
("aten::_foreach_sub_", datetime.date(2021, 4, 2)),
("aten::_foreach_add_", datetime.date(2021, 4, 2)),
("aten::_foreach_mul", datetime.date(2021, 4, 2)),
("aten::_foreach_div_", datetime.date(2021, 4, 2)),
("aten::_foreach_addcdiv", datetime.date(2021, 4, 2)),
("aten::_foreach_addcmul", datetime.date(2021, 4, 2)),
("aten::mkldnn_linear", datetime.date(2021, 3, 2)),
("aten::_mode*", datetime.date(2021, 5, 2)),
("aten::linalg_multi_dot", datetime.date(2021, 3, 25)),
("aten::coalesce", datetime.date(2021, 4, 15)),
@ -87,6 +74,7 @@ allow_list = [
("aten::assert_async", datetime.date(2021, 5, 1)),
("aten::cumprod_backward", datetime.date(2021, 5, 1)),
("aten::_triangular_solve_helper", datetime.date(9999, 1, 1)),
("aten::_addmv_impl_", datetime.date(2021, 5, 15)),
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
]