mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Merge branch 'master' of https://github.com/Yangqing/caffe2
This commit is contained in:
@ -9,8 +9,7 @@ __global__ void LRNFillScaleNCHW(const int nthreads, const T* in,
|
||||
const int num, const int channels, const int height,
|
||||
const int width, const int size, const T alpha_over_size,
|
||||
const T bias, T* scale) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < nthreads) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// find out the local offset
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
@ -48,6 +47,9 @@ __global__ void LRNFillScaleNCHW(const int nthreads, const T* in,
|
||||
scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size;
|
||||
++head;
|
||||
}
|
||||
// recover the pointers for the next loop.
|
||||
in -= offset;
|
||||
scale -= offset;
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,8 +58,7 @@ __global__ void LRNFillScaleNHWC(const int nthreads, const T* in,
|
||||
const int num, const int height, const int width,
|
||||
const int channels, const int size, const T alpha_over_size,
|
||||
const T bias, T* scale) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < nthreads) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int c = index % channels;
|
||||
int pre_pad = (size - 1) / 2;
|
||||
scale[index] = 0;
|
||||
@ -76,8 +77,7 @@ __global__ void LRNFillScaleNHWC(const int nthreads, const T* in,
|
||||
template <typename T>
|
||||
__global__ void LRNComputeOutput(const int nthreads, const T* in,
|
||||
const T* scale, const T negative_beta, T* out) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < nthreads) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
out[index] = in[index] * pow(scale[index], negative_beta);
|
||||
}
|
||||
}
|
||||
@ -89,8 +89,7 @@ __global__ void LRNComputeDiffNCHW(const int nthreads, const T* bottom_data,
|
||||
const int width, const int size, const T negative_beta,
|
||||
const T cache_ratio,
|
||||
T* bottom_diff) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < nthreads) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// find out the local offset
|
||||
int w = index % width;
|
||||
int h = (index / width) % height;
|
||||
@ -141,6 +140,12 @@ __global__ void LRNComputeDiffNCHW(const int nthreads, const T* bottom_data,
|
||||
bottom_data[(head - post_pad) * step] * accum_ratio;
|
||||
++head;
|
||||
}
|
||||
// recover pointer for next iteration.
|
||||
bottom_data -= offset;
|
||||
top_data -= offset;
|
||||
scale -= offset;
|
||||
top_diff -= offset;
|
||||
bottom_diff -= offset;
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,8 +158,7 @@ __global__ void LRNComputeDiffNHWC(const int nthreads, const T* bottom_data,
|
||||
const int num, const int height, const int width, const int channels,
|
||||
const int size, const T negative_beta, const T cache_ratio,
|
||||
T* bottom_diff) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (index < nthreads) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
// find out the local channel offset
|
||||
int c = index % channels;
|
||||
int pre_pad = size / 2;
|
||||
|
Reference in New Issue
Block a user