mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21556 Optimize batch mm op when broadcast the second input Reviewed By: houseroad Differential Revision: D15728914 fbshipit-source-id: c60441d69d4997dd32a3566780496c7ccda5e67a
326 lines
9.4 KiB
C++
326 lines
9.4 KiB
C++
#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
|
|
#define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
template <class Context, class Engine = DefaultEngine>
|
|
class BatchMatMulOp final : public Operator<Context> {
|
|
public:
|
|
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
|
|
|
template <class... Args>
|
|
explicit BatchMatMulOp(Args&&... args)
|
|
: Operator<Context>(std::forward<Args>(args)...),
|
|
OP_SINGLE_ARG(bool, "trans_a", trans_a_, false),
|
|
OP_SINGLE_ARG(bool, "trans_b", trans_b_, false),
|
|
OP_SINGLE_ARG(bool, "broadcast", broadcast_, false) {}
|
|
|
|
bool RunOnDevice() override {
|
|
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
|
|
}
|
|
|
|
template <typename T>
|
|
bool DoRunWithType() {
|
|
const auto& A = Input(0);
|
|
const auto& B = Input(1);
|
|
const int A_ndim = A.dim();
|
|
const int B_ndim = B.dim();
|
|
const std::vector<std::int64_t> A_dims = A.sizes().vec();
|
|
const std::vector<std::int64_t> B_dims = B.sizes().vec();
|
|
const T* A_data = A.template data<T>();
|
|
const T* B_data = B.template data<T>();
|
|
|
|
if (A_ndim == 1 && B_ndim == 1) {
|
|
CAFFE_ENFORCE_EQ(A.numel(), B.numel());
|
|
auto* Y = Output(0, {1}, at::dtype<T>());
|
|
T* Y_data = Y->template mutable_data<T>();
|
|
math::Dot<T, Context>(A.numel(), A_data, B_data, Y_data, &context_);
|
|
return true;
|
|
}
|
|
if (A_ndim == 1) {
|
|
const int N = A.numel();
|
|
if (trans_b_) {
|
|
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], N);
|
|
} else {
|
|
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], N);
|
|
}
|
|
std::vector<std::int64_t> Y_dims(B_ndim - 1);
|
|
if (trans_b_) {
|
|
std::copy_n(B_dims.cbegin(), B_ndim - 1, Y_dims.begin());
|
|
} else {
|
|
std::copy_n(B_dims.cbegin(), B_ndim - 2, Y_dims.begin());
|
|
Y_dims.back() = B_dims.back();
|
|
}
|
|
auto* Y = Output(0, Y_dims, at::dtype<T>());
|
|
T* Y_data = Y->template mutable_data<T>();
|
|
if (trans_b_) {
|
|
const int M = B.numel() / N;
|
|
math::Gemv<T, Context, Engine>(
|
|
CblasNoTrans, M, N, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
|
|
} else {
|
|
const int M = B_dims[B_ndim - 1];
|
|
const int batch_size = B.numel() / (M * N);
|
|
if (batch_size == 1) {
|
|
math::Gemv<T, Context, Engine>(
|
|
CblasTrans, N, M, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
|
|
} else {
|
|
math::GemmStridedBatched<T, Context, Engine>(
|
|
CblasTrans,
|
|
CblasNoTrans,
|
|
batch_size,
|
|
M,
|
|
1,
|
|
N,
|
|
1.0f,
|
|
B_data,
|
|
M * N,
|
|
A_data,
|
|
0,
|
|
0.0f,
|
|
Y_data,
|
|
M,
|
|
&context_);
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
if (B_ndim == 1) {
|
|
const int N = B.numel();
|
|
if (trans_a_) {
|
|
CAFFE_ENFORCE_EQ(A_dims[A_ndim - 2], N);
|
|
} else {
|
|
CAFFE_ENFORCE_EQ(A_dims[A_ndim - 1], N);
|
|
}
|
|
const std::vector<std::int64_t> Y_dims(
|
|
A_dims.cbegin(), A_dims.cbegin() + A_ndim - 1);
|
|
auto* Y = Output(0, Y_dims, at::dtype<T>());
|
|
T* Y_data = Y->template mutable_data<T>();
|
|
if (trans_a_) {
|
|
const int M = A_dims[A_ndim - 1];
|
|
const int batch_size = A.numel() / (M * N);
|
|
if (batch_size == 1) {
|
|
math::Gemv<T, Context, Engine>(
|
|
CblasTrans, N, M, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
|
|
} else {
|
|
math::GemmStridedBatched<T, Context, Engine>(
|
|
CblasTrans,
|
|
CblasNoTrans,
|
|
batch_size,
|
|
M,
|
|
1,
|
|
N,
|
|
1.0f,
|
|
A_data,
|
|
M * N,
|
|
B_data,
|
|
0,
|
|
0.0f,
|
|
Y_data,
|
|
M,
|
|
&context_);
|
|
}
|
|
} else {
|
|
const int M = A.numel() / N;
|
|
math::Gemv<T, Context, Engine>(
|
|
CblasNoTrans, M, N, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
const int M = trans_a_ ? A_dims[A_ndim - 1] : A_dims[A_ndim - 2];
|
|
const int K = trans_a_ ? A_dims[A_ndim - 2] : A_dims[A_ndim - 1];
|
|
if (trans_b_) {
|
|
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], K);
|
|
} else {
|
|
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], K);
|
|
}
|
|
const int N = trans_b_ ? B_dims[B_ndim - 2] : B_dims[B_ndim - 1];
|
|
const int ndim = std::max(A_ndim, B_ndim);
|
|
std::vector<std::int64_t> A_broadcast_dims(ndim);
|
|
std::vector<std::int64_t> B_broadcast_dims(ndim);
|
|
std::vector<std::int64_t> Y_broadcast_dims(ndim);
|
|
math::utils::ComputeBroadcastBinaryOpDims(
|
|
A_ndim - 2,
|
|
A_dims.data(),
|
|
B_ndim - 2,
|
|
B_dims.data(),
|
|
A_broadcast_dims.data(),
|
|
B_broadcast_dims.data(),
|
|
Y_broadcast_dims.data());
|
|
Y_broadcast_dims[ndim - 2] = M;
|
|
Y_broadcast_dims[ndim - 1] = N;
|
|
auto* Y = Output(0, Y_broadcast_dims, at::dtype<T>());
|
|
T* Y_data = Y->template mutable_data<T>();
|
|
|
|
const int batch_dim = ndim - 2;
|
|
const bool is_broadcast_dims = !std::equal(
|
|
A_broadcast_dims.cbegin(),
|
|
A_broadcast_dims.cbegin() + batch_dim,
|
|
B_broadcast_dims.cbegin());
|
|
if (is_broadcast_dims) {
|
|
CAFFE_ENFORCE(broadcast_);
|
|
}
|
|
|
|
const std::int64_t A_batch_size = std::accumulate(
|
|
A_broadcast_dims.cbegin(),
|
|
A_broadcast_dims.cbegin() + batch_dim,
|
|
1LL,
|
|
std::multiplies<std::int64_t>());
|
|
const std::int64_t B_batch_size = std::accumulate(
|
|
B_broadcast_dims.cbegin(),
|
|
B_broadcast_dims.cbegin() + batch_dim,
|
|
1LL,
|
|
std::multiplies<std::int64_t>());
|
|
const std::int64_t Y_batch_size = std::accumulate(
|
|
Y_broadcast_dims.cbegin(),
|
|
Y_broadcast_dims.cbegin() + batch_dim,
|
|
1LL,
|
|
std::multiplies<std::int64_t>());
|
|
if (Y_batch_size == 0) {
|
|
return true;
|
|
}
|
|
if (A_batch_size == 1 && B_batch_size == 1) {
|
|
math::Gemm<T, Context, Engine>(
|
|
trans_a_ ? CblasTrans : CblasNoTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_data,
|
|
B_data,
|
|
0.0f,
|
|
Y_data,
|
|
&context_);
|
|
} else if (A_batch_size == 1) {
|
|
if (M == 1 && trans_b_) {
|
|
math::Gemv<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
B_batch_size * N,
|
|
K,
|
|
1.0f,
|
|
B_data,
|
|
A_data,
|
|
0.0f,
|
|
Y_data,
|
|
&context_);
|
|
} else {
|
|
math::GemmStridedBatched<T, Context, Engine>(
|
|
trans_a_ ? CblasTrans : CblasNoTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
Y_batch_size,
|
|
M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_data,
|
|
0,
|
|
B_data,
|
|
K * N,
|
|
0.0f,
|
|
Y_data,
|
|
M * N,
|
|
&context_);
|
|
}
|
|
} else if (B_batch_size == 1) {
|
|
if (!trans_a_) {
|
|
math::Gemm<T, Context, Engine>(
|
|
CblasNoTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
A_batch_size * M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_data,
|
|
B_data,
|
|
0.0f,
|
|
Y_data,
|
|
&context_);
|
|
} else {
|
|
math::GemmStridedBatched<T, Context, Engine>(
|
|
CblasTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
Y_batch_size,
|
|
M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_data,
|
|
M * K,
|
|
B_data,
|
|
0,
|
|
0.0f,
|
|
Y_data,
|
|
M * N,
|
|
&context_);
|
|
}
|
|
} else if (!is_broadcast_dims) {
|
|
math::GemmStridedBatched<T, Context, Engine>(
|
|
trans_a_ ? CblasTrans : CblasNoTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
Y_batch_size,
|
|
M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_data,
|
|
M * K,
|
|
B_data,
|
|
K * N,
|
|
0.0f,
|
|
Y_data,
|
|
M * N,
|
|
&context_);
|
|
} else {
|
|
std::vector<const T*> A_ptr(Y_batch_size);
|
|
std::vector<const T*> B_ptr(Y_batch_size);
|
|
std::vector<T*> Y_ptr(Y_batch_size);
|
|
std::vector<std::int64_t> index(batch_dim);
|
|
for (std::int64_t i = 0; i < Y_batch_size; ++i) {
|
|
const std::int64_t A_index = math::utils::GetIndexFromDims(
|
|
batch_dim, A_broadcast_dims.data(), index.data());
|
|
const std::int64_t B_index = math::utils::GetIndexFromDims(
|
|
batch_dim, B_broadcast_dims.data(), index.data());
|
|
A_ptr[i] = A_data + A_index * M * K;
|
|
B_ptr[i] = B_data + B_index * K * N;
|
|
Y_ptr[i] = Y_data + i * M * N;
|
|
math::utils::IncreaseIndexInDims(
|
|
batch_dim, Y_broadcast_dims.data(), index.data());
|
|
}
|
|
math::GemmBatched<T, Context, Engine>(
|
|
trans_a_ ? CblasTrans : CblasNoTrans,
|
|
trans_b_ ? CblasTrans : CblasNoTrans,
|
|
Y_batch_size,
|
|
M,
|
|
N,
|
|
K,
|
|
1.0f,
|
|
A_ptr.data(),
|
|
B_ptr.data(),
|
|
0.0f,
|
|
Y_ptr.data(),
|
|
&context_);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
private:
|
|
const bool trans_a_;
|
|
const bool trans_b_;
|
|
const bool broadcast_;
|
|
};
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
|