Add templated __init__

This commit is contained in:
Adam Paszke
2016-05-02 23:54:59 +02:00
parent 9c18e7c990
commit b0d90e3688
4 changed files with 35 additions and 7 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ build/
dist/
torch.egg-info/
*/**/__pycache__
torch/__init__.py

View File

@ -5,6 +5,32 @@ C = Extension("torch.C",
sources=["torch/csrc/Module.c", "torch/csrc/Tensor.c", "torch/csrc/Storage.c"],
include_dirs=["torch/csrc"])
in_init = "torch/__init__.py.in"
out_init = "torch/__init__.py"
templates = ["torch/Tensor.py"]
types = ['Double', 'Float', 'Long', 'Int', 'Char', 'Byte']
to_append = ''
for template in templates:
with open(template, 'r') as f:
template_content = f.read()
for T in types:
t = T.lower()
replacements = [
('RealTensorBase', T + 'TensorBase'),
('RealTensor', T + 'Tensor'),
]
new_content = template_content
for pattern, replacement in replacements:
new_content = new_content.replace(pattern, replacement)
to_append += '\n'
to_append += new_content
to_append += '\n'
with open(out_init, 'w') as f:
with open(in_init, 'r') as i:
header = i.read()
f.write(header)
f.write(to_append)
setup(name="torch", version="0.1",
ext_modules=[C],
packages=['torch'])

8
torch/Tensor.py Normal file
View File

@ -0,0 +1,8 @@
# This will be included in __init__.py
class RealTensor(C.RealTensorBase):
def __str__(self):
return "RealTensor"
def __repr__(self):
return str(self)

View File

@ -1,12 +1,5 @@
import torch.C
class FloatTensor(C.FloatTensorBase):
def __str__(self):
return "Tensor"
def __repr__(self):
return str(self)
class LongStorage(C.LongStorageBase):
def __str__(self):
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))