Files
pytorch/torch/legacy/nn/Mul.py
2017-06-16 02:15:33 +02:00

34 lines
1006 B
Python

import math
import torch
from .Module import Module
class Mul(Module):
def __init__(self):
super(Mul, self).__init__()
self.weight = torch.Tensor(1)
self.gradWeight = torch.Tensor(1)
self.reset()
def reset(self, stdv=None):
if stdv is not None:
stdv = stdv * math.sqrt(3)
else:
stdv = 1. / math.sqrt(self.weight.size(0))
self.weight.uniform_(-stdv, stdv)
def updateOutput(self, input):
self.output.resize_as_(input).copy_(input)
self.output.mul_(self.weight[0])
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input).zero_()
self.gradInput.add_(self.weight[0], gradOutput)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
self.gradWeight[0] = (self.gradWeight[0] +
scale * input.contiguous().view(-1).dot(gradOutput.contiguous().view(-1)))