Files
pytorch/torch/legacy/nn/Concat.py
Luke Yeager e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00

107 lines
3.9 KiB
Python

import torch
from .Container import Container
class Concat(Container):
def __init__(self, dimension):
super(Concat, self).__init__()
self.outputSize = torch.Size()
self.dimension = dimension
def updateOutput(self, input):
outs = []
for i in range(len(self.modules)):
currentOutput = self.modules[i].updateOutput(input)
outs.append(currentOutput)
if i == 0:
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
self.outputSize = torch.Size(size)
self.output.resize_(self.outputSize)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = outs[i]
self.output.narrow(self.dimension, offset, currentOutput.size(self.dimension)).copy_(currentOutput)
offset = offset + currentOutput.size(self.dimension)
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
currentGradInput = module.updateGradInput(input, gradOutput.narrow(
self.dimension, offset, currentOutput.size(self.dimension)))
# if the module does not produce a gradInput (for example first layer),: ignore it and move on.
if currentGradInput:
if i == 0:
self.gradInput.copy_(currentGradInput)
else:
self.gradInput.add_(currentGradInput)
offset = offset + currentOutput.size(self.dimension)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
module.accGradParameters(
input,
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
scale)
offset = offset + currentOutput.size(self.dimension)
def backward(self, input, gradOutput, scale=1):
self.gradInput.resize_as_(input)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
currentGradInput = module.backward(input, gradOutput.narrow(
self.dimension, offset, currentOutput.size(self.dimension)), scale)
# if the module.es not produce a gradInput (for example first layer),: ignore it and move on.
if currentGradInput is not None:
if i == 0:
self.gradInput.copy_(currentGradInput)
else:
self.gradInput.add_(currentGradInput)
offset = offset + currentOutput.size(self.dimension)
return self.gradInput
def accUpdateGradParameters(self, input, gradOutput, lr):
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
module.accUpdateGradParameters(
input,
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
lr)
offset = offset + currentOutput.size(self.dimension)
def __tostring__(self):
tab = ' '
line = '\n'
next = ' |`-> '
ext = ' | '
extlast = ' '
last = ' +. -> '
res = torch.type(self)
res += ' {' + line + tab + 'input'
for i in range(len(self.modules)):
if i == len(self.modules) - 1:
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + extlast)
else:
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + ext)
res += line + tab + last + 'output'
res += line + '}'
return res