From b0d90e3688840381534dd0f2ae8d74518e48a87f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 2 May 2016 23:54:59 +0200 Subject: [PATCH] Add templated __init__ --- .gitignore | 1 + setup.py | 26 ++++++++++++++++++++++++++ torch/Tensor.py | 8 ++++++++ torch/{__init__.py => __init__.py.in} | 7 ------- 4 files changed, 35 insertions(+), 7 deletions(-) create mode 100644 torch/Tensor.py rename torch/{__init__.py => __init__.py.in} (67%) diff --git a/.gitignore b/.gitignore index d19052f93db8..6d0e99b9f084 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ build/ dist/ torch.egg-info/ */**/__pycache__ +torch/__init__.py diff --git a/setup.py b/setup.py index 1d57e08bd09e..9f16663a071c 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,32 @@ C = Extension("torch.C", sources=["torch/csrc/Module.c", "torch/csrc/Tensor.c", "torch/csrc/Storage.c"], include_dirs=["torch/csrc"]) +in_init = "torch/__init__.py.in" +out_init = "torch/__init__.py" +templates = ["torch/Tensor.py"] +types = ['Double', 'Float', 'Long', 'Int', 'Char', 'Byte'] +to_append = '' +for template in templates: + with open(template, 'r') as f: + template_content = f.read() + for T in types: + t = T.lower() + replacements = [ + ('RealTensorBase', T + 'TensorBase'), + ('RealTensor', T + 'Tensor'), + ] + new_content = template_content + for pattern, replacement in replacements: + new_content = new_content.replace(pattern, replacement) + to_append += '\n' + to_append += new_content + to_append += '\n' +with open(out_init, 'w') as f: + with open(in_init, 'r') as i: + header = i.read() + f.write(header) + f.write(to_append) + setup(name="torch", version="0.1", ext_modules=[C], packages=['torch']) diff --git a/torch/Tensor.py b/torch/Tensor.py new file mode 100644 index 000000000000..118acbb104bf --- /dev/null +++ b/torch/Tensor.py @@ -0,0 +1,8 @@ +# This will be included in __init__.py + +class RealTensor(C.RealTensorBase): + def __str__(self): + return "RealTensor" + + def __repr__(self): + return str(self) diff --git a/torch/__init__.py b/torch/__init__.py.in similarity index 67% rename from torch/__init__.py rename to torch/__init__.py.in index 6e0559706d38..2cd61f9b23fe 100644 --- a/torch/__init__.py +++ b/torch/__init__.py.in @@ -1,12 +1,5 @@ import torch.C -class FloatTensor(C.FloatTensorBase): - def __str__(self): - return "Tensor" - - def __repr__(self): - return str(self) - class LongStorage(C.LongStorageBase): def __str__(self): content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))