mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ScaleBlobs Operator (#19660)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19660 Implementation of aggregated Scale operator. The operator takes a list of tensors as an input and scales all of them them with the argument float value. The tensor sizes can be different, therefore bookkeeping of the sizes and pointers to the tensors are necessary for the GPU version of the kernel. Reviewed By: BIT-silence Differential Revision: D14984233 fbshipit-source-id: 37cc97159a4f2c38cd6fff4f5710ab7d3a773611
This commit is contained in:
committed by
Facebook Github Bot
parent
4d676d53a6
commit
3ee97183b0
@ -81,7 +81,7 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
|
||||
detail::_guard_long_unique<std::vector<long>>);
|
||||
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
|
||||
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
|
||||
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(29, float*)
|
||||
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(30, at::Half*)
|
||||
} // namespace caffe2
|
||||
|
@ -624,6 +624,7 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
|
||||
detail::_guard_long_unique<std::vector<long>>)
|
||||
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, c10::qint8);
|
||||
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(29, float*)
|
||||
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(30, at::Half*)
|
||||
} // namespace caffe2
|
||||
|
18
caffe2/operators/scale_blobs_op.cc
Normal file
18
caffe2/operators/scale_blobs_op.cc
Normal file
@ -0,0 +1,18 @@
|
||||
#include "caffe2/operators/scale_blobs_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
REGISTER_CPU_OPERATOR(ScaleBlobs, ScaleBlobsOp<CPUContext>);
|
||||
OPERATOR_SCHEMA(ScaleBlobs)
|
||||
.NumInputs(1, INT_MAX)
|
||||
.NumOutputs(1, INT_MAX)
|
||||
.AllowInplace([](int, int) { return true; })
|
||||
.IdenticalTypeAndShape()
|
||||
.SetDoc(R"DOC(
|
||||
ScaleBlobs takes one or more input data (Tensor) and produces one
|
||||
or more output data (Tensor) whose value is the input data tensor
|
||||
scaled element-wise.
|
||||
)DOC")
|
||||
.Arg("scale", "(float, default 1.0) the scale to apply.");
|
||||
|
||||
} // namespace caffe2
|
169
caffe2/operators/scale_blobs_op.cu
Normal file
169
caffe2/operators/scale_blobs_op.cu
Normal file
@ -0,0 +1,169 @@
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/operators/scale_blobs_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <typename T>
|
||||
__global__ void ScaleBlobsCUDAKernel(
|
||||
const float scale,
|
||||
const int numBlobs,
|
||||
const int* sizeArr,
|
||||
T** X,
|
||||
T** Y) {
|
||||
for (size_t i = 0; i < numBlobs; ++i) {
|
||||
CUDA_1D_KERNEL_LOOP(j, sizeArr[i]) {
|
||||
Y[i][j] = X[i][j] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ScaleBlobsCUDAKernelManyTensors(
|
||||
const float scale,
|
||||
const int* sizeArr,
|
||||
T** X,
|
||||
T** Y) {
|
||||
for (size_t i = threadIdx.x; i < sizeArr[blockIdx.x]; i += blockDim.x) {
|
||||
Y[blockIdx.x][i] = X[blockIdx.x][i] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
template <typename T>
|
||||
bool ScaleBlobsOp<CUDAContext>::DoRunWithType() {
|
||||
const int numBlobs = InputSize();
|
||||
|
||||
ReinitializeTensor(&hostBlobSizes_, {numBlobs}, at::dtype<int>().device(CPU));
|
||||
int* hostBlobSizesData = hostBlobSizes_.mutable_data<int>();
|
||||
|
||||
ReinitializeTensor(&hostInputs_, {numBlobs}, at::dtype<T*>().device(CPU));
|
||||
T** hostInputsData = hostInputs_.mutable_data<T*>();
|
||||
|
||||
ReinitializeTensor(&hostOutputs_, {numBlobs}, at::dtype<T*>().device(CPU));
|
||||
T** hostOutputsData = hostOutputs_.mutable_data<T*>();
|
||||
|
||||
int totalSize = 0;
|
||||
int maxSize = 0;
|
||||
for (int i = 0; i < numBlobs; ++i) {
|
||||
hostBlobSizesData[i] = Input(i).numel();
|
||||
totalSize += hostBlobSizesData[i];
|
||||
maxSize = max(maxSize, hostBlobSizesData[i]);
|
||||
hostInputsData[i] = Input(i).template data<T>();
|
||||
hostOutputsData[i] = Output(i)->template mutable_data<T>();
|
||||
}
|
||||
|
||||
ReinitializeTensor(&inputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
|
||||
ReinitializeTensor(&outputs_, {numBlobs}, at::dtype<T*>().device(CUDA));
|
||||
ReinitializeTensor(&blobSizes_, {numBlobs}, at::dtype<T*>().device(CUDA));
|
||||
|
||||
blobSizes_.CopyFrom(hostBlobSizes_);
|
||||
inputs_.CopyFrom(hostInputs_);
|
||||
outputs_.CopyFrom(hostOutputs_);
|
||||
|
||||
// Select which kernel to launch based on the length of the tensors
|
||||
// The first one performs better when there are many tensors of short length
|
||||
// The second one is better when there are small number of long tensors
|
||||
if (numBlobs > CAFFE_GET_BLOCKS(maxSize)) {
|
||||
// Note: number of blocks has to be equal to the numBlobs
|
||||
ScaleBlobsCUDAKernelManyTensors<T>
|
||||
<<<numBlobs, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
|
||||
scale_,
|
||||
blobSizes_.data<int>(),
|
||||
inputs_.mutable_data<T*>(),
|
||||
outputs_.mutable_data<T*>());
|
||||
} else {
|
||||
ScaleBlobsCUDAKernel<T>
|
||||
<<<CAFFE_GET_BLOCKS(maxSize),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
0,
|
||||
context_.cuda_stream()>>>(
|
||||
scale_,
|
||||
numBlobs,
|
||||
blobSizes_.data<int>(),
|
||||
inputs_.mutable_data<T*>(),
|
||||
outputs_.mutable_data<T*>());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool ScaleBlobsOp<CUDAContext>::RunOnDevice() {
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
auto& input = this->template Input<Tensor>(i, CUDA);
|
||||
auto* output = this->template Output<Tensor>(i, CUDA);
|
||||
output->ResizeLike(input);
|
||||
}
|
||||
return DispatchHelper<TensorTypes<at::Half, float>>::call(this, Input(0));
|
||||
}
|
||||
|
||||
REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
|
||||
|
||||
/*
|
||||
* Implementation of a different version of the kernel
|
||||
* This balances the work per thread and could be useful
|
||||
* when there is a high imbalance between tensors
|
||||
* However the memory requirement is very high so it does
|
||||
* not perform well for common scenarios
|
||||
*
|
||||
*
|
||||
* Additional storage for the start pointers is required
|
||||
* for ScaleBlobsCUDAKernelBalanced setup
|
||||
*
|
||||
int threadsPerBlock = CAFFE_CUDA_NUM_THREADS;
|
||||
int coorArrSize = 2 * ((totalSize - 1) / threadsPerBlock + 1);
|
||||
int startCoorArr[coorArrSize];
|
||||
int* dStartCoorArr;
|
||||
|
||||
int j = 0, cur = 0, elemsLeftInRow = 0;
|
||||
for (int i = 0; i < numBlobs; ++i) {
|
||||
if (i == 0) {
|
||||
startCoorArr[cur++] = i;
|
||||
startCoorArr[cur++] = j;
|
||||
elemsLeftInRow = 0;
|
||||
}
|
||||
while (j < sizeArr[i]) {
|
||||
j += threadsPerBlock - elemsLeftInRow;
|
||||
if (j < sizeArr[i]) {
|
||||
startCoorArr[cur++] = i;
|
||||
startCoorArr[cur++] = j;
|
||||
elemsLeftInRow = 0;
|
||||
} else {
|
||||
elemsLeftInRow = sizeArr[i] - j + threadsPerBlock;
|
||||
j = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize);
|
||||
cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
// ScaleBlobsCUDAKernelBalanced kernel launch
|
||||
ScaleBlobsCUDAKernelBalanced<T>
|
||||
<<<(totalSize-1)/CAFFE_CUDA_NUM_THREADS+1, CAFFE_CUDA_NUM_THREADS, 0,
|
||||
context_.cuda_stream()>>>(
|
||||
scale_, numBlobs, coorArrSize, dStartCoorArr, dSizeArr, dInputArr,
|
||||
dOutputArr);
|
||||
cudaFree(dStartCoorArr);
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
__global__ void ScaleBlobsCUDAKernelBalanced(
|
||||
const float scale,
|
||||
const int numBlobs,
|
||||
const int coorArrSize,
|
||||
const int* coorArr,
|
||||
const int* sizeArr,
|
||||
T** X,
|
||||
T** Y) {
|
||||
int i = coorArr[2 * blockIdx.x + 1] + threadIdx.x;
|
||||
int curTen = coorArr[2 * blockIdx.x];
|
||||
while (curTen < numBlobs && i >= sizeArr[curTen]) {
|
||||
i -= sizeArr[curTen++];
|
||||
}
|
||||
if (curTen < numBlobs) {
|
||||
Y[curTen][i] = X[curTen][i] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
58
caffe2/operators/scale_blobs_op.h
Normal file
58
caffe2/operators/scale_blobs_op.h
Normal file
@ -0,0 +1,58 @@
|
||||
#ifndef CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|
||||
#define CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class ScaleBlobsOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
template <class... Args>
|
||||
explicit ScaleBlobsOp(Args&&... args)
|
||||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(float, "scale", scale_, 1.0f) {}
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType() {
|
||||
int batchSize = InputSize();
|
||||
|
||||
for (int i = 0; i < batchSize; ++i) {
|
||||
const auto& X = Input(i);
|
||||
auto* Y = Output(i, X.sizes(), at::dtype<T>());
|
||||
math::Scale<float, T, Context>(
|
||||
X.numel(),
|
||||
scale_,
|
||||
X.template data<T>(),
|
||||
Y->template mutable_data<T>(),
|
||||
&context_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
for (int i = 0; i < InputSize(); ++i) {
|
||||
auto& input = this->template Input<Tensor>(i, CPU);
|
||||
auto* output = this->template Output<Tensor>(i, CPU);
|
||||
output->ResizeLike(input);
|
||||
}
|
||||
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
|
||||
}
|
||||
|
||||
private:
|
||||
const float scale_;
|
||||
Tensor blobSizes_;
|
||||
Tensor inputs_;
|
||||
Tensor outputs_;
|
||||
|
||||
Tensor hostBlobSizes_;
|
||||
Tensor hostInputs_;
|
||||
Tensor hostOutputs_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_SCALE_BLOBS_OP_H_
|
64
caffe2/python/operator_test/scale_op_test.py
Normal file
64
caffe2/python/operator_test/scale_op_test.py
Normal file
@ -0,0 +1,64 @@
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, workspace
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.serialized_test.serialized_test_util as serial
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestScaleOps(serial.SerializedTestCase):
|
||||
@serial.given(dim=st.sampled_from([[1, 386, 1], [386, 1, 1],
|
||||
[1, 256, 1], [256, 1, 1],
|
||||
[1024, 256, 1], [1, 1024, 1],
|
||||
[1, 1, 1]]),
|
||||
scale=st.floats(0.0, 10.0),
|
||||
num_tensors=st.integers(1, 10),
|
||||
**hu.gcs)
|
||||
def test_scale_ops(self, dim, scale, num_tensors, gc, dc):
|
||||
in_tensors = []
|
||||
in_tensor_ps = []
|
||||
out_tensors = []
|
||||
out_ref_tensors = []
|
||||
# initialize tensors
|
||||
for i in range(num_tensors):
|
||||
tensor = "X_{}".format(i)
|
||||
X = np.random.rand(*dim).astype(np.float32) - 0.5
|
||||
in_tensors.append(tensor)
|
||||
in_tensor_ps.append(X)
|
||||
out_tensor = "O_{}".format(i)
|
||||
out_tensors.append(out_tensor)
|
||||
workspace.FeedBlob(tensor, X, device_option=gc)
|
||||
|
||||
# run ScaleBlobs operator
|
||||
scale_blobs_op = core.CreateOperator(
|
||||
"ScaleBlobs",
|
||||
in_tensors,
|
||||
out_tensors,
|
||||
scale=scale,
|
||||
)
|
||||
scale_blobs_op.device_option.CopyFrom(gc)
|
||||
workspace.RunOperatorOnce(scale_blobs_op)
|
||||
|
||||
# run Scale op for each tensor and compare with ScaleBlobs
|
||||
for i in range(num_tensors):
|
||||
tensor = "X_{}".format(i)
|
||||
out_ref_tensor = "O_ref_{}".format(i)
|
||||
scale_op = core.CreateOperator(
|
||||
"Scale",
|
||||
[tensor],
|
||||
[out_ref_tensor],
|
||||
scale=scale,
|
||||
)
|
||||
scale_op.device_option.CopyFrom(gc)
|
||||
workspace.RunOperatorOnce(scale_op)
|
||||
o_ref = workspace.FetchBlob(out_ref_tensor)
|
||||
o = workspace.FetchBlob(out_tensors[i])
|
||||
np.testing.assert_allclose(o, o_ref)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Reference in New Issue
Block a user