Add Storage.py template

This commit is contained in:
Adam Paszke
2016-05-03 15:11:32 +02:00
parent b0d90e3688
commit 690d470c71
7 changed files with 103 additions and 37 deletions

16
Makefile Normal file
View File

@ -0,0 +1,16 @@
# Add main target here - setup.py doesn't understand the need to recompile
# after generic files change
.PHONY: all clean torch
all: torch install
torch: clean
python3 setup.py build
install:
python3 setup.py install
clean:
@rm -rf build
@rm -rf dist
@rm -rf torch.egg-info

11
README.md Normal file
View File

@ -0,0 +1,11 @@
# torch
### Installation
```bash
python3 setup.py install
```
### TODO
See list in issue #1

View File

@ -1,35 +1,51 @@
from setuptools import setup, Extension
################################################################################
# Generate __init__.py from templates
################################################################################
in_init = "torch/__init__.py.in"
out_init = "torch/__init__.py"
templates = ["torch/Tensor.py", "torch/Storage.py"]
types = ['Double', 'Float', 'Long', 'Int', 'Char', 'Byte']
generated = ''
for template in templates:
with open(template, 'r') as f:
template_content = f.read()
for T in types:
replacements = [
('RealTensorBase', T + 'TensorBase'),
('RealTensor', T + 'Tensor'),
('RealStorageBase', T + 'StorageBase'),
('RealStorage', T + 'Storage'),
]
new_content = template_content
for pattern, replacement in replacements:
new_content = new_content.replace(pattern, replacement)
generated += '\n{}\n'.format(new_content)
with open(in_init, 'r') as i:
header = i.read()
with open(out_init, 'w') as f:
f.write("""
################################################################################
# WARNING
# This file is generated automatically. Do not edit it, as it will be
# regenerated at next build
################################################################################
""")
f.write(header)
f.write(generated)
################################################################################
# Declare the package
################################################################################
C = Extension("torch.C",
libraries=['TH'],
sources=["torch/csrc/Module.c", "torch/csrc/Tensor.c", "torch/csrc/Storage.c"],
include_dirs=["torch/csrc"])
in_init = "torch/__init__.py.in"
out_init = "torch/__init__.py"
templates = ["torch/Tensor.py"]
types = ['Double', 'Float', 'Long', 'Int', 'Char', 'Byte']
to_append = ''
for template in templates:
with open(template, 'r') as f:
template_content = f.read()
for T in types:
t = T.lower()
replacements = [
('RealTensorBase', T + 'TensorBase'),
('RealTensor', T + 'Tensor'),
]
new_content = template_content
for pattern, replacement in replacements:
new_content = new_content.replace(pattern, replacement)
to_append += '\n'
to_append += new_content
to_append += '\n'
with open(out_init, 'w') as f:
with open(in_init, 'r') as i:
header = i.read()
f.write(header)
f.write(to_append)
setup(name="torch", version="0.1",
ext_modules=[C],

8
torch/Storage.py Normal file
View File

@ -0,0 +1,8 @@
class RealStorage(C.RealStorageBase):
def __str__(self):
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
return content + '\n[torch.RealStorage of size {}]'.format(len(self))
def __repr__(self):
return str(self)

View File

@ -1,5 +1,3 @@
# This will be included in __init__.py
class RealTensor(C.RealTensorBase):
def __str__(self):
return "RealTensor"

View File

@ -1,9 +1 @@
import torch.C
class LongStorage(C.LongStorageBase):
def __str__(self):
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
return content + '\n[torch.LongStorage of size {}]'.format(len(self))
def __repr__(self):
return str(self)

View File

@ -63,11 +63,19 @@ static Py_ssize_t THPStorage_(length)(PyObject *_self)
return THStorage_(size)(self->cdata);
}
static PyObject * THPStorage_(size)(PyObject *_self)
{
GET_SELF;
return PyLong_FromLong(THStorage_(size)(self->cdata));
}
static PyObject * THPStorage_(get)(PyObject *_self, PyObject *index)
{
GET_SELF;
if (!PyLong_Check(index))
if (!PyLong_Check(index)) {
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers supported at the moment");
return NULL;
}
size_t nindex = PyLong_AsSize_t(index);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
return PyFloat_FromDouble(THStorage_(get)(self->cdata, nindex));
@ -78,7 +86,23 @@ static PyObject * THPStorage_(get)(PyObject *_self, PyObject *index)
static int THPStorage_(set)(PyObject *_self, PyObject *index, PyObject *value)
{
GET_SELF;
if (!PyLong_Check(index)) {
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers supported at the moment");
return -1;
}
real rvalue;
// TODO: overflow checks
if (PyLong_Check(value)) {
rvalue = PyLong_AsLongLong(value);
} else if (PyFloat_Check(value)) {
rvalue = PyFloat_AsDouble(value);
} else {
PyErr_SetString(PyExc_RuntimeError, "Only assignment of floats and integers supported at the moment");
return -1;
}
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
return 0;
}
static struct PyMemberDef THPStorage_(members)[] = {
@ -87,6 +111,7 @@ static struct PyMemberDef THPStorage_(members)[] = {
};
static PyMethodDef THPStorage_(methods)[] = {
{"size", (PyCFunction)THPStorage_(size), METH_NOARGS, NULL},
{NULL}
};