mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add templated __init__
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,3 +2,4 @@ build/
|
||||
dist/
|
||||
torch.egg-info/
|
||||
*/**/__pycache__
|
||||
torch/__init__.py
|
||||
|
26
setup.py
26
setup.py
@ -5,6 +5,32 @@ C = Extension("torch.C",
|
||||
sources=["torch/csrc/Module.c", "torch/csrc/Tensor.c", "torch/csrc/Storage.c"],
|
||||
include_dirs=["torch/csrc"])
|
||||
|
||||
in_init = "torch/__init__.py.in"
|
||||
out_init = "torch/__init__.py"
|
||||
templates = ["torch/Tensor.py"]
|
||||
types = ['Double', 'Float', 'Long', 'Int', 'Char', 'Byte']
|
||||
to_append = ''
|
||||
for template in templates:
|
||||
with open(template, 'r') as f:
|
||||
template_content = f.read()
|
||||
for T in types:
|
||||
t = T.lower()
|
||||
replacements = [
|
||||
('RealTensorBase', T + 'TensorBase'),
|
||||
('RealTensor', T + 'Tensor'),
|
||||
]
|
||||
new_content = template_content
|
||||
for pattern, replacement in replacements:
|
||||
new_content = new_content.replace(pattern, replacement)
|
||||
to_append += '\n'
|
||||
to_append += new_content
|
||||
to_append += '\n'
|
||||
with open(out_init, 'w') as f:
|
||||
with open(in_init, 'r') as i:
|
||||
header = i.read()
|
||||
f.write(header)
|
||||
f.write(to_append)
|
||||
|
||||
setup(name="torch", version="0.1",
|
||||
ext_modules=[C],
|
||||
packages=['torch'])
|
||||
|
8
torch/Tensor.py
Normal file
8
torch/Tensor.py
Normal file
@ -0,0 +1,8 @@
|
||||
# This will be included in __init__.py
|
||||
|
||||
class RealTensor(C.RealTensorBase):
|
||||
def __str__(self):
|
||||
return "RealTensor"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
@ -1,12 +1,5 @@
|
||||
import torch.C
|
||||
|
||||
class FloatTensor(C.FloatTensorBase):
|
||||
def __str__(self):
|
||||
return "Tensor"
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
class LongStorage(C.LongStorageBase):
|
||||
def __str__(self):
|
||||
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
|
Reference in New Issue
Block a user