mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
34 lines
1006 B
Python
34 lines
1006 B
Python
import math
|
|
import torch
|
|
from .Module import Module
|
|
|
|
|
|
class Mul(Module):
|
|
|
|
def __init__(self):
|
|
super(Mul, self).__init__()
|
|
self.weight = torch.Tensor(1)
|
|
self.gradWeight = torch.Tensor(1)
|
|
self.reset()
|
|
|
|
def reset(self, stdv=None):
|
|
if stdv is not None:
|
|
stdv = stdv * math.sqrt(3)
|
|
else:
|
|
stdv = 1. / math.sqrt(self.weight.size(0))
|
|
self.weight.uniform_(-stdv, stdv)
|
|
|
|
def updateOutput(self, input):
|
|
self.output.resize_as_(input).copy_(input)
|
|
self.output.mul_(self.weight[0])
|
|
return self.output
|
|
|
|
def updateGradInput(self, input, gradOutput):
|
|
self.gradInput.resize_as_(input).zero_()
|
|
self.gradInput.add_(self.weight[0], gradOutput)
|
|
return self.gradInput
|
|
|
|
def accGradParameters(self, input, gradOutput, scale=1):
|
|
self.gradWeight[0] = (self.gradWeight[0] +
|
|
scale * input.contiguous().view(-1).dot(gradOutput.contiguous().view(-1)))
|