mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
* Simplify PReLU binding - Remove internal buffers from function signature - Compute nOutputPlane internally * Fix legacy PReLU
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
import torch
|
|
from .Module import Module
|
|
from .utils import clear
|
|
|
|
|
|
class PReLU(Module):
|
|
|
|
def __init__(self, nOutputPlane=0):
|
|
super(PReLU, self).__init__()
|
|
# if no argument provided, use shared model (weight is scalar)
|
|
self.nOutputPlane = nOutputPlane
|
|
self.weight = torch.Tensor(nOutputPlane or 1).fill_(0.25)
|
|
self.gradWeight = torch.Tensor(nOutputPlane or 1)
|
|
|
|
def updateOutput(self, input):
|
|
self._backend.PReLU_updateOutput(
|
|
self._backend.library_state,
|
|
input,
|
|
self.output,
|
|
self.weight
|
|
)
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, gradOutput):
|
|
self._backend.PReLU_updateGradInput(
|
|
self._backend.library_state,
|
|
input,
|
|
gradOutput,
|
|
self.gradInput,
|
|
self.weight
|
|
)
|
|
return self.gradInput
|
|
|
|
def accGradParameters(self, input, gradOutput, scale=1):
|
|
self._backend.PReLU_accGradParameters(
|
|
self._backend.library_state,
|
|
input,
|
|
gradOutput,
|
|
self.gradInput,
|
|
self.weight,
|
|
self.gradWeight,
|
|
scale
|
|
)
|
|
return self.gradWeight
|
|
|
|
def clearState(self):
|
|
clear(self, 'gradWeightBuf', 'gradWeightBuf2')
|
|
return super(PReLU, self).clearState()
|