mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Base for nn conversion
This commit is contained in:
2
setup.py
2
setup.py
@ -83,5 +83,5 @@ C = Extension("torch._C",
|
|||||||
|
|
||||||
setup(name="torch", version="0.1",
|
setup(name="torch", version="0.1",
|
||||||
ext_modules=[C],
|
ext_modules=[C],
|
||||||
packages=['torch', 'torch.cuda', 'torch.optim'],
|
packages=['torch', 'torch.cuda', 'torch.optim', 'torch.legacy', 'torch.legacy.nn'],
|
||||||
)
|
)
|
||||||
|
52
tools/convert.vim
Normal file
52
tools/convert.vim
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"Slightly adjust indentation
|
||||||
|
%s/^ / /g
|
||||||
|
|
||||||
|
" # -> len
|
||||||
|
%s/#\(\S*\) /len(\1)/g
|
||||||
|
|
||||||
|
" for loops
|
||||||
|
%s/for\( \)\{-\}\(\S*\)\( \)\{-\}=\( \)\{-\}\(\S*\),\( \)\{-\}\(\S*\)\( \)\{-\}do/for \2 in range(\5, \7+1)/g
|
||||||
|
|
||||||
|
" Change comments
|
||||||
|
%s/--\[\[/"""/g
|
||||||
|
%s/]]/"""/g
|
||||||
|
%s/--/#/g
|
||||||
|
|
||||||
|
" Add spacing between commas
|
||||||
|
%s/\(\S\),\(\S\)/\1, \2/g
|
||||||
|
|
||||||
|
%s/local //g
|
||||||
|
%s/ then/:/g
|
||||||
|
%s/ do/:/g
|
||||||
|
%s/end//g
|
||||||
|
%s/elseif/elif/g
|
||||||
|
%s/else/else:/g
|
||||||
|
%s/true/True/g
|
||||||
|
%s/false/False/g
|
||||||
|
%s/\~=/!=/g
|
||||||
|
%s/math\.min/min/g
|
||||||
|
%s/math\.max/max/g
|
||||||
|
%s/math\.abs/abs/g
|
||||||
|
|
||||||
|
|
||||||
|
%s/__init/__init__/g
|
||||||
|
|
||||||
|
" Rewrite function declarations
|
||||||
|
%s/function \w*:\(\w*\)/ def \1/g
|
||||||
|
%s/def \(.*\)$/def \1:/g
|
||||||
|
|
||||||
|
" class declaration
|
||||||
|
%s/\(\w*\), parent = torch\.class.*$/import torch\rfrom torch.legacy import nn\r\rclass \1(nn.Module):/g
|
||||||
|
|
||||||
|
%s/input\.THNN/self._backend/g
|
||||||
|
%s/\(self\.backend\w*$\)/\1\r self._backend.library_state,/g
|
||||||
|
%s/def \(\w*\)(/def \1(self, /g
|
||||||
|
|
||||||
|
%s/__init__(self)/__init__()/g
|
||||||
|
|
||||||
|
%s/:\(\S\)/.\1/g
|
||||||
|
|
||||||
|
%s/\.cdata()//g
|
||||||
|
%s/THNN\.optionalTensor(\(.*\))/\1/g
|
||||||
|
|
||||||
|
%s/parent\./super(##, self)./g
|
@ -104,6 +104,12 @@
|
|||||||
set -> self
|
set -> self
|
||||||
- self
|
- self
|
||||||
- THTensor source
|
- THTensor source
|
||||||
|
setStorage -> self
|
||||||
|
- self
|
||||||
|
- CONSTANT NULL
|
||||||
|
- CONSTANT 0
|
||||||
|
- CONSTANT NULL
|
||||||
|
- CONSTANT NULL
|
||||||
setStorage -> self
|
setStorage -> self
|
||||||
- self
|
- self
|
||||||
- THStorage sourceStorage
|
- THStorage sourceStorage
|
||||||
|
0
torch/legacy/__init__.py
Normal file
0
torch/legacy/__init__.py
Normal file
24
torch/legacy/nn/Abs.py
Normal file
24
torch/legacy/nn/Abs.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
from torch.legacy import nn
|
||||||
|
|
||||||
|
class Abs(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Abs, self).__init__()
|
||||||
|
|
||||||
|
def updateOutput(self, input):
|
||||||
|
self._backend.Abs_updateOutput(
|
||||||
|
self._backend.library_state,
|
||||||
|
input,
|
||||||
|
self.output
|
||||||
|
)
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def updateGradInput(self, input, gradOutput):
|
||||||
|
self._backend.Abs_updateGradInput(
|
||||||
|
self._backend.library_state,
|
||||||
|
input,
|
||||||
|
gradOutput,
|
||||||
|
self.gradInput
|
||||||
|
)
|
||||||
|
return self.gradInput
|
||||||
|
|
35
torch/legacy/nn/AbsCriterion.py
Normal file
35
torch/legacy/nn/AbsCriterion.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
from torch.legacy import nn
|
||||||
|
|
||||||
|
class AbsCriterion(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, sizeAverage):
|
||||||
|
super(AbsCriterion, self).__init__()
|
||||||
|
if sizeAverage != nil:
|
||||||
|
self.sizeAverage = sizeAverage
|
||||||
|
else:
|
||||||
|
self.sizeAverage = True
|
||||||
|
|
||||||
|
def updateOutput(self, input, target):
|
||||||
|
self.output_tensor = self.output_tensor or input.new(1)
|
||||||
|
self._backend.AbsCriterion_updateOutput(
|
||||||
|
self._backend.library_state,
|
||||||
|
input._cdata,
|
||||||
|
target._cdata,
|
||||||
|
self.output_tensor._cdata,
|
||||||
|
self.sizeAverage
|
||||||
|
)
|
||||||
|
self.output = self.output_tensor[1]
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
|
||||||
|
def updateGradInput(self, input, target):
|
||||||
|
self._backend.AbsCriterion_updateGradInput(
|
||||||
|
self._backend.library_state,
|
||||||
|
input._cdata,
|
||||||
|
target._cdata,
|
||||||
|
self.gradInput._cdata,
|
||||||
|
self.sizeAverage
|
||||||
|
)
|
||||||
|
return self.gradInput
|
||||||
|
|
298
torch/legacy/nn/Module.py
Normal file
298
torch/legacy/nn/Module.py
Normal file
@ -0,0 +1,298 @@
|
|||||||
|
import torch
|
||||||
|
from torch.legacy import nn
|
||||||
|
|
||||||
|
class Module(object):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.gradInput = torch.Tensor()
|
||||||
|
self.output = torch.Tensor()
|
||||||
|
self._type = self.output.type()
|
||||||
|
self._backend = nn._backends.THNNDoubleBackend
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
if self.weight and self.bias:
|
||||||
|
return [self.weight, self.bias], [self.gradWeight, self.gradBias]
|
||||||
|
elif self.weight:
|
||||||
|
return [self.weight], [self.gradWeight]
|
||||||
|
elif self.bias:
|
||||||
|
return [self.bias], [self.gradBias]
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
def updateOutput(self, input):
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return self.updateOutput(input)
|
||||||
|
|
||||||
|
def backward(self, input, gradOutput, scale=1):
|
||||||
|
self.updateGradInput(input, gradOutput)
|
||||||
|
self.accGradParameters(input, gradOutput, scale)
|
||||||
|
return self.gradInput
|
||||||
|
|
||||||
|
|
||||||
|
def backwardUpdate(self, input, gradOutput, lr):
|
||||||
|
self.updateGradInput(input, gradOutput)
|
||||||
|
self.accUpdateGradParameters(input, gradOutput, lr)
|
||||||
|
return self.gradInput
|
||||||
|
|
||||||
|
|
||||||
|
def updateGradInput(self, input, gradOutput):
|
||||||
|
return self.gradInput
|
||||||
|
|
||||||
|
def accGradParameters(self, input, gradOutput, scale=1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def accUpdateGradParameters(self, input, gradOutput, lr):
|
||||||
|
gradWeight = self.gradWeight
|
||||||
|
gradBias = self.gradBias
|
||||||
|
self.gradWeight = self.weight
|
||||||
|
self.gradBias = self.bias
|
||||||
|
self.accGradParameters(input, gradOutput, -lr)
|
||||||
|
self.gradWeight = gradWeight
|
||||||
|
self.gradBias = gradBias
|
||||||
|
|
||||||
|
|
||||||
|
def sharedAccUpdateGradParameters(self, input, gradOutput, lr):
|
||||||
|
if self.parameters():
|
||||||
|
self.zeroGradParameters()
|
||||||
|
self.accGradParameters(input, gradOutput, 1)
|
||||||
|
self.updateParameters(lr)
|
||||||
|
|
||||||
|
def zeroGradParameters(self):
|
||||||
|
_, gradParams = self.parameters()
|
||||||
|
if gradParams:
|
||||||
|
for grad in gradParams:
|
||||||
|
grad.zero()
|
||||||
|
|
||||||
|
def updateParameters(self, learningRate):
|
||||||
|
params, gradParams = self.parameters()
|
||||||
|
if params:
|
||||||
|
for p, gp in zip(params, gradParams):
|
||||||
|
p.add(-learningRate, gp)
|
||||||
|
|
||||||
|
def training(self):
|
||||||
|
self.train = True
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
self.train = False
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
def share(self, mlp, *arg):
|
||||||
|
for i, v in ipairs(arg):
|
||||||
|
if self[v] != nil:
|
||||||
|
self[v].set(mlp[v])
|
||||||
|
self.accUpdateGradParameters = self.sharedAccUpdateGradParameters
|
||||||
|
mlp.accUpdateGradParameters = mlp.sharedAccUpdateGradParameters
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clone(self, *arg):
|
||||||
|
f = torch.MemoryFile("rw").binary()
|
||||||
|
f.writeObject(self)
|
||||||
|
f.seek(1)
|
||||||
|
clone = f.readObject()
|
||||||
|
f.close()
|
||||||
|
if len(arg) > 0:
|
||||||
|
clone.share(self, *arg)
|
||||||
|
return clone
|
||||||
|
|
||||||
|
def type(self, type, tensorCache):
|
||||||
|
if not type:
|
||||||
|
return self._type
|
||||||
|
|
||||||
|
tensorCache = tensorCache or {}
|
||||||
|
|
||||||
|
# find all tensors and convert them
|
||||||
|
for key, param in pairs(self):
|
||||||
|
self[key] = nn.utils.recursiveType(param, type, tensorCache)
|
||||||
|
|
||||||
|
self._type = type
|
||||||
|
return self
|
||||||
|
|
||||||
|
def float(self, *args):
|
||||||
|
return self.type('torch.FloatTensor', *args)
|
||||||
|
|
||||||
|
def double(self, *args):
|
||||||
|
return self.type('torch.DoubleTensor', *args)
|
||||||
|
|
||||||
|
def cuda(self, *args):
|
||||||
|
return self.type('torch.CudaTensor', *args)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def write(self, f):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def read(self, f):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# This function is not easy to understand. It works as follows:
|
||||||
|
#
|
||||||
|
# - gather all parameter tensors for this module (and children);
|
||||||
|
# count all parameter values (floats)
|
||||||
|
# - create one ginormous memory area (Storage object) with room for all
|
||||||
|
# parameters
|
||||||
|
# - remap each parameter tensor to point to an area within the ginormous
|
||||||
|
# Storage, and copy it there
|
||||||
|
#
|
||||||
|
# It has the effect of making all parameters point to the same memory area,
|
||||||
|
# which is: returned.
|
||||||
|
#
|
||||||
|
# The purpose is to allow operations over all parameters (such as momentum
|
||||||
|
# updates and serialization), but it assumes that all parameters are of
|
||||||
|
# the same type (and, in the case of CUDA, on the same device), which
|
||||||
|
# is not always True. Use for_each() to iterate over this module and
|
||||||
|
# children instead.
|
||||||
|
#
|
||||||
|
# Module._flattenTensorBuffer can be used by other packages (e.g. cunn)
|
||||||
|
# to specify the type of temporary buffers. For example, the temporary
|
||||||
|
# buffers for CudaTensor could be FloatTensor, to avoid GPU memory usage.
|
||||||
|
#
|
||||||
|
# TODO: This logically belongs to torch.Tensor, not nn.
|
||||||
|
_flattenTensorBuffer = {}
|
||||||
|
def _flatten(self, parameters=[]):
|
||||||
|
|
||||||
|
# returns True if tensor occupies a contiguous region of memory (no holes)
|
||||||
|
def isCompact(tensor):
|
||||||
|
# TODO: wut, does it really need to create this tensor?
|
||||||
|
# isn't it enough to check if strides == size.cumprod(0)?
|
||||||
|
sortedStride, perm = torch.sort(torch.LongTensor(tensor.nDimension()).set(tensor.stride()), 0, True)
|
||||||
|
sortedSize = torch.LongTensor(tensor.nDimension()).set(tensor.size()).index(1, perm)
|
||||||
|
nRealDim = torch.clamp(sortedStride, 0, 1).sum()
|
||||||
|
sortedStride = sortedStride.narrow(1, 1, nRealDim).clone()
|
||||||
|
sortedSize = sortedSize.narrow(1, 1, nRealDim).clone()
|
||||||
|
t = tensor.new().set(tensor.storage(), 1,
|
||||||
|
sortedSize.storage(),
|
||||||
|
sortedStride.storage())
|
||||||
|
return t.isContiguous()
|
||||||
|
|
||||||
|
if not parameters:
|
||||||
|
return torch.Tensor()
|
||||||
|
|
||||||
|
Tensor = parameters[0].new
|
||||||
|
BufferTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor
|
||||||
|
|
||||||
|
# 1. construct the set of all unique storages referenced by parameter tensors
|
||||||
|
storages = {}
|
||||||
|
num_parameters = 0
|
||||||
|
parameterMeta = []
|
||||||
|
for i, param in enumerate(parameters):
|
||||||
|
storage = param.storage()
|
||||||
|
key = storage._cdata
|
||||||
|
|
||||||
|
if not storages[key]:
|
||||||
|
storages[key] = (storage, num_parameters)
|
||||||
|
num_parameters = num_parameters + storage.size()
|
||||||
|
|
||||||
|
|
||||||
|
parameterMeta[i] = {
|
||||||
|
'storageOffset': param.storageOffset() + storages[key][1],
|
||||||
|
'size' : param.size(),
|
||||||
|
'stride' : param.stride()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 2. construct a single tensor that will hold all the parameters
|
||||||
|
flatParameters = BufferTensor(num_parameters).zero()
|
||||||
|
|
||||||
|
# 3. determine if there are elements in the storage that none of the
|
||||||
|
# parameter tensors reference ('holes')
|
||||||
|
tensorsCompact = True
|
||||||
|
for meta in parameterMeta:
|
||||||
|
# TODO: reuse one Tensor
|
||||||
|
tmp = BufferTensor().set(flatParameters.storage(), meta.storageOffset, meta.size, meta.stride)
|
||||||
|
tmp.fill(1)
|
||||||
|
tensorsCompact = tensorsCompact and isCompact(tmp)
|
||||||
|
|
||||||
|
maskParameters = flatParameters.byte().clone()
|
||||||
|
compactOffsets = flatParameters.long().cumsum(1)
|
||||||
|
used_parameters = compactOffsets[-1]
|
||||||
|
|
||||||
|
# 4. copy storages into the flattened parameter tensor
|
||||||
|
for storageAndOffset in storages.values():
|
||||||
|
storage, offset = storageAndOffset
|
||||||
|
# TODO: reuse Tensor
|
||||||
|
flatParameters[slice(offset, offset+storage.size())].copy(Tensor().set(storage))
|
||||||
|
|
||||||
|
# 5. allow garbage collection
|
||||||
|
storages = None
|
||||||
|
for param in parameters:
|
||||||
|
param.set()
|
||||||
|
|
||||||
|
# 6. compact the flattened parameters if there were holes
|
||||||
|
if used_parameters != num_parameters:
|
||||||
|
assert tensorsCompact
|
||||||
|
|
||||||
|
flatParameters = BufferTensor(used_parameters).copy(
|
||||||
|
flatParameters.maskedSelect(maskParameters))
|
||||||
|
for meta in parameterMeta:
|
||||||
|
meta['storageOffset'] = compactOffsets[meta['storageOffset']]
|
||||||
|
|
||||||
|
if BufferTensor != Tensor:
|
||||||
|
flatParameters = Tensor(flatParameters.nElement()).copy(flatParameters)
|
||||||
|
|
||||||
|
# 7. fix up the parameter tensors to point at the flattened parameters
|
||||||
|
for param, meta in zip(parameters, parameterMeta):
|
||||||
|
param.set(flatParameters.storage(),
|
||||||
|
meta['storageOffset'],
|
||||||
|
meta['size'],
|
||||||
|
meta['stride'])
|
||||||
|
|
||||||
|
return flatParameters
|
||||||
|
|
||||||
|
def flattenParameters(self):
|
||||||
|
parameters, gradParameters = self.parameters()
|
||||||
|
p, g = self._flatten(parameters), self._flatten(gradParameters)
|
||||||
|
|
||||||
|
assert p.nElement() == g.nElement()
|
||||||
|
if parameters:
|
||||||
|
for param, grad in zip(parameters, gradParameters):
|
||||||
|
assert param.storageOffset() == grad.storageOffset()
|
||||||
|
|
||||||
|
return p, g
|
||||||
|
|
||||||
|
def apply(self, callback):
|
||||||
|
callback(self)
|
||||||
|
for _, module in self.modules:
|
||||||
|
module.apply(callback)
|
||||||
|
|
||||||
|
def findModules(self, typename, container=None):
|
||||||
|
nodes = []
|
||||||
|
containers = []
|
||||||
|
mod_type = str(type(self))
|
||||||
|
if mod_type == typename:
|
||||||
|
nodes.append(self)
|
||||||
|
containers.append(container)
|
||||||
|
|
||||||
|
# Recurse on nodes with 'modules'
|
||||||
|
if self.modules:
|
||||||
|
for child in self.modules:
|
||||||
|
child_nodes, child_containers = child.findModules(typename, self)
|
||||||
|
assert len(child_nodes) == len(child_containers)
|
||||||
|
# add the list items from our child to our list (i.e. return a
|
||||||
|
# flattened table of the return nodes).
|
||||||
|
nodes.extend(child_nodes)
|
||||||
|
containers.extend(child_containers)
|
||||||
|
|
||||||
|
return nodes, containers
|
||||||
|
|
||||||
|
def listModules(self):
|
||||||
|
# include self first
|
||||||
|
modules = [self]
|
||||||
|
if self.modules:
|
||||||
|
for child in self.modules:
|
||||||
|
modules.extend(child.listModules())
|
||||||
|
return modules
|
||||||
|
|
||||||
|
def clearState(self):
|
||||||
|
return nn.utils.clear(self, 'output', 'gradInput')
|
||||||
|
|
||||||
|
def replace(self, callback):
|
||||||
|
out = callback(self)
|
||||||
|
# TODO: not out.modules?
|
||||||
|
if self.modules:
|
||||||
|
for i, module in self.modules:
|
||||||
|
self.modules[i] = module.replace(callback)
|
||||||
|
return out
|
||||||
|
|
1119
torch/legacy/nn/THNN.h
Normal file
1119
torch/legacy/nn/THNN.h
Normal file
File diff suppressed because it is too large
Load Diff
5
torch/legacy/nn/__init__.py
Normal file
5
torch/legacy/nn/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .ffi import _backends
|
||||||
|
|
||||||
|
from .Module import Module
|
||||||
|
from .Abs import Abs
|
||||||
|
from .AbsCriterion import AbsCriterion
|
108
torch/legacy/nn/ffi.py
Normal file
108
torch/legacy/nn/ffi.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import ctypes
|
||||||
|
import itertools
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: submodule THNN
|
||||||
|
THNN_H_PATH = '/Users/apaszke/pytorch/pytorch/torch/legacy/nn/THNN.h'
|
||||||
|
THNN_LIB_PATH = '/Users/apaszke/torch/install/lib/lua/5.1/libTHNN.so'
|
||||||
|
|
||||||
|
with open(THNN_H_PATH, 'r') as f:
|
||||||
|
lines = f.read().split('\n')
|
||||||
|
|
||||||
|
# Remove empty lines and preprocessor directives
|
||||||
|
lines = filter(lambda l: l and not l.startswith('#'), lines)
|
||||||
|
# Remove line comments
|
||||||
|
lines = map(lambda l: l.partition('//')[0], lines)
|
||||||
|
# Remove trailing special signs
|
||||||
|
lines = map(lambda l: l.rstrip(');').rstrip(','), lines)
|
||||||
|
# Split arguments
|
||||||
|
lines = map(lambda l: l.split(','), lines)
|
||||||
|
# Flatten list
|
||||||
|
lines = itertools.chain.from_iterable(lines)
|
||||||
|
# Remove unnecessary whitespace
|
||||||
|
lines = map(lambda l: l.strip(), lines)
|
||||||
|
# Remove empty lines
|
||||||
|
lines = filter(lambda l: l, lines)
|
||||||
|
|
||||||
|
class Function(object):
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
self.arguments = []
|
||||||
|
|
||||||
|
def add_argument(self, arg):
|
||||||
|
self.arguments.append(arg)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.name + '(' + ', '.join(self.arguments) + ')'
|
||||||
|
|
||||||
|
generic_functions = []
|
||||||
|
for l in lines:
|
||||||
|
if l.startswith('TH_API void THNN_'):
|
||||||
|
fn_name = l.lstrip('TH_API void THNN_')[1:-2]
|
||||||
|
generic_functions.append(Function(fn_name))
|
||||||
|
else:
|
||||||
|
t, name = l.split(' ')
|
||||||
|
if '*' in name:
|
||||||
|
t = t + '*'
|
||||||
|
generic_functions[-1].add_argument(t)
|
||||||
|
|
||||||
|
types = ['Float', 'Double']
|
||||||
|
|
||||||
|
class THNNBackendBase(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.methods = {}
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
method = self.methods.get(name, None)
|
||||||
|
if method is None:
|
||||||
|
raise NotImplementedError
|
||||||
|
return method
|
||||||
|
|
||||||
|
def register_method(self, name, ctypes_fn):
|
||||||
|
self.methods[name] = ctypes_fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def library_state(self):
|
||||||
|
return ctypes.c_void_p()
|
||||||
|
|
||||||
|
lib_handle = ctypes.cdll.LoadLibrary(THNN_LIB_PATH)
|
||||||
|
|
||||||
|
# TODO: typechecking
|
||||||
|
class TorchArgument(object):
|
||||||
|
@staticmethod
|
||||||
|
def from_param(obj):
|
||||||
|
if hasattr(obj, '_cdata'):
|
||||||
|
return ctypes.c_void_p(obj._cdata)
|
||||||
|
else:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
TYPE_CONVERTERS = {
|
||||||
|
# TODO: this won't work for CUDA
|
||||||
|
'THNNState*': ctypes.c_void_p,
|
||||||
|
'THTensor*': TorchArgument,
|
||||||
|
'THIndexTensor*': TorchArgument,
|
||||||
|
'THIntegerTensor*': TorchArgument,
|
||||||
|
'THGenerator*': TorchArgument,
|
||||||
|
'int': ctypes.c_int,
|
||||||
|
'real': ctypes.c_double,
|
||||||
|
'double': ctypes.c_double,
|
||||||
|
'bool': ctypes.c_bool,
|
||||||
|
'long': ctypes.c_long,
|
||||||
|
'THIndex_t': ctypes.c_long,
|
||||||
|
}
|
||||||
|
|
||||||
|
class Backends(object):
|
||||||
|
pass
|
||||||
|
_backends = Backends()
|
||||||
|
|
||||||
|
for t in types:
|
||||||
|
backend_name = 'THNN{}Backend'.format(t)
|
||||||
|
backend = THNNBackendBase()
|
||||||
|
setattr(_backends, backend_name, backend)
|
||||||
|
for function in generic_functions:
|
||||||
|
full_fn_name = 'THNN_{}{}'.format(t, function.name)
|
||||||
|
ctypes_fn = getattr(lib_handle, full_fn_name)
|
||||||
|
ctypes_fn.restype = None # All functions return void
|
||||||
|
ctypes_fn.argtypes = [TYPE_CONVERTERS[t] for t in function.arguments]
|
||||||
|
backend.register_method(function.name, ctypes_fn)
|
Reference in New Issue
Block a user