mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Here's the command I used to invoke autopep8 (in parallel!): git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i Several rules are ignored in setup.cfg. The goal is to let autopep8 handle everything which it can handle safely, and to disable any rules which are tricky or controversial to address. We may want to come back and re-enable some of these rules later, but I'm trying to make this patch as safe as possible. Also configures flake8 to match pep8's behavior. Also configures TravisCI to check the whole project for lint.
67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
import torch
|
|
from .Module import Module
|
|
from .utils import clear
|
|
from functools import wraps
|
|
import sys
|
|
|
|
|
|
class Container(Module):
|
|
|
|
def __init__(self, *args):
|
|
super(Container, self).__init__(*args)
|
|
self.modules = []
|
|
|
|
def add(self, module):
|
|
self.modules.append(module)
|
|
return self
|
|
|
|
def get(self, index):
|
|
return self.modules[index]
|
|
|
|
def size(self):
|
|
return len(self.modules)
|
|
|
|
def applyToModules(self, func):
|
|
for module in self.modules:
|
|
func(module)
|
|
|
|
def zeroGradParameters(self):
|
|
self.applyToModules(lambda m: m.zeroGradParameters())
|
|
|
|
def updateParameters(self, learningRate):
|
|
self.applyToModules(lambda m: m.updateParameters(learningRate))
|
|
|
|
def training(self):
|
|
self.applyToModules(lambda m: m.training())
|
|
super(Container, self).training()
|
|
|
|
def evaluate(self, ):
|
|
self.applyToModules(lambda m: m.evaluate())
|
|
super(Container, self).evaluate()
|
|
|
|
def share(self, mlp, *args):
|
|
for module, other_module in zip(self.modules, mlp.modules):
|
|
module.share(other_module, *args)
|
|
|
|
def reset(self, stdv=None):
|
|
self.applyToModules(lambda m: m.reset(stdv))
|
|
|
|
def parameters(self):
|
|
w = []
|
|
gw = []
|
|
for module in self.modules:
|
|
mparam = module.parameters()
|
|
if mparam is not None:
|
|
w.extend(mparam[0])
|
|
gw.extend(mparam[1])
|
|
if not w:
|
|
return
|
|
return w, gw
|
|
|
|
def clearState(self):
|
|
clear('output')
|
|
clear('gradInput')
|
|
for module in self.modules:
|
|
module.clearState()
|
|
return self
|