ScaleBlobs Operator (#19660)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19660

Implementation of aggregated Scale operator.
The operator takes a list of tensors as an input and scales all of them them with the argument float value.
The tensor sizes can be different, therefore bookkeeping of the sizes and pointers to the tensors are
necessary for the GPU version of the kernel.

Reviewed By: BIT-silence

Differential Revision: D14984233

fbshipit-source-id: 37cc97159a4f2c38cd6fff4f5710ab7d3a773611
This commit is contained in:
Bilge Acun
2019-05-08 17:53:31 -07:00
committed by Facebook Github Bot
parent 4d676d53a6
commit 3ee97183b0
6 changed files with 313 additions and 3 deletions

View File

@ -81,7 +81,7 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
detail::_guard_long_unique<std::vector<long>>);
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, float*)
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, at::Half*)
} // namespace caffe2

View File

@ -624,6 +624,7 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
detail::_guard_long_unique<std::vector<long>>)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, float*)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, at::Half*)
} // namespace caffe2

View File

@ -0,0 +1,18 @@
#include "caffe2/operators/scale_blobs_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(ScaleBlobs, ScaleBlobsOp<CPUContext>);
OPERATOR_SCHEMA(ScaleBlobs)
.NumInputs(1, INT_MAX)
.NumOutputs(1, INT_MAX)
.AllowInplace([](int, int) { return true; })
.IdenticalTypeAndShape()
.SetDoc(R"DOC(
ScaleBlobs takes one or more input data (Tensor) and produces one
or more output data (Tensor) whose value is the input data tensor
scaled element-wise.
)DOC")
.Arg("scale", "(float, default 1.0) the scale to apply.");
} // namespace caffe2

View File

@ -0,0 +1,169 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/scale_blobs_op.h"
namespace caffe2 {
template <typename T>
__global__ void ScaleBlobsCUDAKernel(
const float scale,
const int numBlobs,
const int* sizeArr,
T** X,
T** Y) {
for (size_t i = 0; i < numBlobs; ++i) {
CUDA_1D_KERNEL_LOOP(j, sizeArr[i]) {
Y[i][j] = X[i][j] * scale;
}
}
}
template <typename T>
__global__ void ScaleBlobsCUDAKernelManyTensors(
const float scale,
const int* sizeArr,
T** X,
T** Y) {
for (size_t i = threadIdx.x; i < sizeArr[blockIdx.x]; i += blockDim.x) {
Y[blockIdx.x][i] = X[blockIdx.x][i] * scale;
}
}
template <>
template <typename T>
bool ScaleBlobsOp<CUDAContext>::DoRunWithType() {
const int numBlobs = InputSize();
ReinitializeTensor(&hostBlobSizes_, {numBlobs}, at::dtype<int>().device(CPU));
int* hostBlobSizesData = hostBlobSizes_.mutable_data<int>();
ReinitializeTensor(&hostInputs_, {numBlobs}, at::dtype<T*>().device(CPU));
T** hostInputsData = hostInputs_.mutable_data<T*>();
ReinitializeTensor(&hostOutputs_, {numBlobs}, at::dtype<T*>().device(CPU));
T** hostOutputsData = hostOutputs_.mutable_data<T*>();
int totalSize = 0;
int maxSize = 0;
for (int i = 0; i < numBlobs; ++i) {
hostBlobSizesData[i] = Input(i).numel();
totalSize += hostBlobSizesData[i];
maxSize = max(maxSize, hostBlobSizesData[i]);
hostInputsData[i] = Input(i).template data<T>();
hostOutputsData[i] = Output(i)->template mutable_data<T>();
}
ReinitializeTensor(&inputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
ReinitializeTensor(&outputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
ReinitializeTensor(&blobSizes_, {numBlobs}, at::dtype<T*>().device(CUDA));
blobSizes_.CopyFrom(hostBlobSizes_);
inputs_.CopyFrom(hostInputs_);
outputs_.CopyFrom(hostOutputs_);
// Select which kernel to launch based on the length of the tensors
// The first one performs better when there are many tensors of short length
// The second one is better when there are small number of long tensors
if (numBlobs > CAFFE_GET_BLOCKS(maxSize)) {
// Note: number of blocks has to be equal to the numBlobs
ScaleBlobsCUDAKernelManyTensors<T>
<<<numBlobs, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
scale_,
blobSizes_.data<int>(),
inputs_.mutable_data<T*>(),
outputs_.mutable_data<T*>());
} else {
ScaleBlobsCUDAKernel<T>
<<<CAFFE_GET_BLOCKS(maxSize),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
scale_,
numBlobs,
blobSizes_.data<int>(),
inputs_.mutable_data<T*>(),
outputs_.mutable_data<T*>());
}
return true;
}
template <>
bool ScaleBlobsOp<CUDAContext>::RunOnDevice() {
for (int i = 0; i < InputSize(); ++i) {
auto& input = this->template Input<Tensor>(i, CUDA);
auto* output = this->template Output<Tensor>(i, CUDA);
output->ResizeLike(input);
}
return DispatchHelper<TensorTypes<at::Half, float>>::call(this, Input(0));
}
REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
/*
* Implementation of a different version of the kernel
* This balances the work per thread and could be useful
* when there is a high imbalance between tensors
* However the memory requirement is very high so it does
* not perform well for common scenarios
*
*
* Additional storage for the start pointers is required
* for ScaleBlobsCUDAKernelBalanced setup
*
int threadsPerBlock = CAFFE_CUDA_NUM_THREADS;
int coorArrSize = 2 * ((totalSize - 1) / threadsPerBlock + 1);
int startCoorArr[coorArrSize];
int* dStartCoorArr;
int j = 0, cur = 0, elemsLeftInRow = 0;
for (int i = 0; i < numBlobs; ++i) {
if (i == 0) {
startCoorArr[cur++] = i;
startCoorArr[cur++] = j;
elemsLeftInRow = 0;
}
while (j < sizeArr[i]) {
j += threadsPerBlock - elemsLeftInRow;
if (j < sizeArr[i]) {
startCoorArr[cur++] = i;
startCoorArr[cur++] = j;
elemsLeftInRow = 0;
} else {
elemsLeftInRow = sizeArr[i] - j + threadsPerBlock;
j = 0;
break;
}
}
}
cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize);
cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
cudaMemcpyHostToDevice);
// ScaleBlobsCUDAKernelBalanced kernel launch
ScaleBlobsCUDAKernelBalanced<T>
<<<(totalSize-1)/CAFFE_CUDA_NUM_THREADS+1, CAFFE_CUDA_NUM_THREADS, 0,
context_.cuda_stream()>>>(
scale_, numBlobs, coorArrSize, dStartCoorArr, dSizeArr, dInputArr,
dOutputArr);
cudaFree(dStartCoorArr);
*/
template <typename T>
__global__ void ScaleBlobsCUDAKernelBalanced(
const float scale,
const int numBlobs,
const int coorArrSize,
const int* coorArr,
const int* sizeArr,
T** X,
T** Y) {
int i = coorArr[2 * blockIdx.x + 1] + threadIdx.x;
int curTen = coorArr[2 * blockIdx.x];
while (curTen < numBlobs && i >= sizeArr[curTen]) {
i -= sizeArr[curTen++];
}
if (curTen < numBlobs) {
Y[curTen][i] = X[curTen][i] * scale;
}
}
} // namespace caffe2

View File

@ -0,0 +1,58 @@
#ifndef CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
#define CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context>
class ScaleBlobsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit ScaleBlobsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(float, "scale", scale_, 1.0f) {}
template <typename T>
bool DoRunWithType() {
int batchSize = InputSize();
for (int i = 0; i < batchSize; ++i) {
const auto& X = Input(i);
auto* Y = Output(i, X.sizes(), at::dtype<T>());
math::Scale<float, T, Context>(
X.numel(),
scale_,
X.template data<T>(),
Y->template mutable_data<T>(),
&context_);
}
return true;
}
bool RunOnDevice() override {
for (int i = 0; i < InputSize(); ++i) {
auto& input = this->template Input<Tensor>(i, CPU);
auto* output = this->template Output<Tensor>(i, CPU);
output->ResizeLike(input);
}
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
}
private:
const float scale_;
Tensor blobSizes_;
Tensor inputs_;
Tensor outputs_;
Tensor hostBlobSizes_;
Tensor hostInputs_;
Tensor hostOutputs_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_

View File

@ -0,0 +1,64 @@
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, workspace
import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial
import hypothesis.strategies as st
import numpy as np
class TestScaleOps(serial.SerializedTestCase):
@serial.given(dim=st.sampled_from([[1, 386, 1], [386, 1, 1],
[1, 256, 1], [256, 1, 1],
[1024, 256, 1], [1, 1024, 1],
[1, 1, 1]]),
scale=st.floats(0.0, 10.0),
num_tensors=st.integers(1, 10),
**hu.gcs)
def test_scale_ops(self, dim, scale, num_tensors, gc, dc):
in_tensors = []
in_tensor_ps = []
out_tensors = []
out_ref_tensors = []
# initialize tensors
for i in range(num_tensors):
tensor = "X_{}".format(i)
X = np.random.rand(*dim).astype(np.float32) - 0.5
in_tensors.append(tensor)
in_tensor_ps.append(X)
out_tensor = "O_{}".format(i)
out_tensors.append(out_tensor)
workspace.FeedBlob(tensor, X, device_option=gc)
# run ScaleBlobs operator
scale_blobs_op = core.CreateOperator(
"ScaleBlobs",
in_tensors,
out_tensors,
scale=scale,
)
scale_blobs_op.device_option.CopyFrom(gc)
workspace.RunOperatorOnce(scale_blobs_op)
# run Scale op for each tensor and compare with ScaleBlobs
for i in range(num_tensors):
tensor = "X_{}".format(i)
out_ref_tensor = "O_ref_{}".format(i)
scale_op = core.CreateOperator(
"Scale",
[tensor],
[out_ref_tensor],
scale=scale,
)
scale_op.device_option.CopyFrom(gc)
workspace.RunOperatorOnce(scale_op)
o_ref = workspace.FetchBlob(out_ref_tensor)
o = workspace.FetchBlob(out_tensors[i])
np.testing.assert_allclose(o, o_ref)
if __name__ == '__main__':
unittest.main()