mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
added cuda convolutionMap
This commit is contained in:
417
THCTensorConv.cu
417
THCTensorConv.cu
@ -903,3 +903,420 @@ TH_API void THCudaTensor_conv2DRevgerm(THCudaTensor *output, float beta, float a
|
||||
THError("aborting");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////
|
||||
///// ConvolutionMap
|
||||
/*
|
||||
* Description:
|
||||
* base conv2D routine: 3D input, 3D output, 4D kernel
|
||||
*
|
||||
* - all chunks of data should be contiguous
|
||||
* - the swapkernel flag can be used to generate a conv2 instead of xcorr2
|
||||
* - the templated kernel size is useful to generate code that's 2x faster
|
||||
* but can be set to 0 to allow arbitrary kernel sizes
|
||||
* ---- the table should have the first dim with the outputs, each output
|
||||
* ---- should have a fanin set of inputs contiguously
|
||||
*/
|
||||
template <bool swapkernel, int T_kernel_h, int T_kernel_w>
|
||||
__global__ void conv2mapgeneric(float *input, float *kernel, float *output,
|
||||
int input_n, int input_h, int input_w,
|
||||
int kernel_n, int kernel_h, int kernel_w,
|
||||
int stride_h, int stride_w,
|
||||
float *table, int fanin)
|
||||
{
|
||||
// output dimensions
|
||||
int output_h = (input_h - kernel_h) / stride_h + 1;
|
||||
int output_w = (input_w - kernel_w) / stride_w + 1;
|
||||
|
||||
// xcorr or conv
|
||||
int koffset = swapkernel ? kernel_w*kernel_h-1 : 0;
|
||||
|
||||
// nb outputs
|
||||
// int output_n = kernel_n / fanin;
|
||||
|
||||
// generate offsets according to block/thread ids
|
||||
int xx_start = threadIdx.x;
|
||||
int xx_end = output_w;
|
||||
int xx_step = blockDim.x;
|
||||
|
||||
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
int yy_end = output_h;
|
||||
int yy_step = blockDim.y*gridDim.y;
|
||||
|
||||
int oo_start = blockIdx.x;
|
||||
int oo_end = oo_start+1;
|
||||
|
||||
int table_start = blockIdx.x * (fanin * 2);
|
||||
int table_end = table_start + (fanin * 2);
|
||||
|
||||
// nb threads, unique thread id
|
||||
int tid = blockDim.x*blockDim.y*threadIdx.z
|
||||
+ blockDim.x * threadIdx.y + threadIdx.x;
|
||||
int nthreads = blockDim.x * blockDim.y * blockDim.z;
|
||||
|
||||
// iterators
|
||||
int oo, ii, xx, yy, kx, ky, kk;
|
||||
|
||||
// do the kernels fit in shared mem ?
|
||||
if (kernel_w*kernel_h*kernel_n <= CUDA_SHARED_MEM_SIZE) {
|
||||
// put the kernel in shared memory
|
||||
__shared__ float shared_kernel[CUDA_SHARED_MEM_SIZE];
|
||||
|
||||
// first thread of each block does the copy
|
||||
for (kk = tid; kk < kernel_w*kernel_h*kernel_n; kk += nthreads) {
|
||||
shared_kernel[kk] = kernel[kk];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// templated kernel size
|
||||
if ((T_kernel_w > 0) && (T_kernel_h > 0)) {
|
||||
// unrolled convolution loop
|
||||
for(oo = oo_start; oo < oo_end; oo++) {
|
||||
for (ii = table_start; ii < table_end; ii = ii + 2) {
|
||||
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
|
||||
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
|
||||
// Dot product in two dimensions... (between input image and the mask)
|
||||
float *input_p = input + ((long)table[ii])*input_h*input_w
|
||||
+ yy*stride_h*input_w + xx*stride_w;
|
||||
float *output_p = output + oo*output_h*output_w + yy*output_w + xx;
|
||||
float *kernel_p = shared_kernel
|
||||
+ ((long)table[ii + 1]) *kernel_w*kernel_h + koffset;
|
||||
float sum = 0;
|
||||
if (swapkernel) {
|
||||
#pragma unroll
|
||||
for(ky = 0; ky < T_kernel_h; ky++) {
|
||||
#pragma unroll
|
||||
for(kx = 0; kx < T_kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p--);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for(ky = 0; ky < T_kernel_h; ky++) {
|
||||
#pragma unroll
|
||||
for(kx = 0; kx < T_kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p++);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
}
|
||||
*output_p += sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// default convolution loop
|
||||
for(oo = oo_start; oo < oo_end; oo++) {
|
||||
for (ii = table_start; ii < table_end; ii++) {
|
||||
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
|
||||
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
|
||||
// Dot product in two dims (between input image and the mask)
|
||||
float *input_p = input + ((long)table[ii])*input_h*input_w
|
||||
+ yy*stride_h*input_w + xx*stride_w;
|
||||
float *output_p = output + oo*output_h*output_w + yy*output_w
|
||||
+ xx;
|
||||
float *kernel_p = shared_kernel
|
||||
+ (((long)table[ii]) % fanin) * kernel_w * kernel_h + koffset;
|
||||
float sum = 0;
|
||||
if (swapkernel) {
|
||||
for(ky = 0; ky < kernel_h; ky++) {
|
||||
#pragma unroll 5
|
||||
for(kx = 0; kx < kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p--);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
} else {
|
||||
for(ky = 0; ky < kernel_h; ky++) {
|
||||
#pragma unroll 5
|
||||
for(kx = 0; kx < kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p++);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
}
|
||||
*output_p += sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else { // not enough shared mem for kernels, simply stream them
|
||||
|
||||
// convolution loop
|
||||
for(oo = oo_start; oo < oo_end; oo++) {
|
||||
for (ii = table_start; ii < table_end; ii = ii + 2) {
|
||||
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
|
||||
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
|
||||
// Dot product in two dimensions... (between input image and the mask)
|
||||
float *input_p = input + ((long)table[ii])*input_h*input_w
|
||||
+ yy*stride_h*input_w + xx*stride_w;
|
||||
float *output_p = output + oo*output_h*output_w + yy*output_w + xx;
|
||||
float *kernel_p = kernel + ((long)table[ii + 1]) *kernel_w*kernel_h + koffset;
|
||||
float sum = 0;
|
||||
if (swapkernel) {
|
||||
for(ky = 0; ky < kernel_h; ky++) {
|
||||
#pragma unroll 5
|
||||
for(kx = 0; kx < kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p--);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
} else {
|
||||
for(ky = 0; ky < kernel_h; ky++) {
|
||||
#pragma unroll 5
|
||||
for(kx = 0; kx < kernel_w; kx++) {
|
||||
sum += input_p[kx]*(*kernel_p++);
|
||||
}
|
||||
input_p += input_w;
|
||||
}
|
||||
}
|
||||
*output_p += sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* API-compatible with THRealTensor_conv2Dmv
|
||||
* 3D input, 4D kernel, 3D output
|
||||
* matrix vector product like: y <- Ax + beta*y
|
||||
*/
|
||||
TH_API void THCudaTensor_conv2Dmap(THCudaTensor *output, THCudaTensor *input,
|
||||
THCudaTensor *kernel, long stride_x, long stride_y
|
||||
, THCudaTensor *table, long fanin)
|
||||
{
|
||||
long nInputPlane, nInputRows, nInputCols;
|
||||
long nKernelRows, nKernelCols;
|
||||
long nOutputPlane, nOutputRows, nOutputCols;
|
||||
|
||||
THArgCheck(kernel->nDimension == 4 , 4, "kernel: 4D Tensor expected");
|
||||
THArgCheck(stride_x >= 1, 5, "Stride should be a positive integer");
|
||||
THArgCheck(stride_y >= 1, 6, "Stride should be a positive integer");
|
||||
|
||||
input = THCudaTensor_newContiguous(input);
|
||||
kernel = THCudaTensor_newContiguous(kernel);
|
||||
|
||||
nInputPlane = input->size[0];
|
||||
nInputRows = input->size[1];
|
||||
nInputCols = input->size[2];
|
||||
|
||||
nKernelRows = kernel->size[2];
|
||||
nKernelCols = kernel->size[3];
|
||||
nOutputPlane = kernel->size[0];
|
||||
THArgCheck(kernel->size[1] == nInputPlane, 2, "invalid number of input planes");
|
||||
|
||||
THArgCheck( (nInputRows >= nKernelRows && nInputCols >= nKernelCols), 2,
|
||||
"conv2Dmap : Input image is smaller than kernel");
|
||||
|
||||
// output dims
|
||||
nOutputRows = (nInputRows - nKernelRows) / stride_x + 1;
|
||||
nOutputCols = (nInputCols - nKernelCols) / stride_y + 1;
|
||||
|
||||
// long nelem = THCudaTensor_nElement(output);
|
||||
THCudaTensor_resize3d(output, nOutputPlane, nOutputRows, nOutputCols);
|
||||
|
||||
float *input_data = THCudaTensor_data(input);
|
||||
float *kernel_data = THCudaTensor_data(kernel);
|
||||
float *output_data = THCudaTensor_data(output);
|
||||
float *table_data = THCudaTensor_data(table);
|
||||
|
||||
// set the number of blocks and threads
|
||||
int nthreads_x = 32;
|
||||
int nthreads_y = 8;
|
||||
int block_height = floor(16 / nOutputPlane);
|
||||
if (block_height < 1)
|
||||
block_height = 1;
|
||||
dim3 blocks(nOutputPlane,block_height);
|
||||
dim3 threads(nthreads_x,nthreads_y);
|
||||
// sync any previous kernel exec
|
||||
cudaDeviceSynchronize();
|
||||
if ((nKernelCols == 3) && (nKernelRows == 3))
|
||||
conv2mapgeneric <false, 3, 3> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 5) && (nKernelRows == 5))
|
||||
conv2mapgeneric <false, 5, 5> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 7) && (nKernelRows == 7))
|
||||
conv2mapgeneric <false, 7, 7> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 9) && (nKernelRows == 9))
|
||||
conv2mapgeneric <false, 9, 9> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 11) && (nKernelRows == 11))
|
||||
conv2mapgeneric <false, 11, 11> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 13) && (nKernelRows == 13))
|
||||
conv2mapgeneric <false, 13, 13> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows, nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 4) && (nKernelRows == 4))
|
||||
conv2mapgeneric <false, 4, 4> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 6) && (nKernelRows == 6))
|
||||
conv2mapgeneric <false, 6, 6> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 8) && (nKernelRows == 8))
|
||||
conv2mapgeneric <false, 8, 8> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 10) && (nKernelRows == 10))
|
||||
conv2mapgeneric <false, 10, 10> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else if ((nKernelCols == 12) && (nKernelRows == 12))
|
||||
conv2mapgeneric <false, 12, 12> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
else
|
||||
conv2mapgeneric <false, 0 , 0> <<<blocks, threads>>> (input_data,
|
||||
kernel_data,
|
||||
output_data,
|
||||
nInputPlane,
|
||||
nInputRows,
|
||||
nInputCols,
|
||||
nOutputPlane*fanin,
|
||||
nKernelRows,
|
||||
nKernelCols,
|
||||
stride_x,
|
||||
stride_y,
|
||||
table_data,
|
||||
fanin);
|
||||
// sync & clean
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
THCudaTensor_free(input);
|
||||
THCudaTensor_free(kernel);
|
||||
THCudaTensor_free(table);
|
||||
|
||||
// check for errors
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("error in conv2Dmap: %s\n", cudaGetErrorString(err));
|
||||
THError("aborting");
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,4 +15,8 @@ TH_API void THCudaTensor_conv2DRevgerm(THCudaTensor *output, float beta, float a
|
||||
THCudaTensor *input, THCudaTensor *kernel,
|
||||
long srow, long scol);
|
||||
|
||||
TH_API void THCudaTensor_conv2Dmap(THCudaTensor *output, THCudaTensor *input,
|
||||
THCudaTensor *kernel, long stride_x, long stride_y
|
||||
, THCudaTensor *table, long fanin);
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user