mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Major refactor
This commit is contained in:
110
setup.py
110
setup.py
@ -1,13 +1,16 @@
|
||||
from setuptools import setup, Extension, distutils
|
||||
from os.path import expanduser
|
||||
from tools.nnwrap import generate_wrappers as generate_nn_wrappers
|
||||
from tools.cwrap import cwrap
|
||||
from tools.cwrap.plugins.THPPlugin import THPPlugin
|
||||
from tools.cwrap.plugins.THPLongArgsPlugin import THPLongArgsPlugin
|
||||
from tools.cwrap.plugins.ArgcountSortPlugin import ArgcountSortPlugin
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
subprocess.call(['bash', 'torch/lib/build_all.sh', '--with-cuda'])
|
||||
import sys
|
||||
import os
|
||||
|
||||
# TODO: detect CUDA
|
||||
WITH_CUDA = False
|
||||
WITH_CUDA = True
|
||||
DEBUG = False
|
||||
|
||||
################################################################################
|
||||
@ -32,31 +35,42 @@ def parallelCCompile(self, sources, output_dir=None, macros=None, include_dirs=N
|
||||
distutils.ccompiler.CCompiler.compile = parallelCCompile
|
||||
|
||||
################################################################################
|
||||
# Generate Tensor methods
|
||||
# Build libraries
|
||||
################################################################################
|
||||
|
||||
cwrap_src = ['torch/csrc/generic/TensorMethods.cwrap.cpp']
|
||||
for src in cwrap_src:
|
||||
print("Generating code for " + src)
|
||||
cwrap(src)
|
||||
if subprocess.call(['bash', 'torch/lib/build_all.sh'] + (['--with-cuda'] if WITH_CUDA else [])) != 0:
|
||||
sys.exit(1)
|
||||
|
||||
################################################################################
|
||||
# Declare the package
|
||||
# Generate cpp code
|
||||
################################################################################
|
||||
|
||||
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[THPLongArgsPlugin(), THPPlugin(), ArgcountSortPlugin()])
|
||||
generate_nn_wrappers()
|
||||
|
||||
################################################################################
|
||||
# Configure compile flags
|
||||
################################################################################
|
||||
|
||||
include_dirs = []
|
||||
extra_link_args = []
|
||||
|
||||
# TODO: remove and properly submodule TH in the repo itself
|
||||
th_path = expanduser("~/torch/install/")
|
||||
torch_headers = th_path + "include"
|
||||
th_header_path = th_path + "include/TH"
|
||||
th_lib_path = th_path + "lib"
|
||||
extra_link_args.append('-L' + th_lib_path)
|
||||
extra_link_args.append('-Wl,-rpath,' + th_lib_path)
|
||||
|
||||
libraries = ['TH']
|
||||
extra_compile_args = ['-std=c++11']
|
||||
sources = [
|
||||
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
lib_path = os.path.join(cwd, "torch", "lib")
|
||||
|
||||
tmp_install_path = lib_path + "/tmp_install"
|
||||
include_dirs += [
|
||||
cwd,
|
||||
os.path.join(cwd, "torch", "csrc"),
|
||||
tmp_install_path + "/include",
|
||||
tmp_install_path + "/include/TH",
|
||||
]
|
||||
|
||||
extra_link_args.append('-L' + lib_path)
|
||||
|
||||
main_libraries = ['TH']
|
||||
main_sources = [
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/Generator.cpp",
|
||||
"torch/csrc/Tensor.cpp",
|
||||
@ -65,26 +79,62 @@ sources = [
|
||||
]
|
||||
|
||||
if WITH_CUDA:
|
||||
libraries += ['THC']
|
||||
if platform.system() == 'Darwin':
|
||||
include_dirs += ['/Developer/NVIDIA/CUDA-7.5/include']
|
||||
else:
|
||||
include_dirs += ['/usr/local/cuda/include']
|
||||
extra_compile_args += ['-DWITH_CUDA']
|
||||
sources += [
|
||||
main_libraries += ['THC']
|
||||
main_sources += [
|
||||
"torch/csrc/cuda/Module.cpp",
|
||||
"torch/csrc/cuda/Storage.cpp",
|
||||
"torch/csrc/cuda/Tensor.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
]
|
||||
|
||||
if DEBUG:
|
||||
extra_compile_args += ['-O0', '-g']
|
||||
extra_link_args += ['-O0', '-g']
|
||||
|
||||
################################################################################
|
||||
# Declare extensions and the package
|
||||
################################################################################
|
||||
|
||||
extensions = []
|
||||
|
||||
C = Extension("torch._C",
|
||||
libraries=libraries,
|
||||
sources=sources,
|
||||
libraries=main_libraries,
|
||||
sources=main_sources,
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args + (['-O0', '-g'] if DEBUG else []),
|
||||
include_dirs=([".", "torch/csrc", "cutorch/csrc", torch_headers, th_header_path, "/Developer/NVIDIA/CUDA-7.5/include", "/usr/local/cuda/include"]),
|
||||
extra_link_args = extra_link_args + (['-O0', '-g'] if DEBUG else []),
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + ['-Wl,-rpath,$ORIGIN/lib'],
|
||||
)
|
||||
extensions.append(C)
|
||||
|
||||
THNN = Extension("torch._thnn._THNN",
|
||||
libraries=['TH', 'THNN'],
|
||||
sources=['torch/csrc/nn/THNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + ['-Wl,-rpath,$ORIGIN/../lib'],
|
||||
)
|
||||
extensions.append(THNN)
|
||||
|
||||
if WITH_CUDA:
|
||||
THCUNN = Extension("torch._thnn._THCUNN",
|
||||
libraries=['TH', 'THC', 'THCUNN'],
|
||||
sources=['torch/csrc/nn/THCUNN.cpp'],
|
||||
language='c++',
|
||||
extra_compile_args=extra_compile_args,
|
||||
include_dirs=include_dirs,
|
||||
extra_link_args=extra_link_args + ['-Wl,-rpath,$ORIGIN/../lib'],
|
||||
)
|
||||
extensions.append(THCUNN)
|
||||
|
||||
setup(name="torch", version="0.1",
|
||||
ext_modules=[C],
|
||||
packages=['torch', 'torch.legacy', 'torch.legacy.nn', 'torch.legacy.optim'] + (['torch.cuda', 'torch.legacy.cunn'] if WITH_CUDA else []),
|
||||
ext_modules=extensions,
|
||||
packages=['torch', 'torch._thnn', 'torch.legacy', 'torch.legacy.nn', 'torch.legacy.optim'] + (['torch.cuda', 'torch.legacy.cunn'] if WITH_CUDA else []),
|
||||
package_data={'torch': ['lib/*.so', 'lib/*.h']}
|
||||
)
|
||||
|
Reference in New Issue
Block a user