mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
move BatchPermutationOp to caffe2/operators
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31350 Reviewed By: houseroad Differential Revision: D19053527 fbshipit-source-id: 50d11f137d0f5c07e8ad899a3a84d56a042bbc32
This commit is contained in:
committed by
Facebook Github Bot
parent
0b8332efb4
commit
d9c3913dfc
169
caffe2/operators/batch_permutation_op.cc
Normal file
169
caffe2/operators/batch_permutation_op.cc
Normal file
@ -0,0 +1,169 @@
|
||||
#include "caffe2/operators/batch_permutation_op.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#ifdef CAFFE2_USE_MKLDNN
|
||||
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
|
||||
#include <caffe2/ideep/utils/ideep_operator.h>
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <bool forwards>
|
||||
void batch_permutation_loop(
|
||||
const int N,
|
||||
const int K,
|
||||
const float* src,
|
||||
const int* indices,
|
||||
float* dst) {
|
||||
long numBytes = K * sizeof(float);
|
||||
if (forwards) {
|
||||
#ifdef _OPENMP
|
||||
#if (_OPENMP >= 201307)
|
||||
#pragma omp parallel for simd
|
||||
#else
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
#endif
|
||||
for (int n = 0; n < N; n++) {
|
||||
int origIdx = n * K;
|
||||
int permuteIdx = indices[n] * K;
|
||||
std::memcpy(dst + origIdx, src + permuteIdx, numBytes);
|
||||
}
|
||||
} else {
|
||||
std::vector<int> backward_indices(N);
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
backward_indices[indices[i]] = i;
|
||||
}
|
||||
for (int n = 0; n < N; n++) {
|
||||
int permuteIdx = n * K;
|
||||
int origIdx = backward_indices[n] * K;
|
||||
std::memcpy(dst + permuteIdx, src + origIdx, numBytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& indices = Input(1);
|
||||
|
||||
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
|
||||
CAFFE_ENFORCE(
|
||||
X.dim32(0) == indices.dim32(0),
|
||||
"X.dim32(0) must be equal to indices.dim32(0)",
|
||||
"(",
|
||||
X.dim32(0),
|
||||
" vs. ",
|
||||
indices.dim32(0),
|
||||
")");
|
||||
|
||||
auto* Y = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
CAFFE_ENFORCE_GT(X.dim32(0), 0);
|
||||
batch_permutation_loop<true>(
|
||||
X.dim32(0),
|
||||
X.numel() / X.dim32(0),
|
||||
X.data<float>(),
|
||||
indices.data<int>(),
|
||||
Y->mutable_data<float>());
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BatchPermutationGradientOp<float, CPUContext>::RunOnDevice() {
|
||||
auto& indices = Input(0);
|
||||
auto& dY = Input(1);
|
||||
|
||||
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
|
||||
|
||||
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
|
||||
batch_permutation_loop<false>(
|
||||
dY.dim32(0),
|
||||
dY.numel() / dY.dim32(0),
|
||||
dY.data<float>(),
|
||||
indices.data<int>(),
|
||||
dX->mutable_data<float>());
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef CAFFE2_USE_MKLDNN
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
BatchPermutation,
|
||||
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
|
||||
#endif
|
||||
|
||||
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
BatchPermutationGradient,
|
||||
BatchPermutationGradientOp<float, CPUContext>);
|
||||
|
||||
// Input: X, indices; Output: Y
|
||||
OPERATOR_SCHEMA(BatchPermutation)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Batch permutation of an input tensor X given input indices. First dimension of
|
||||
X equals batch size N. The indices stores a be permutation of N.
|
||||
The output Y is a tensor of same shape as X, with data re-ordered according to
|
||||
the indices within the batch size.
|
||||
|
||||
Example of batch permutation on a 2-D tensor with batch size 4:
|
||||
X = [
|
||||
[1, 5, 2, 3, 4, 6, 0],
|
||||
[4, 3, 3, 5, 2, 3, 1],
|
||||
[2, 2, 3, 6, 0, 0, 1],
|
||||
[0, 0, 1, 1, 2, 2, 3]
|
||||
]
|
||||
indices = [2, 0, 1, 3]
|
||||
Y = [
|
||||
[2, 2, 3, 6, 0, 0, 1],
|
||||
[1, 5, 2, 3, 4, 6, 0],
|
||||
[4, 3, 3, 5, 2, 3, 1],
|
||||
[0, 0, 1, 1, 2, 2, 3]
|
||||
]
|
||||
|
||||
Example of batch permutation on a 3-D tensor with batch size 4:
|
||||
X = [
|
||||
[[1, 5, 2], [3, 4, 6, 0]],
|
||||
[[4, 3, 3], [5, 2, 3, 1]],
|
||||
[[2, 2, 3], [6, 0, 0, 1]],
|
||||
[[0, 0, 1], [1, 2, 2, 3]]
|
||||
]
|
||||
indices = [2, 0, 1, 3]
|
||||
Y = [
|
||||
[[2, 2, 3], [6, 0, 0, 1]],
|
||||
[[1, 5, 2], [3, 4, 6, 0]],
|
||||
[[4, 3, 3], [5, 2, 3, 1]],
|
||||
[[0, 0, 1], [1, 2, 2, 3]]
|
||||
]
|
||||
)DOC")
|
||||
.Input(0, "X", "Input tensor, where 1st dimension equals batch size")
|
||||
.Input(1, "indices", "Input indices of batch to permute")
|
||||
.Output(0, "Y", "Output permuted tensor");
|
||||
// Input: indices, dY (aka "gradOutput"); Output: dX (aka "gradInput")
|
||||
OPERATOR_SCHEMA(BatchPermutationGradient).NumInputs(2).NumOutputs(1);
|
||||
|
||||
class GetBatchPermutationGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
return SingleGradientDef(
|
||||
"BatchPermutationGradient",
|
||||
"",
|
||||
vector<string>{I(1), GO(0)},
|
||||
vector<string>{GI(0)});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
using BatchPermutationOpFloatCPU =
|
||||
caffe2::BatchPermutationOp<float, caffe2::CPUContext>;
|
||||
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
|
||||
BatchPermutation,
|
||||
"_caffe2::BatchPermutation(Tensor X, Tensor indices) -> Tensor",
|
||||
BatchPermutationOpFloatCPU);
|
||||
113
caffe2/operators/batch_permutation_op.cu
Normal file
113
caffe2/operators/batch_permutation_op.cu
Normal file
@ -0,0 +1,113 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/batch_permutation_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
template <bool forward>
|
||||
__global__ void BatchPermutationKernel(
|
||||
int N,
|
||||
int K,
|
||||
const float* src,
|
||||
const int* indices,
|
||||
float* dst) {
|
||||
if (forward) {
|
||||
CUDA_1D_KERNEL_LOOP(index, N * K) {
|
||||
int k = index % K;
|
||||
int n = index / K;
|
||||
int idx = indices[n];
|
||||
CUDA_KERNEL_ASSERT(idx >= 0);
|
||||
CUDA_KERNEL_ASSERT(idx < N);
|
||||
dst[index] = src[idx * K + k];
|
||||
}
|
||||
} else {
|
||||
CUDA_1D_KERNEL_LOOP(index, N * K) {
|
||||
int k = index % K;
|
||||
int n = index / K;
|
||||
|
||||
// NOTE: an alternative implementation if we want to align the index with
|
||||
// the output tensor (rather than the input tensor).
|
||||
// int idx = -1;
|
||||
// for (size_t i = 0; i < N; ++i) {
|
||||
// if (indices[i] == n) {
|
||||
// idx = i;
|
||||
// }
|
||||
// }
|
||||
// CUDA_KERNEL_ASSERT(idx >= 0);
|
||||
// CUDA_KERNEL_ASSERT(idx < N);
|
||||
// dst[index] = src[idx * K + k];
|
||||
|
||||
int idx = indices[n];
|
||||
CUDA_KERNEL_ASSERT(idx >= 0);
|
||||
CUDA_KERNEL_ASSERT(idx < N);
|
||||
dst[idx * K + k] = src[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
bool BatchPermutationOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& indices = Input(1);
|
||||
|
||||
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
|
||||
CAFFE_ENFORCE(
|
||||
X.dim32(0) == indices.dim32(0),
|
||||
"X.dim32(0) must be equal to indices.dim32(0)",
|
||||
"(",
|
||||
X.dim32(0),
|
||||
" vs. ",
|
||||
indices.dim32(0),
|
||||
")");
|
||||
|
||||
auto* Y = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
CAFFE_ENFORCE_GT(X.dim32(0), 0);
|
||||
BatchPermutationKernel<true>
|
||||
<<<CAFFE_GET_BLOCKS(X.numel()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
X.dim32(0),
|
||||
X.numel() / X.dim32(0),
|
||||
X.data<float>(),
|
||||
indices.data<int>(),
|
||||
Y->mutable_data<float>());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BatchPermutationGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& indices = Input(0);
|
||||
auto& dY = Input(1);
|
||||
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
|
||||
|
||||
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
|
||||
BatchPermutationKernel<false>
|
||||
<<<CAFFE_GET_BLOCKS(dY.numel()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
dY.dim32(0),
|
||||
dY.numel() / dY.dim32(0),
|
||||
dY.data<float>(),
|
||||
indices.data<int>(),
|
||||
dX->mutable_data<float>());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
BatchPermutation,
|
||||
BatchPermutationOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
BatchPermutationGradient,
|
||||
BatchPermutationGradientOp<float, CUDAContext>);
|
||||
} // namespace caffe2
|
||||
|
||||
using BatchPermutationOpFloatCUDA =
|
||||
caffe2::BatchPermutationOp<float, caffe2::CUDAContext>;
|
||||
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(BatchPermutation, BatchPermutationOpFloatCUDA);
|
||||
37
caffe2/operators/batch_permutation_op.h
Normal file
37
caffe2/operators/batch_permutation_op.h
Normal file
@ -0,0 +1,37 @@
|
||||
#ifndef BATCHPERMUTATION_OP_H_
|
||||
#define BATCHPERMUTATION_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchPermutation)
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
class BatchPermutationOp final : public Operator<Context> {
|
||||
public:
|
||||
template <class... Args>
|
||||
explicit BatchPermutationOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice();
|
||||
};
|
||||
|
||||
template <typename T, class Context>
|
||||
class BatchPermutationGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
BatchPermutationGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice();
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // BATCHPERMUTATION_OP_H_
|
||||
269
caffe2/operators/batch_permutation_op_gpu_test.cc
Normal file
269
caffe2/operators/batch_permutation_op_gpu_test.cc
Normal file
@ -0,0 +1,269 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/core/flags.h"
|
||||
#include "caffe2/operators/batch_permutation_op.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
|
||||
// Add the vector as an input to a Workspace depending on the context of the
|
||||
// workspace
|
||||
|
||||
template <typename T>
|
||||
void AddInputCPU(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<T>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CPU);
|
||||
tensor->Resize(shape);
|
||||
EigenVectorMap<T> tensor_vec(tensor->mutable_data<T>(), tensor->numel());
|
||||
tensor_vec.array() = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>{
|
||||
values.data(), static_cast<int>(values.size())};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddInputGPU(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<T>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
Tensor tmp(shape, CPU);
|
||||
EigenVectorMap<T> tmp_vec(tmp.mutable_data<T>(), tmp.numel());
|
||||
tmp_vec.array() = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>{
|
||||
values.data(), static_cast<int>(values.size())};
|
||||
|
||||
Blob* blob = ws->CreateBlob(name);
|
||||
auto* tensor = BlobGetMutableTensor(blob, CUDA);
|
||||
tensor->CopyFrom(tmp);
|
||||
}
|
||||
|
||||
// Overload 4 different signatures for AddInput because clang does not allow
|
||||
// template <typename T>
|
||||
// void AddInput<CPUContext>(...) {...}
|
||||
|
||||
template <typename T, class Context>
|
||||
void AddInput(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<T>& values,
|
||||
const string& name,
|
||||
Workspace* ws);
|
||||
|
||||
template <>
|
||||
void AddInput<int, CPUContext>(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<int>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
AddInputCPU<int>(shape, values, name, ws);
|
||||
}
|
||||
|
||||
template <>
|
||||
void AddInput<float, CPUContext>(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<float>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
AddInputCPU<float>(shape, values, name, ws);
|
||||
}
|
||||
|
||||
template <>
|
||||
void AddInput<int, CUDAContext>(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<int>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
AddInputGPU<int>(shape, values, name, ws);
|
||||
}
|
||||
|
||||
template <>
|
||||
void AddInput<float, CUDAContext>(
|
||||
const vector<int64_t>& shape,
|
||||
const vector<float>& values,
|
||||
const string& name,
|
||||
Workspace* ws) {
|
||||
AddInputGPU<float>(shape, values, name, ws);
|
||||
}
|
||||
|
||||
template <class Context>
|
||||
DeviceTypeProto GetDeviceType() {
|
||||
return PROTO_CPU;
|
||||
}
|
||||
template <>
|
||||
DeviceTypeProto GetDeviceType<CUDAContext>() {
|
||||
return PROTO_CUDA;
|
||||
}
|
||||
|
||||
// Create a BatchPermutationOp with the given inputs (actual values are
|
||||
// generated sequentially) and run it
|
||||
template <class Context>
|
||||
void CreateAndRun(
|
||||
TensorCPU* outResult,
|
||||
int N,
|
||||
vector<int64_t>& shape,
|
||||
vector<float>& features,
|
||||
vector<int> indices) {
|
||||
Workspace ws;
|
||||
|
||||
AddInput<float, Context>(shape, features, "X", &ws);
|
||||
AddInput<int, Context>(vector<int64_t>{N}, indices, "indices", &ws);
|
||||
|
||||
OperatorDef def;
|
||||
def.set_name("test");
|
||||
def.set_type("BatchPermutation");
|
||||
def.add_input("X");
|
||||
def.add_input("indices");
|
||||
def.add_output("Y");
|
||||
def.mutable_device_option()->set_device_type(GetDeviceType<Context>());
|
||||
unique_ptr<OperatorBase> op = CreateOperator(def, &ws);
|
||||
|
||||
EXPECT_NE(nullptr, op.get());
|
||||
EXPECT_TRUE(op->Run());
|
||||
|
||||
Blob* Y_blob = ws.GetBlob("Y");
|
||||
EXPECT_NE(nullptr, Y_blob);
|
||||
|
||||
auto& Y = Y_blob->Get<Tensor>();
|
||||
outResult->CopyFrom(Y);
|
||||
}
|
||||
|
||||
// Create a BatchPermutationOp with the given inputs (actual values are
|
||||
// generated sequentially) and run it
|
||||
template <class Context>
|
||||
void CreateAndRunGradient(
|
||||
TensorCPU* outResult,
|
||||
int N,
|
||||
vector<int64_t>& shape,
|
||||
vector<float>& features,
|
||||
vector<int> indices) {
|
||||
Workspace ws;
|
||||
|
||||
AddInput<float, Context>(shape, features, "dY", &ws);
|
||||
AddInput<int, Context>(vector<int64_t>{N}, indices, "indices", &ws);
|
||||
|
||||
OperatorDef def;
|
||||
def.set_name("test");
|
||||
def.set_type("BatchPermutationGradient");
|
||||
def.add_input("indices");
|
||||
def.add_input("dY");
|
||||
def.add_output("dX");
|
||||
def.mutable_device_option()->set_device_type(GetDeviceType<Context>());
|
||||
unique_ptr<OperatorBase> op = CreateOperator(def, &ws);
|
||||
|
||||
EXPECT_NE(nullptr, op.get());
|
||||
EXPECT_TRUE(op->Run());
|
||||
|
||||
Blob* Y_blob = ws.GetBlob("dX");
|
||||
EXPECT_NE(nullptr, Y_blob);
|
||||
|
||||
auto& Y = Y_blob->Get<Tensor>();
|
||||
outResult->CopyFrom(Y);
|
||||
}
|
||||
|
||||
// Check that the CPU and GPU implementations provide the exact same results
|
||||
void CheckCPUGPUEqual(vector<int64_t> shape, vector<int> indices) {
|
||||
// Prepare input data
|
||||
EXPECT_GT(shape.size(), 1);
|
||||
int N = shape[0];
|
||||
int input_size = 1;
|
||||
for (auto k : shape) {
|
||||
input_size *= k;
|
||||
}
|
||||
int K = input_size / N;
|
||||
vector<float> features(input_size);
|
||||
std::iota(features.begin(), features.end(), 0);
|
||||
|
||||
// CPU outputs
|
||||
Tensor y_cpu{CPU};
|
||||
Tensor y_cpu_grad{CPU};
|
||||
|
||||
// CPU BatchPermutation
|
||||
CreateAndRun<CPUContext>(&y_cpu, N, shape, features, indices);
|
||||
|
||||
// CPU BatchPermutationGradient
|
||||
CreateAndRunGradient<CPUContext>(&y_cpu_grad, N, shape, features, indices);
|
||||
|
||||
// Check CPU output values
|
||||
for (auto i = 0; i < indices.size(); ++i) {
|
||||
for (auto k = 0; k < K; ++k) {
|
||||
EXPECT_NEAR(
|
||||
y_cpu.data<float>()[indices[i] * K + k], features[i * K + k], 1e4);
|
||||
EXPECT_NEAR(
|
||||
y_cpu_grad.data<float>()[i * K + k],
|
||||
features[indices[i] * K + k],
|
||||
1e4);
|
||||
}
|
||||
}
|
||||
|
||||
if (!caffe2::HasCudaGPU()) {
|
||||
VLOG(2) << "No CudaGPU found. Skip GPU test." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
// GPU outputs
|
||||
Tensor y_gpu{CPU};
|
||||
Tensor y_gpu_grad{CPU};
|
||||
|
||||
// GPU BatchPermutation
|
||||
CreateAndRun<CPUContext>(&y_gpu, N, shape, features, indices);
|
||||
|
||||
// Compare CPU and GPU BatchPermutation outputs
|
||||
EXPECT_EQ(y_cpu.sizes(), y_gpu.sizes());
|
||||
ConstEigenVectorMap<float> y_cpu_vec(y_cpu.data<float>(), y_cpu.numel());
|
||||
ConstEigenVectorMap<float> y_gpu_vec(y_gpu.data<float>(), y_gpu.numel());
|
||||
EXPECT_TRUE(y_cpu_vec.isApprox(y_gpu_vec));
|
||||
|
||||
// GPU BatchPermutationGradient
|
||||
CreateAndRunGradient<CUDAContext>(&y_gpu_grad, N, shape, features, indices);
|
||||
|
||||
// Check GPU outputs
|
||||
for (auto i = 0; i < indices.size(); ++i) {
|
||||
for (auto k = 0; k < K; ++k) {
|
||||
EXPECT_NEAR(
|
||||
y_gpu.data<float>()[indices[i] * K + k], features[i * K + k], 1e4);
|
||||
EXPECT_NEAR(
|
||||
y_gpu_grad.data<float>()[i * K + k],
|
||||
features[indices[i] * K + k],
|
||||
1e4);
|
||||
}
|
||||
}
|
||||
|
||||
// Compare CPU and GPU BatchPermutationGradient outputs
|
||||
EXPECT_EQ(y_cpu_grad.sizes(), y_gpu_grad.sizes());
|
||||
ConstEigenVectorMap<float> y_cpu_vec_grad(
|
||||
y_cpu_grad.data<float>(), y_cpu_grad.numel());
|
||||
ConstEigenVectorMap<float> y_gpu_vec_grad(
|
||||
y_gpu_grad.data<float>(), y_gpu_grad.numel());
|
||||
EXPECT_TRUE(y_cpu_vec_grad.isApprox(y_gpu_vec_grad));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(BatchPermutationTest, CHECKCPUGPUEqualGenericDimension) {
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
int batch_size = 8;
|
||||
int max_dimension = 6;
|
||||
vector<int64_t> shape = vector<int64_t>{batch_size};
|
||||
|
||||
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
std::default_random_engine generator(seed);
|
||||
|
||||
for (int i = 2; i < max_dimension; ++i) {
|
||||
std::uniform_int_distribution<> dis(1, i);
|
||||
shape.push_back(dis(generator));
|
||||
CheckCPUGPUEqual(shape, vector<int>{0, 1, 2, 3, 4, 5, 6, 7});
|
||||
CheckCPUGPUEqual(shape, vector<int>{7, 6, 5, 4, 3, 2, 1, 0});
|
||||
CheckCPUGPUEqual(shape, vector<int>{1, 3, 5, 7, 0, 2, 4, 6});
|
||||
CheckCPUGPUEqual(shape, vector<int>{4, 5, 6, 7, 0, 1, 2, 3});
|
||||
CheckCPUGPUEqual(shape, vector<int>{3, 1, 5, 7, 6, 2, 4, 0});
|
||||
}
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
double elapsed =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0).count();
|
||||
VLOG(2) << "Time elapsed: " << elapsed << " ms" << std::endl;
|
||||
return;
|
||||
}
|
||||
} // namespace caffe2
|
||||
@ -1,131 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "batch_permutation_op.h"
|
||||
#ifdef CAFFE2_USE_MKLDNN
|
||||
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
|
||||
#include <caffe2/ideep/utils/ideep_operator.h>
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
#ifdef CAFFE2_USE_MKLDNN
|
||||
REGISTER_IDEEP_OPERATOR(
|
||||
BatchPermutation,
|
||||
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
|
||||
#endif
|
||||
|
||||
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
|
||||
REGISTER_CPU_OPERATOR(
|
||||
BatchPermutationGradient,
|
||||
BatchPermutationGradientOp<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(BatchPermutation)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.SetDoc(R"DOC(
|
||||
Permute the batch elements of the input tensor X according to the permutation
|
||||
specified in the input indices.
|
||||
|
||||
Warning: this op does not verify that indices is a valid permutation; gradient
|
||||
comptuation is only correct if indices is a permutation.
|
||||
)DOC")
|
||||
.Input(
|
||||
0,
|
||||
"X",
|
||||
"Tensor of at least 1D shape (N, D0, D1, ...).")
|
||||
.Input(
|
||||
1,
|
||||
"indices",
|
||||
"1D tensor of type int with shape (N, ) specifying a valid permutation "
|
||||
"of the indices in [0, N - 1] (inclusive).")
|
||||
.Output(
|
||||
0,
|
||||
"Y",
|
||||
"Tensor with the same shape as X where the (D0, D1, ...) dimensional "
|
||||
"batch elements of X are permuted according to the input indices.");
|
||||
|
||||
OPERATOR_SCHEMA(BatchPermutationGradient)
|
||||
.NumInputs(2)
|
||||
.NumOutputs(1)
|
||||
.Input(
|
||||
0,
|
||||
"indices",
|
||||
"See BatchPermutation.")
|
||||
.Input(
|
||||
1,
|
||||
"dY",
|
||||
"Gradient of forward output 0 (Y).")
|
||||
.Output(
|
||||
0,
|
||||
"dX",
|
||||
"Gradient of forward input 0 (X).");
|
||||
|
||||
template <>
|
||||
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
|
||||
const auto& X = Input(0);
|
||||
const auto& indices = Input(1);
|
||||
|
||||
CAFFE_ENFORCE_EQ(indices.dim(), 1, "indices must be 1-d");
|
||||
CAFFE_ENFORCE_EQ(
|
||||
X.dim32(0), indices.dim32(0),
|
||||
"X.dim32(0) must be equal to indices.dim32(0)",
|
||||
"(",
|
||||
X.dim32(0),
|
||||
" vs. ",
|
||||
indices.dim32(0),
|
||||
")");
|
||||
|
||||
auto* Y = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
const int N = X.dim32(0);
|
||||
const int C = X.dim32(1);
|
||||
const int H = X.dim32(2);
|
||||
const int W = X.dim32(3);
|
||||
|
||||
const float *src = X.template data<float>();
|
||||
float *dst = Y->template mutable_data<float>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#if (_OPENMP >= 201307)
|
||||
#pragma omp parallel for simd
|
||||
#else
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
#endif
|
||||
for (int i = 0; i < N; i++) {
|
||||
int idx = indices.template data<int>()[i];
|
||||
|
||||
std::memcpy(dst + i * C * H * W, src + idx * C * H * W, sizeof(float) * C * H * W);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
class GetBatchPermutationGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
return SingleGradientDef(
|
||||
"BatchPermutationGradient",
|
||||
"",
|
||||
vector<string>{I(1), GO(0)},
|
||||
vector<string>{GI(0)});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
|
||||
|
||||
} // namespace caffe2
|
||||
@ -1,112 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "modules/detectron/batch_permutation_op.h"
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
template <bool forward>
|
||||
__global__ void BatchPermutationKernel(
|
||||
int N,
|
||||
int C,
|
||||
int H,
|
||||
int W,
|
||||
const float* src,
|
||||
const int* indices,
|
||||
float* dst) {
|
||||
CUDA_1D_KERNEL_LOOP(index, N * C * H * W) {
|
||||
int w = index % W;
|
||||
int h = (index / W) % H;
|
||||
int c = (index / W / H) % C;
|
||||
int n = (index / W / H / C);
|
||||
int idx = indices[n];
|
||||
if (forward) {
|
||||
dst[n * C * H * W + c * H * W + h * W + w] =
|
||||
src[idx * C * H * W + c * H * W + h * W + w];
|
||||
} else {
|
||||
dst[idx * C * H * W + c * H * W + h * W + w] =
|
||||
src[n * C * H * W + c * H * W + h * W + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BatchPermutationOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& X = Input(0);
|
||||
auto& indices = Input(1);
|
||||
|
||||
|
||||
CAFFE_ENFORCE(indices.ndim() == 1, "indices must be 1-d");
|
||||
CAFFE_ENFORCE(
|
||||
X.dim32(0) == indices.dim32(0),
|
||||
"X.dim32(0) must be equal to indices.dim32(0)",
|
||||
"(",
|
||||
X.dim32(0),
|
||||
" vs. ",
|
||||
indices.dim32(0),
|
||||
")");
|
||||
|
||||
auto* Y = Output(0, X.sizes(), at::dtype<float>());
|
||||
|
||||
BatchPermutationKernel<true><<<
|
||||
CAFFE_GET_BLOCKS(X.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
X.dim32(0),
|
||||
X.dim32(1),
|
||||
X.dim32(2),
|
||||
X.dim32(3),
|
||||
X.data<float>(),
|
||||
indices.data<int>(),
|
||||
Y->mutable_data<float>());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool BatchPermutationGradientOp<float, CUDAContext>::RunOnDevice() {
|
||||
auto& indices = Input(0);
|
||||
auto& dY = Input(1);
|
||||
|
||||
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
|
||||
|
||||
BatchPermutationKernel<false><<<
|
||||
CAFFE_GET_BLOCKS(dY.size()),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
dY.dim32(0),
|
||||
dY.dim32(1),
|
||||
dY.dim32(2),
|
||||
dY.dim32(3),
|
||||
dY.data<float>(),
|
||||
indices.data<int>(),
|
||||
dX->mutable_data<float>());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
BatchPermutation,
|
||||
BatchPermutationOp<float, CUDAContext>);
|
||||
REGISTER_CUDA_OPERATOR(
|
||||
BatchPermutationGradient,
|
||||
BatchPermutationGradientOp<float, CUDAContext>);
|
||||
} // namespace caffe2
|
||||
@ -1,53 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef BATCHPERMUTATION_OP_H_
|
||||
#define BATCHPERMUTATION_OP_H_
|
||||
|
||||
#include <cstring>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T, class Context>
|
||||
class BatchPermutationOp final : public Operator<Context> {
|
||||
public:
|
||||
BatchPermutationOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override;
|
||||
};
|
||||
|
||||
template <typename T, class Context>
|
||||
class BatchPermutationGradientOp final : public Operator<Context> {
|
||||
public:
|
||||
BatchPermutationGradientOp(const OperatorDef& def, Workspace* ws)
|
||||
: Operator<Context>(def, ws) {}
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
|
||||
bool RunOnDevice() override {
|
||||
// No CPU implementation for now
|
||||
CAFFE_NOT_IMPLEMENTED;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // BATCHPERMUTATION_OP_H_
|
||||
Reference in New Issue
Block a user