move BatchPermutationOp to caffe2/operators

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31350

Reviewed By: houseroad

Differential Revision: D19053527

fbshipit-source-id: 50d11f137d0f5c07e8ad899a3a84d56a042bbc32
This commit is contained in:
Yanghan Wang
2019-12-17 14:53:33 -08:00
committed by Facebook Github Bot
parent 0b8332efb4
commit d9c3913dfc
7 changed files with 588 additions and 296 deletions

View File

@ -0,0 +1,169 @@
#include "caffe2/operators/batch_permutation_op.h"
#include <cstring>
#include <vector>
#ifdef CAFFE2_USE_MKLDNN
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif
namespace caffe2 {
template <bool forwards>
void batch_permutation_loop(
const int N,
const int K,
const float* src,
const int* indices,
float* dst) {
long numBytes = K * sizeof(float);
if (forwards) {
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd
#else
#pragma omp parallel for
#endif
#endif
for (int n = 0; n < N; n++) {
int origIdx = n * K;
int permuteIdx = indices[n] * K;
std::memcpy(dst + origIdx, src + permuteIdx, numBytes);
}
} else {
std::vector<int> backward_indices(N);
for (size_t i = 0; i < N; ++i) {
backward_indices[indices[i]] = i;
}
for (int n = 0; n < N; n++) {
int permuteIdx = n * K;
int origIdx = backward_indices[n] * K;
std::memcpy(dst + permuteIdx, src + origIdx, numBytes);
}
}
}
template <>
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(0);
auto& indices = Input(1);
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
CAFFE_ENFORCE(
X.dim32(0) == indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
auto* Y = Output(0, X.sizes(), at::dtype<float>());
CAFFE_ENFORCE_GT(X.dim32(0), 0);
batch_permutation_loop<true>(
X.dim32(0),
X.numel() / X.dim32(0),
X.data<float>(),
indices.data<int>(),
Y->mutable_data<float>());
return true;
}
template <>
bool BatchPermutationGradientOp<float, CPUContext>::RunOnDevice() {
auto& indices = Input(0);
auto& dY = Input(1);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
batch_permutation_loop<false>(
dY.dim32(0),
dY.numel() / dY.dim32(0),
dY.data<float>(),
indices.data<int>(),
dX->mutable_data<float>());
return true;
}
#ifdef CAFFE2_USE_MKLDNN
REGISTER_IDEEP_OPERATOR(
BatchPermutation,
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CPUContext>);
// Input: X, indices; Output: Y
OPERATOR_SCHEMA(BatchPermutation)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Batch permutation of an input tensor X given input indices. First dimension of
X equals batch size N. The indices stores a be permutation of N.
The output Y is a tensor of same shape as X, with data re-ordered according to
the indices within the batch size.
Example of batch permutation on a 2-D tensor with batch size 4:
X = [
[1, 5, 2, 3, 4, 6, 0],
[4, 3, 3, 5, 2, 3, 1],
[2, 2, 3, 6, 0, 0, 1],
[0, 0, 1, 1, 2, 2, 3]
]
indices = [2, 0, 1, 3]
Y = [
[2, 2, 3, 6, 0, 0, 1],
[1, 5, 2, 3, 4, 6, 0],
[4, 3, 3, 5, 2, 3, 1],
[0, 0, 1, 1, 2, 2, 3]
]
Example of batch permutation on a 3-D tensor with batch size 4:
X = [
[[1, 5, 2], [3, 4, 6, 0]],
[[4, 3, 3], [5, 2, 3, 1]],
[[2, 2, 3], [6, 0, 0, 1]],
[[0, 0, 1], [1, 2, 2, 3]]
]
indices = [2, 0, 1, 3]
Y = [
[[2, 2, 3], [6, 0, 0, 1]],
[[1, 5, 2], [3, 4, 6, 0]],
[[4, 3, 3], [5, 2, 3, 1]],
[[0, 0, 1], [1, 2, 2, 3]]
]
)DOC")
.Input(0, "X", "Input tensor, where 1st dimension equals batch size")
.Input(1, "indices", "Input indices of batch to permute")
.Output(0, "Y", "Output permuted tensor");
// Input: indices, dY (aka "gradOutput"); Output: dX (aka "gradInput")
OPERATOR_SCHEMA(BatchPermutationGradient).NumInputs(2).NumOutputs(1);
class GetBatchPermutationGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BatchPermutationGradient",
"",
vector<string>{I(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
} // namespace caffe2
using BatchPermutationOpFloatCPU =
caffe2::BatchPermutationOp<float, caffe2::CPUContext>;
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
BatchPermutation,
"_caffe2::BatchPermutation(Tensor X, Tensor indices) -> Tensor",
BatchPermutationOpFloatCPU);

View File

@ -0,0 +1,113 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/batch_permutation_op.h"
namespace caffe2 {
namespace {
template <bool forward>
__global__ void BatchPermutationKernel(
int N,
int K,
const float* src,
const int* indices,
float* dst) {
if (forward) {
CUDA_1D_KERNEL_LOOP(index, N * K) {
int k = index % K;
int n = index / K;
int idx = indices[n];
CUDA_KERNEL_ASSERT(idx >= 0);
CUDA_KERNEL_ASSERT(idx < N);
dst[index] = src[idx * K + k];
}
} else {
CUDA_1D_KERNEL_LOOP(index, N * K) {
int k = index % K;
int n = index / K;
// NOTE: an alternative implementation if we want to align the index with
// the output tensor (rather than the input tensor).
// int idx = -1;
// for (size_t i = 0; i < N; ++i) {
// if (indices[i] == n) {
// idx = i;
// }
// }
// CUDA_KERNEL_ASSERT(idx >= 0);
// CUDA_KERNEL_ASSERT(idx < N);
// dst[index] = src[idx * K + k];
int idx = indices[n];
CUDA_KERNEL_ASSERT(idx >= 0);
CUDA_KERNEL_ASSERT(idx < N);
dst[idx * K + k] = src[index];
}
}
}
} // namespace
template <>
bool BatchPermutationOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& indices = Input(1);
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
CAFFE_ENFORCE(
X.dim32(0) == indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
auto* Y = Output(0, X.sizes(), at::dtype<float>());
CAFFE_ENFORCE_GT(X.dim32(0), 0);
BatchPermutationKernel<true>
<<<CAFFE_GET_BLOCKS(X.numel()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
X.dim32(0),
X.numel() / X.dim32(0),
X.data<float>(),
indices.data<int>(),
Y->mutable_data<float>());
return true;
}
template <>
bool BatchPermutationGradientOp<float, CUDAContext>::RunOnDevice() {
auto& indices = Input(0);
auto& dY = Input(1);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
BatchPermutationKernel<false>
<<<CAFFE_GET_BLOCKS(dY.numel()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
dY.dim32(0),
dY.numel() / dY.dim32(0),
dY.data<float>(),
indices.data<int>(),
dX->mutable_data<float>());
return true;
}
REGISTER_CUDA_OPERATOR(
BatchPermutation,
BatchPermutationOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CUDAContext>);
} // namespace caffe2
using BatchPermutationOpFloatCUDA =
caffe2::BatchPermutationOp<float, caffe2::CUDAContext>;
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(BatchPermutation, BatchPermutationOpFloatCUDA);

View File

@ -0,0 +1,37 @@
#ifndef BATCHPERMUTATION_OP_H_
#define BATCHPERMUTATION_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchPermutation)
namespace caffe2 {
template <typename T, class Context>
class BatchPermutationOp final : public Operator<Context> {
public:
template <class... Args>
explicit BatchPermutationOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice();
};
template <typename T, class Context>
class BatchPermutationGradientOp final : public Operator<Context> {
public:
BatchPermutationGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice();
};
} // namespace caffe2
#endif // BATCHPERMUTATION_OP_H_

View File

@ -0,0 +1,269 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/core/flags.h"
#include "caffe2/operators/batch_permutation_op.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
#include "gtest/gtest.h"
namespace caffe2 {
namespace {
// Add the vector as an input to a Workspace depending on the context of the
// workspace
template <typename T>
void AddInputCPU(
const vector<int64_t>& shape,
const vector<T>& values,
const string& name,
Workspace* ws) {
Blob* blob = ws->CreateBlob(name);
auto* tensor = BlobGetMutableTensor(blob, CPU);
tensor->Resize(shape);
EigenVectorMap<T> tensor_vec(tensor->mutable_data<T>(), tensor->numel());
tensor_vec.array() = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>{
values.data(), static_cast<int>(values.size())};
}
template <typename T>
void AddInputGPU(
const vector<int64_t>& shape,
const vector<T>& values,
const string& name,
Workspace* ws) {
Tensor tmp(shape, CPU);
EigenVectorMap<T> tmp_vec(tmp.mutable_data<T>(), tmp.numel());
tmp_vec.array() = Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>>{
values.data(), static_cast<int>(values.size())};
Blob* blob = ws->CreateBlob(name);
auto* tensor = BlobGetMutableTensor(blob, CUDA);
tensor->CopyFrom(tmp);
}
// Overload 4 different signatures for AddInput because clang does not allow
// template <typename T>
// void AddInput<CPUContext>(...) {...}
template <typename T, class Context>
void AddInput(
const vector<int64_t>& shape,
const vector<T>& values,
const string& name,
Workspace* ws);
template <>
void AddInput<int, CPUContext>(
const vector<int64_t>& shape,
const vector<int>& values,
const string& name,
Workspace* ws) {
AddInputCPU<int>(shape, values, name, ws);
}
template <>
void AddInput<float, CPUContext>(
const vector<int64_t>& shape,
const vector<float>& values,
const string& name,
Workspace* ws) {
AddInputCPU<float>(shape, values, name, ws);
}
template <>
void AddInput<int, CUDAContext>(
const vector<int64_t>& shape,
const vector<int>& values,
const string& name,
Workspace* ws) {
AddInputGPU<int>(shape, values, name, ws);
}
template <>
void AddInput<float, CUDAContext>(
const vector<int64_t>& shape,
const vector<float>& values,
const string& name,
Workspace* ws) {
AddInputGPU<float>(shape, values, name, ws);
}
template <class Context>
DeviceTypeProto GetDeviceType() {
return PROTO_CPU;
}
template <>
DeviceTypeProto GetDeviceType<CUDAContext>() {
return PROTO_CUDA;
}
// Create a BatchPermutationOp with the given inputs (actual values are
// generated sequentially) and run it
template <class Context>
void CreateAndRun(
TensorCPU* outResult,
int N,
vector<int64_t>& shape,
vector<float>& features,
vector<int> indices) {
Workspace ws;
AddInput<float, Context>(shape, features, "X", &ws);
AddInput<int, Context>(vector<int64_t>{N}, indices, "indices", &ws);
OperatorDef def;
def.set_name("test");
def.set_type("BatchPermutation");
def.add_input("X");
def.add_input("indices");
def.add_output("Y");
def.mutable_device_option()->set_device_type(GetDeviceType<Context>());
unique_ptr<OperatorBase> op = CreateOperator(def, &ws);
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
Blob* Y_blob = ws.GetBlob("Y");
EXPECT_NE(nullptr, Y_blob);
auto& Y = Y_blob->Get<Tensor>();
outResult->CopyFrom(Y);
}
// Create a BatchPermutationOp with the given inputs (actual values are
// generated sequentially) and run it
template <class Context>
void CreateAndRunGradient(
TensorCPU* outResult,
int N,
vector<int64_t>& shape,
vector<float>& features,
vector<int> indices) {
Workspace ws;
AddInput<float, Context>(shape, features, "dY", &ws);
AddInput<int, Context>(vector<int64_t>{N}, indices, "indices", &ws);
OperatorDef def;
def.set_name("test");
def.set_type("BatchPermutationGradient");
def.add_input("indices");
def.add_input("dY");
def.add_output("dX");
def.mutable_device_option()->set_device_type(GetDeviceType<Context>());
unique_ptr<OperatorBase> op = CreateOperator(def, &ws);
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
Blob* Y_blob = ws.GetBlob("dX");
EXPECT_NE(nullptr, Y_blob);
auto& Y = Y_blob->Get<Tensor>();
outResult->CopyFrom(Y);
}
// Check that the CPU and GPU implementations provide the exact same results
void CheckCPUGPUEqual(vector<int64_t> shape, vector<int> indices) {
// Prepare input data
EXPECT_GT(shape.size(), 1);
int N = shape[0];
int input_size = 1;
for (auto k : shape) {
input_size *= k;
}
int K = input_size / N;
vector<float> features(input_size);
std::iota(features.begin(), features.end(), 0);
// CPU outputs
Tensor y_cpu{CPU};
Tensor y_cpu_grad{CPU};
// CPU BatchPermutation
CreateAndRun<CPUContext>(&y_cpu, N, shape, features, indices);
// CPU BatchPermutationGradient
CreateAndRunGradient<CPUContext>(&y_cpu_grad, N, shape, features, indices);
// Check CPU output values
for (auto i = 0; i < indices.size(); ++i) {
for (auto k = 0; k < K; ++k) {
EXPECT_NEAR(
y_cpu.data<float>()[indices[i] * K + k], features[i * K + k], 1e4);
EXPECT_NEAR(
y_cpu_grad.data<float>()[i * K + k],
features[indices[i] * K + k],
1e4);
}
}
if (!caffe2::HasCudaGPU()) {
VLOG(2) << "No CudaGPU found. Skip GPU test." << std::endl;
return;
}
// GPU outputs
Tensor y_gpu{CPU};
Tensor y_gpu_grad{CPU};
// GPU BatchPermutation
CreateAndRun<CPUContext>(&y_gpu, N, shape, features, indices);
// Compare CPU and GPU BatchPermutation outputs
EXPECT_EQ(y_cpu.sizes(), y_gpu.sizes());
ConstEigenVectorMap<float> y_cpu_vec(y_cpu.data<float>(), y_cpu.numel());
ConstEigenVectorMap<float> y_gpu_vec(y_gpu.data<float>(), y_gpu.numel());
EXPECT_TRUE(y_cpu_vec.isApprox(y_gpu_vec));
// GPU BatchPermutationGradient
CreateAndRunGradient<CUDAContext>(&y_gpu_grad, N, shape, features, indices);
// Check GPU outputs
for (auto i = 0; i < indices.size(); ++i) {
for (auto k = 0; k < K; ++k) {
EXPECT_NEAR(
y_gpu.data<float>()[indices[i] * K + k], features[i * K + k], 1e4);
EXPECT_NEAR(
y_gpu_grad.data<float>()[i * K + k],
features[indices[i] * K + k],
1e4);
}
}
// Compare CPU and GPU BatchPermutationGradient outputs
EXPECT_EQ(y_cpu_grad.sizes(), y_gpu_grad.sizes());
ConstEigenVectorMap<float> y_cpu_vec_grad(
y_cpu_grad.data<float>(), y_cpu_grad.numel());
ConstEigenVectorMap<float> y_gpu_vec_grad(
y_gpu_grad.data<float>(), y_gpu_grad.numel());
EXPECT_TRUE(y_cpu_vec_grad.isApprox(y_gpu_vec_grad));
}
} // namespace
TEST(BatchPermutationTest, CHECKCPUGPUEqualGenericDimension) {
auto t0 = std::chrono::high_resolution_clock::now();
int batch_size = 8;
int max_dimension = 6;
vector<int64_t> shape = vector<int64_t>{batch_size};
auto seed = std::chrono::system_clock::now().time_since_epoch().count();
std::default_random_engine generator(seed);
for (int i = 2; i < max_dimension; ++i) {
std::uniform_int_distribution<> dis(1, i);
shape.push_back(dis(generator));
CheckCPUGPUEqual(shape, vector<int>{0, 1, 2, 3, 4, 5, 6, 7});
CheckCPUGPUEqual(shape, vector<int>{7, 6, 5, 4, 3, 2, 1, 0});
CheckCPUGPUEqual(shape, vector<int>{1, 3, 5, 7, 0, 2, 4, 6});
CheckCPUGPUEqual(shape, vector<int>{4, 5, 6, 7, 0, 1, 2, 3});
CheckCPUGPUEqual(shape, vector<int>{3, 1, 5, 7, 6, 2, 4, 0});
}
auto t1 = std::chrono::high_resolution_clock::now();
double elapsed =
std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0).count();
VLOG(2) << "Time elapsed: " << elapsed << " ms" << std::endl;
return;
}
} // namespace caffe2

View File

@ -1,131 +0,0 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "batch_permutation_op.h"
#ifdef CAFFE2_USE_MKLDNN
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
#include <caffe2/ideep/utils/ideep_operator.h>
#endif
namespace caffe2 {
#ifdef CAFFE2_USE_MKLDNN
REGISTER_IDEEP_OPERATOR(
BatchPermutation,
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
#endif
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(BatchPermutation)
.NumInputs(2)
.NumOutputs(1)
.SetDoc(R"DOC(
Permute the batch elements of the input tensor X according to the permutation
specified in the input indices.
Warning: this op does not verify that indices is a valid permutation; gradient
comptuation is only correct if indices is a permutation.
)DOC")
.Input(
0,
"X",
"Tensor of at least 1D shape (N, D0, D1, ...).")
.Input(
1,
"indices",
"1D tensor of type int with shape (N, ) specifying a valid permutation "
"of the indices in [0, N - 1] (inclusive).")
.Output(
0,
"Y",
"Tensor with the same shape as X where the (D0, D1, ...) dimensional "
"batch elements of X are permuted according to the input indices.");
OPERATOR_SCHEMA(BatchPermutationGradient)
.NumInputs(2)
.NumOutputs(1)
.Input(
0,
"indices",
"See BatchPermutation.")
.Input(
1,
"dY",
"Gradient of forward output 0 (Y).")
.Output(
0,
"dX",
"Gradient of forward input 0 (X).");
template <>
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
const auto& X = Input(0);
const auto& indices = Input(1);
CAFFE_ENFORCE_EQ(indices.dim(), 1, "indices must be 1-d");
CAFFE_ENFORCE_EQ(
X.dim32(0), indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
auto* Y = Output(0, X.sizes(), at::dtype<float>());
const int N = X.dim32(0);
const int C = X.dim32(1);
const int H = X.dim32(2);
const int W = X.dim32(3);
const float *src = X.template data<float>();
float *dst = Y->template mutable_data<float>();
#ifdef _OPENMP
#if (_OPENMP >= 201307)
#pragma omp parallel for simd
#else
#pragma omp parallel for
#endif
#endif
for (int i = 0; i < N; i++) {
int idx = indices.template data<int>()[i];
std::memcpy(dst + i * C * H * W, src + idx * C * H * W, sizeof(float) * C * H * W);
}
return true;
}
class GetBatchPermutationGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"BatchPermutationGradient",
"",
vector<string>{I(1), GO(0)},
vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
} // namespace caffe2

View File

@ -1,112 +0,0 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "modules/detectron/batch_permutation_op.h"
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <bool forward>
__global__ void BatchPermutationKernel(
int N,
int C,
int H,
int W,
const float* src,
const int* indices,
float* dst) {
CUDA_1D_KERNEL_LOOP(index, N * C * H * W) {
int w = index % W;
int h = (index / W) % H;
int c = (index / W / H) % C;
int n = (index / W / H / C);
int idx = indices[n];
if (forward) {
dst[n * C * H * W + c * H * W + h * W + w] =
src[idx * C * H * W + c * H * W + h * W + w];
} else {
dst[idx * C * H * W + c * H * W + h * W + w] =
src[n * C * H * W + c * H * W + h * W + w];
}
}
}
}
template <>
bool BatchPermutationOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(0);
auto& indices = Input(1);
CAFFE_ENFORCE(indices.ndim() == 1, "indices must be 1-d");
CAFFE_ENFORCE(
X.dim32(0) == indices.dim32(0),
"X.dim32(0) must be equal to indices.dim32(0)",
"(",
X.dim32(0),
" vs. ",
indices.dim32(0),
")");
auto* Y = Output(0, X.sizes(), at::dtype<float>());
BatchPermutationKernel<true><<<
CAFFE_GET_BLOCKS(X.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
X.dim32(0),
X.dim32(1),
X.dim32(2),
X.dim32(3),
X.data<float>(),
indices.data<int>(),
Y->mutable_data<float>());
return true;
}
template <>
bool BatchPermutationGradientOp<float, CUDAContext>::RunOnDevice() {
auto& indices = Input(0);
auto& dY = Input(1);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
BatchPermutationKernel<false><<<
CAFFE_GET_BLOCKS(dY.size()),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
dY.dim32(0),
dY.dim32(1),
dY.dim32(2),
dY.dim32(3),
dY.data<float>(),
indices.data<int>(),
dX->mutable_data<float>());
return true;
}
REGISTER_CUDA_OPERATOR(
BatchPermutation,
BatchPermutationOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(
BatchPermutationGradient,
BatchPermutationGradientOp<float, CUDAContext>);
} // namespace caffe2

View File

@ -1,53 +0,0 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef BATCHPERMUTATION_OP_H_
#define BATCHPERMUTATION_OP_H_
#include <cstring>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class BatchPermutationOp final : public Operator<Context> {
public:
BatchPermutationOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
template <typename T, class Context>
class BatchPermutationGradientOp final : public Operator<Context> {
public:
BatchPermutationGradientOp(const OperatorDef& def, Workspace* ws)
: Operator<Context>(def, ws) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
// No CPU implementation for now
CAFFE_NOT_IMPLEMENTED;
}
};
} // namespace caffe2
#endif // BATCHPERMUTATION_OP_H_