mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: In TorchScript and C++ extensions we currently advocate a mix of `torch::` and `at::` namespace usage. In the C++ frontend I had instead exported all symbols from `at::` and some from `c10::` into the `torch::` namespace. This is far, far easier for users to understand, and also avoid bugs around creating tensors vs. variables. The same should from now on be true for the TorchScript C++ API (for running and loading models) and all C++ extensions. Note that since we're just talking about typedefs, this change does not break any existing code. Once this lands I will update stuff in `pytorch/tutorials` too. zdevito ezyang gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/13523 Differential Revision: D12942787 Pulled By: goldsborough fbshipit-source-id: 76058936bd8707b33d9e5bbc2d0705fc3d820763
1073 lines
41 KiB
Python
1073 lines
41 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import copy
|
|
import glob
|
|
import imp
|
|
import os
|
|
import re
|
|
import setuptools
|
|
import subprocess
|
|
import sys
|
|
import sysconfig
|
|
import tempfile
|
|
import warnings
|
|
|
|
import torch
|
|
from .file_baton import FileBaton
|
|
from ._cpp_extension_versioner import ExtensionVersioner
|
|
|
|
from setuptools.command.build_ext import build_ext
|
|
|
|
|
|
IS_WINDOWS = sys.platform == 'win32'
|
|
|
|
|
|
def _find_cuda_home():
|
|
'''Finds the CUDA install path.'''
|
|
# Guess #1
|
|
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
|
if cuda_home is None:
|
|
# Guess #2
|
|
if IS_WINDOWS:
|
|
cuda_homes = glob.glob(
|
|
'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
|
|
if len(cuda_homes) == 0:
|
|
cuda_home = ''
|
|
else:
|
|
cuda_home = cuda_homes[0]
|
|
else:
|
|
cuda_home = '/usr/local/cuda'
|
|
if not os.path.exists(cuda_home):
|
|
# Guess #3
|
|
try:
|
|
which = 'where' if IS_WINDOWS else 'which'
|
|
nvcc = subprocess.check_output(
|
|
[which, 'nvcc']).decode().rstrip('\r\n')
|
|
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
|
except Exception:
|
|
cuda_home = None
|
|
if cuda_home and not torch.cuda.is_available():
|
|
print("No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home))
|
|
return cuda_home
|
|
|
|
|
|
MINIMUM_GCC_VERSION = (4, 9, 0)
|
|
MINIMUM_MSVC_VERSION = (19, 0, 24215)
|
|
ABI_INCOMPATIBILITY_WARNING = '''
|
|
|
|
!! WARNING !!
|
|
|
|
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
Your compiler ({}) may be ABI-incompatible with PyTorch!
|
|
Please use a compiler that is ABI-compatible with GCC 4.9 and above.
|
|
See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.
|
|
|
|
See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6
|
|
for instructions on how to install GCC 4.9 or higher.
|
|
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
|
|
!! WARNING !!
|
|
'''
|
|
WRONG_COMPILER_WARNING = '''
|
|
|
|
!! WARNING !!
|
|
|
|
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
Your compiler ({user_compiler}) is not compatible with the compiler Pytorch was
|
|
built with for this platform, which is {pytorch_compiler} on {platform}. Please
|
|
use {pytorch_compiler} to to compile your extension. Alternatively, you may
|
|
compile PyTorch from source using {user_compiler}, and then you can also use
|
|
{user_compiler} to compile your extension.
|
|
|
|
See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for help
|
|
with compiling PyTorch from source.
|
|
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
|
|
!! WARNING !!
|
|
'''
|
|
ACCEPTED_COMPILERS_FOR_PLATFORM = {'darwin': ['clang++', 'clang'], 'linux': ['g++', 'gcc']}
|
|
CUDA_HOME = _find_cuda_home()
|
|
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
|
|
# PyTorch releases have the version pattern major.minor.patch, whereas when
|
|
# PyTorch is built from source, we append the git commit hash, which gives
|
|
# it the below pattern.
|
|
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
|
|
|
|
COMMON_NVCC_FLAGS = [
|
|
'-D__CUDA_NO_HALF_OPERATORS__',
|
|
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
|
'-D__CUDA_NO_HALF2_OPERATORS__',
|
|
]
|
|
|
|
|
|
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
|
|
|
|
|
def _is_binary_build():
|
|
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
|
|
|
|
|
|
def get_default_build_root():
|
|
'''
|
|
Returns the path to the root folder under which extensions will built.
|
|
|
|
For each extension module built, there will be one folder underneath the
|
|
folder returned by this function. For example, if ``p`` is the path
|
|
returned by this function and ``ext`` the name of an extension, the build
|
|
folder for the extension will be ``p/ext``.
|
|
'''
|
|
# tempfile.gettempdir() will be /tmp on UNIX and \TEMP on Windows.
|
|
return os.path.realpath(os.path.join(tempfile.gettempdir(), 'torch_extensions'))
|
|
|
|
|
|
def check_compiler_ok_for_platform(compiler):
|
|
'''
|
|
Verifies that the compiler is the expected one for the current platform.
|
|
|
|
Arguments:
|
|
compiler (str): The compiler executable to check.
|
|
|
|
Returns:
|
|
True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS,
|
|
and always True for Windows.
|
|
'''
|
|
if IS_WINDOWS:
|
|
return True
|
|
which = subprocess.check_output(['which', compiler], stderr=subprocess.STDOUT)
|
|
# Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.
|
|
compiler_path = os.path.realpath(which.decode().strip())
|
|
accepted_compilers = ACCEPTED_COMPILERS_FOR_PLATFORM[sys.platform]
|
|
return any(name in compiler_path for name in accepted_compilers)
|
|
|
|
|
|
def check_compiler_abi_compatibility(compiler):
|
|
'''
|
|
Verifies that the given compiler is ABI-compatible with PyTorch.
|
|
|
|
Arguments:
|
|
compiler (str): The compiler executable name to check (e.g. ``g++``).
|
|
Must be executable in a shell process.
|
|
|
|
Returns:
|
|
False if the compiler is (likely) ABI-incompatible with PyTorch,
|
|
else True.
|
|
'''
|
|
if not _is_binary_build():
|
|
return True
|
|
if os.environ.get('TORCH_DONT_CHECK_COMPILER_ABI') in ['ON', '1', 'YES', 'TRUE', 'Y']:
|
|
return True
|
|
|
|
# First check if the compiler is one of the expected ones for the particular platform.
|
|
if not check_compiler_ok_for_platform(compiler):
|
|
warnings.warn(WRONG_COMPILER_WARNING.format(
|
|
user_compiler=compiler,
|
|
pytorch_compiler=ACCEPTED_COMPILERS_FOR_PLATFORM[sys.platform][0],
|
|
platform=sys.platform))
|
|
return False
|
|
|
|
if sys.platform == 'darwin':
|
|
# There is no particular minimum version we need for clang, so we're good here.
|
|
return True
|
|
try:
|
|
if sys.platform == 'linux':
|
|
minimum_required_version = MINIMUM_GCC_VERSION
|
|
version = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
|
|
version = version.split('.')
|
|
else:
|
|
minimum_required_version = MINIMUM_MSVC_VERSION
|
|
compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
|
|
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info)
|
|
version = (0, 0, 0) if match is None else match.groups()
|
|
except Exception:
|
|
_, error, _ = sys.exc_info()
|
|
warnings.warn('Error checking compiler version for {}: {}'.format(compiler, error))
|
|
return False
|
|
|
|
if tuple(map(int, version)) >= minimum_required_version:
|
|
return True
|
|
|
|
compiler = '{} {}'.format(compiler, version.group(0))
|
|
warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))
|
|
|
|
return False
|
|
|
|
|
|
class BuildExtension(build_ext):
|
|
'''
|
|
A custom :mod:`setuptools` build extension .
|
|
|
|
This :class:`setuptools.build_ext` subclass takes care of passing the
|
|
minimum required compiler flags (e.g. ``-std=c++11``) as well as mixed
|
|
C++/CUDA compilation (and support for CUDA files in general).
|
|
|
|
When using :class:`BuildExtension`, it is allowed to supply a dictionary
|
|
for ``extra_compile_args`` (rather than the usual list) that maps from
|
|
languages (``cxx`` or ``cuda``) to a list of additional compiler flags to
|
|
supply to the compiler. This makes it possible to supply different flags to
|
|
the C++ and CUDA compiler during mixed compilation.
|
|
'''
|
|
|
|
def build_extensions(self):
|
|
self._check_abi()
|
|
for extension in self.extensions:
|
|
self._add_compile_flag(extension, '-DTORCH_API_INCLUDE_EXTENSION_H')
|
|
self._define_torch_extension_name(extension)
|
|
self._add_gnu_abi_flag_if_binary(extension)
|
|
|
|
# Register .cu and .cuh as valid source extensions.
|
|
self.compiler.src_extensions += ['.cu', '.cuh']
|
|
# Save the original _compile method for later.
|
|
if self.compiler.compiler_type == 'msvc':
|
|
self.compiler._cpp_extensions += ['.cu', '.cuh']
|
|
original_compile = self.compiler.compile
|
|
original_spawn = self.compiler.spawn
|
|
else:
|
|
original_compile = self.compiler._compile
|
|
|
|
def unix_wrap_compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
|
|
# Copy before we make any modifications.
|
|
cflags = copy.deepcopy(extra_postargs)
|
|
try:
|
|
original_compiler = self.compiler.compiler_so
|
|
if _is_cuda_file(src):
|
|
nvcc = _join_cuda_home('bin', 'nvcc')
|
|
if not isinstance(nvcc, list):
|
|
nvcc = [nvcc]
|
|
self.compiler.set_executable('compiler_so', nvcc)
|
|
if isinstance(cflags, dict):
|
|
cflags = cflags['nvcc']
|
|
cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
|
|
elif isinstance(cflags, dict):
|
|
cflags = cflags['cxx']
|
|
# NVCC does not allow multiple -std to be passed, so we avoid
|
|
# overriding the option if the user explicitly passed it.
|
|
if not any(flag.startswith('-std=') for flag in cflags):
|
|
cflags.append('-std=c++11')
|
|
|
|
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
|
|
finally:
|
|
# Put the original compiler back in place.
|
|
self.compiler.set_executable('compiler_so', original_compiler)
|
|
|
|
def win_wrap_compile(sources,
|
|
output_dir=None,
|
|
macros=None,
|
|
include_dirs=None,
|
|
debug=0,
|
|
extra_preargs=None,
|
|
extra_postargs=None,
|
|
depends=None):
|
|
|
|
self.cflags = copy.deepcopy(extra_postargs)
|
|
extra_postargs = None
|
|
|
|
def spawn(cmd):
|
|
orig_cmd = cmd
|
|
# Using regex to match src, obj and include files
|
|
|
|
src_regex = re.compile('/T(p|c)(.*)')
|
|
src_list = [
|
|
m.group(2) for m in (src_regex.match(elem) for elem in cmd)
|
|
if m
|
|
]
|
|
|
|
obj_regex = re.compile('/Fo(.*)')
|
|
obj_list = [
|
|
m.group(1) for m in (obj_regex.match(elem) for elem in cmd)
|
|
if m
|
|
]
|
|
|
|
include_regex = re.compile(r'((\-|\/)I.*)')
|
|
include_list = [
|
|
m.group(1)
|
|
for m in (include_regex.match(elem) for elem in cmd) if m
|
|
]
|
|
|
|
if len(src_list) >= 1 and len(obj_list) >= 1:
|
|
src = src_list[0]
|
|
obj = obj_list[0]
|
|
if _is_cuda_file(src):
|
|
nvcc = _join_cuda_home('bin', 'nvcc')
|
|
if isinstance(self.cflags, dict):
|
|
cflags = self.cflags['nvcc']
|
|
elif isinstance(self.cflags, list):
|
|
cflags = self.cflags
|
|
else:
|
|
cflags = []
|
|
cmd = [
|
|
nvcc, '-c', src, '-o', obj, '-Xcompiler',
|
|
'/wd4819', '-Xcompiler', '/MD'
|
|
] + include_list + cflags
|
|
elif isinstance(self.cflags, dict):
|
|
cflags = self.cflags['cxx']
|
|
cmd += cflags
|
|
elif isinstance(self.cflags, list):
|
|
cflags = self.cflags
|
|
cmd += cflags
|
|
|
|
return original_spawn(cmd)
|
|
|
|
try:
|
|
self.compiler.spawn = spawn
|
|
return original_compile(sources, output_dir, macros,
|
|
include_dirs, debug, extra_preargs,
|
|
extra_postargs, depends)
|
|
finally:
|
|
self.compiler.spawn = original_spawn
|
|
|
|
# Monkey-patch the _compile method.
|
|
if self.compiler.compiler_type == 'msvc':
|
|
self.compiler.compile = win_wrap_compile
|
|
else:
|
|
self.compiler._compile = unix_wrap_compile
|
|
|
|
build_ext.build_extensions(self)
|
|
|
|
def _check_abi(self):
|
|
# On some platforms, like Windows, compiler_cxx is not available.
|
|
if hasattr(self.compiler, 'compiler_cxx'):
|
|
compiler = self.compiler.compiler_cxx[0]
|
|
elif IS_WINDOWS:
|
|
compiler = os.environ.get('CXX', 'cl')
|
|
else:
|
|
compiler = os.environ.get('CXX', 'c++')
|
|
check_compiler_abi_compatibility(compiler)
|
|
|
|
def _add_compile_flag(self, extension, flag):
|
|
if isinstance(extension.extra_compile_args, dict):
|
|
for args in extension.extra_compile_args.values():
|
|
args.append(flag)
|
|
else:
|
|
extension.extra_compile_args.append(flag)
|
|
|
|
def _define_torch_extension_name(self, extension):
|
|
# pybind11 doesn't support dots in the names
|
|
# so in order to support extensions in the packages
|
|
# like torch._C, we take the last part of the string
|
|
# as the library name
|
|
names = extension.name.split('.')
|
|
name = names[-1]
|
|
define = '-DTORCH_EXTENSION_NAME={}'.format(name)
|
|
self._add_compile_flag(extension, define)
|
|
|
|
def _add_gnu_abi_flag_if_binary(self, extension):
|
|
# If the version string looks like a binary build,
|
|
# we know that PyTorch was compiled with gcc 4.9.2.
|
|
# if the extension is compiled with gcc >= 5.1,
|
|
# then we have to define _GLIBCXX_USE_CXX11_ABI=0
|
|
# so that the std::string in the API is resolved to
|
|
# non-C++11 symbols
|
|
if _is_binary_build():
|
|
self._add_compile_flag(extension, '-D_GLIBCXX_USE_CXX11_ABI=0')
|
|
|
|
|
|
def CppExtension(name, sources, *args, **kwargs):
|
|
'''
|
|
Creates a :class:`setuptools.Extension` for C++.
|
|
|
|
Convenience method that creates a :class:`setuptools.Extension` with the
|
|
bare minimum (but often sufficient) arguments to build a C++ extension.
|
|
|
|
All arguments are forwarded to the :class:`setuptools.Extension`
|
|
constructor.
|
|
|
|
Example:
|
|
>>> from setuptools import setup
|
|
>>> from torch.utils.cpp_extension import BuildExtension, CppExtension
|
|
>>> setup(
|
|
name='extension',
|
|
ext_modules=[
|
|
CppExtension(
|
|
name='extension',
|
|
sources=['extension.cpp'],
|
|
extra_compile_args=['-g'])),
|
|
],
|
|
cmdclass={
|
|
'build_ext': BuildExtension
|
|
})
|
|
'''
|
|
include_dirs = kwargs.get('include_dirs', [])
|
|
include_dirs += include_paths()
|
|
kwargs['include_dirs'] = include_dirs
|
|
|
|
if IS_WINDOWS:
|
|
library_dirs = kwargs.get('library_dirs', [])
|
|
library_dirs += library_paths()
|
|
kwargs['library_dirs'] = library_dirs
|
|
|
|
libraries = kwargs.get('libraries', [])
|
|
libraries.append('c10')
|
|
libraries.append('caffe2')
|
|
libraries.append('torch')
|
|
libraries.append('_C')
|
|
kwargs['libraries'] = libraries
|
|
|
|
kwargs['language'] = 'c++'
|
|
return setuptools.Extension(name, sources, *args, **kwargs)
|
|
|
|
|
|
def CUDAExtension(name, sources, *args, **kwargs):
|
|
'''
|
|
Creates a :class:`setuptools.Extension` for CUDA/C++.
|
|
|
|
Convenience method that creates a :class:`setuptools.Extension` with the
|
|
bare minimum (but often sufficient) arguments to build a CUDA/C++
|
|
extension. This includes the CUDA include path, library path and runtime
|
|
library.
|
|
|
|
All arguments are forwarded to the :class:`setuptools.Extension`
|
|
constructor.
|
|
|
|
Example:
|
|
>>> from setuptools import setup
|
|
>>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
|
>>> setup(
|
|
name='cuda_extension',
|
|
ext_modules=[
|
|
CUDAExtension(
|
|
name='cuda_extension',
|
|
sources=['extension.cpp', 'extension_kernel.cu'],
|
|
extra_compile_args={'cxx': ['-g'],
|
|
'nvcc': ['-O2']})
|
|
],
|
|
cmdclass={
|
|
'build_ext': BuildExtension
|
|
})
|
|
'''
|
|
library_dirs = kwargs.get('library_dirs', [])
|
|
library_dirs += library_paths(cuda=True)
|
|
kwargs['library_dirs'] = library_dirs
|
|
|
|
libraries = kwargs.get('libraries', [])
|
|
libraries.append('cudart')
|
|
if IS_WINDOWS:
|
|
libraries.append('c10')
|
|
libraries.append('caffe2')
|
|
libraries.append('torch')
|
|
libraries.append('caffe2_gpu')
|
|
libraries.append('_C')
|
|
kwargs['libraries'] = libraries
|
|
|
|
include_dirs = kwargs.get('include_dirs', [])
|
|
include_dirs += include_paths(cuda=True)
|
|
kwargs['include_dirs'] = include_dirs
|
|
|
|
kwargs['language'] = 'c++'
|
|
|
|
return setuptools.Extension(name, sources, *args, **kwargs)
|
|
|
|
|
|
def include_paths(cuda=False):
|
|
'''
|
|
Get the include paths required to build a C++ or CUDA extension.
|
|
|
|
Args:
|
|
cuda: If `True`, includes CUDA-specific include paths.
|
|
|
|
Returns:
|
|
A list of include path strings.
|
|
'''
|
|
here = os.path.abspath(__file__)
|
|
torch_path = os.path.dirname(os.path.dirname(here))
|
|
lib_include = os.path.join(torch_path, 'lib', 'include')
|
|
paths = [
|
|
lib_include,
|
|
# Remove this once torch/torch.h is officially no longer supported for C++ extensions.
|
|
os.path.join(lib_include, 'torch', 'csrc', 'api', 'include'),
|
|
# Some internal (old) Torch headers don't properly prefix their includes,
|
|
# so we need to pass -Itorch/lib/include/TH as well.
|
|
os.path.join(lib_include, 'TH'),
|
|
os.path.join(lib_include, 'THC')
|
|
]
|
|
if cuda:
|
|
paths.append(_join_cuda_home('include'))
|
|
if CUDNN_HOME is not None:
|
|
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
|
return paths
|
|
|
|
|
|
def library_paths(cuda=False):
|
|
'''
|
|
Get the library paths required to build a C++ or CUDA extension.
|
|
|
|
Args:
|
|
cuda: If `True`, includes CUDA-specific library paths.
|
|
|
|
Returns:
|
|
A list of library path strings.
|
|
'''
|
|
paths = []
|
|
|
|
if IS_WINDOWS:
|
|
here = os.path.abspath(__file__)
|
|
torch_path = os.path.dirname(os.path.dirname(here))
|
|
lib_path = os.path.join(torch_path, 'lib')
|
|
|
|
paths.append(lib_path)
|
|
|
|
if cuda:
|
|
lib_dir = 'lib/x64' if IS_WINDOWS else 'lib64'
|
|
paths.append(_join_cuda_home(lib_dir))
|
|
if CUDNN_HOME is not None:
|
|
paths.append(os.path.join(CUDNN_HOME, lib_dir))
|
|
return paths
|
|
|
|
|
|
def load(name,
|
|
sources,
|
|
extra_cflags=None,
|
|
extra_cuda_cflags=None,
|
|
extra_ldflags=None,
|
|
extra_include_paths=None,
|
|
build_directory=None,
|
|
verbose=False,
|
|
with_cuda=None):
|
|
'''
|
|
Loads a PyTorch C++ extension just-in-time (JIT).
|
|
|
|
To load an extension, a Ninja build file is emitted, which is used to
|
|
compile the given sources into a dynamic library. This library is
|
|
subsequently loaded into the current Python process as a module and
|
|
returned from this function, ready for use.
|
|
|
|
By default, the directory to which the build file is emitted and the
|
|
resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where
|
|
``<tmp>`` is the temporary folder on the current platform and ``<name>``
|
|
the name of the extension. This location can be overridden in two ways.
|
|
First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it
|
|
replaces ``<tmp>/torch_extensions`` and all extensions will be compiled
|
|
into subfolders of this directory. Second, if the ``build_directory``
|
|
argument to this function is supplied, it overrides the entire path, i.e.
|
|
the library will be compiled into that folder directly.
|
|
|
|
To compile the sources, the default system compiler (``c++``) is used,
|
|
which can be overridden by setting the ``CXX`` environment variable. To pass
|
|
additional arguments to the compilation process, ``extra_cflags`` or
|
|
``extra_ldflags`` can be provided. For example, to compile your extension
|
|
with optimizations, pass ``extra_cflags=['-O3']``. You can also use
|
|
``extra_cflags`` to pass further include directories.
|
|
|
|
CUDA support with mixed compilation is provided. Simply pass CUDA source
|
|
files (``.cu`` or ``.cuh``) along with other sources. Such files will be
|
|
detected and compiled with nvcc rather than the C++ compiler. This includes
|
|
passing the CUDA lib64 directory as a library directory, and linking
|
|
``cudart``. You can pass additional flags to nvcc via
|
|
``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various
|
|
heuristics for finding the CUDA install directory are used, which usually
|
|
work fine. If not, setting the ``CUDA_HOME`` environment variable is the
|
|
safest option.
|
|
|
|
Args:
|
|
name: The name of the extension to build. This MUST be the same as the
|
|
name of the pybind11 module!
|
|
sources: A list of relative or absolute paths to C++ source files.
|
|
extra_cflags: optional list of compiler flags to forward to the build.
|
|
extra_cuda_cflags: optional list of compiler flags to forward to nvcc
|
|
when building CUDA sources.
|
|
extra_ldflags: optional list of linker flags to forward to the build.
|
|
extra_include_paths: optional list of include directories to forward
|
|
to the build.
|
|
build_directory: optional path to use as build workspace.
|
|
verbose: If ``True``, turns on verbose logging of load steps.
|
|
with_cuda: Determines whether CUDA headers and libraries are added to
|
|
the build. If set to ``None`` (default), this value is
|
|
automatically determined based on the existence of ``.cu`` or
|
|
``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers
|
|
and libraries to be included.
|
|
|
|
Returns:
|
|
The loaded PyTorch extension as a Python module.
|
|
|
|
Example:
|
|
>>> from torch.utils.cpp_extension import load
|
|
>>> module = load(
|
|
name='extension',
|
|
sources=['extension.cpp', 'extension_kernel.cu'],
|
|
extra_cflags=['-O2'],
|
|
verbose=True)
|
|
'''
|
|
return _jit_compile(
|
|
name,
|
|
[sources] if isinstance(sources, str) else sources,
|
|
extra_cflags,
|
|
extra_cuda_cflags,
|
|
extra_ldflags,
|
|
extra_include_paths,
|
|
build_directory or _get_build_directory(name, verbose),
|
|
verbose,
|
|
with_cuda)
|
|
|
|
|
|
def load_inline(name,
|
|
cpp_sources,
|
|
cuda_sources=None,
|
|
functions=None,
|
|
extra_cflags=None,
|
|
extra_cuda_cflags=None,
|
|
extra_ldflags=None,
|
|
extra_include_paths=None,
|
|
build_directory=None,
|
|
verbose=False,
|
|
with_cuda=None):
|
|
'''
|
|
Loads a PyTorch C++ extension just-in-time (JIT) from string sources.
|
|
|
|
This function behaves exactly like :func:`load`, but takes its sources as
|
|
strings rather than filenames. These strings are stored to files in the
|
|
build directory, after which the behavior of :func:`load_inline` is
|
|
identical to :func:`load`.
|
|
|
|
See `the
|
|
tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions.py>`_
|
|
for good examples of using this function.
|
|
|
|
Sources may omit two required parts of a typical non-inline C++ extension:
|
|
the necessary header includes, as well as the (pybind11) binding code. More
|
|
precisely, strings passed to ``cpp_sources`` are first concatenated into a
|
|
single ``.cpp`` file. This file is then prepended with ``#include
|
|
<torch/extension.h>``.
|
|
|
|
Furthermore, if the ``functions`` argument is supplied, bindings will be
|
|
automatically generated for each function specified. ``functions`` can
|
|
either be a list of function names, or a dictionary mapping from function
|
|
names to docstrings. If a list is given, the name of each function is used
|
|
as its docstring.
|
|
|
|
The sources in ``cuda_sources`` are concatenated into a separate ``.cu``
|
|
file and prepended with ``torch/types.h``, ``cuda.h`` and
|
|
``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled
|
|
separately, but ultimately linked into a single library. Note that no
|
|
bindings are generated for functions in ``cuda_sources`` per se. To bind
|
|
to a CUDA kernel, you must create a C++ function that calls it, and either
|
|
declare or define this C++ function in one of the ``cpp_sources`` (and
|
|
include its name in ``functions``).
|
|
|
|
See :func:`load` for a description of arguments omitted below.
|
|
|
|
Args:
|
|
cpp_sources: A string, or list of strings, containing C++ source code.
|
|
cuda_sources: A string, or list of strings, containing CUDA source code.
|
|
functions: A list of function names for which to generate function
|
|
bindings. If a dictionary is given, it should map function names to
|
|
docstrings (which are otherwise just the function names).
|
|
with_cuda: Determines whether CUDA headers and libraries are added to
|
|
the build. If set to ``None`` (default), this value is
|
|
automatically determined based on whether ``cuda_sources`` is
|
|
provided. Set it to `True`` to force CUDA headers
|
|
and libraries to be included.
|
|
|
|
Example:
|
|
>>> from torch.utils.cpp_extension import load_inline
|
|
>>> source = \'\'\'
|
|
at::Tensor sin_add(at::Tensor x, at::Tensor y) {
|
|
return x.sin() + y.sin();
|
|
}
|
|
\'\'\'
|
|
>>> module = load_inline(name='inline_extension',
|
|
cpp_sources=[source],
|
|
functions=['sin_add'])
|
|
'''
|
|
build_directory = build_directory or _get_build_directory(name, verbose)
|
|
|
|
if isinstance(cpp_sources, str):
|
|
cpp_sources = [cpp_sources]
|
|
cuda_sources = cuda_sources or []
|
|
if isinstance(cuda_sources, str):
|
|
cuda_sources = [cuda_sources]
|
|
|
|
cpp_sources.insert(0, '#include <torch/extension.h>')
|
|
|
|
# If `functions` is supplied, we create the pybind11 bindings for the user.
|
|
# Here, `functions` is (or becomes, after some processing) a map from
|
|
# function names to function docstrings.
|
|
if functions is not None:
|
|
cpp_sources.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')
|
|
if isinstance(functions, str):
|
|
functions = [functions]
|
|
if isinstance(functions, list):
|
|
# Make the function docstring the same as the function name.
|
|
functions = dict((f, f) for f in functions)
|
|
elif not isinstance(functions, dict):
|
|
raise ValueError(
|
|
"Expected 'functions' to be a list or dict, but was {}".format(
|
|
type(functions)))
|
|
for function_name, docstring in functions.items():
|
|
cpp_sources.append('m.def("{0}", &{0}, "{1}");'.format(
|
|
function_name, docstring))
|
|
cpp_sources.append('}')
|
|
|
|
cpp_source_path = os.path.join(build_directory, 'main.cpp')
|
|
with open(cpp_source_path, 'w') as cpp_source_file:
|
|
cpp_source_file.write('\n'.join(cpp_sources))
|
|
|
|
sources = [cpp_source_path]
|
|
|
|
if cuda_sources:
|
|
cuda_sources.insert(0, '#include <torch/types.h>')
|
|
cuda_sources.insert(1, '#include <cuda.h>')
|
|
cuda_sources.insert(2, '#include <cuda_runtime.h>')
|
|
|
|
cuda_source_path = os.path.join(build_directory, 'cuda.cu')
|
|
with open(cuda_source_path, 'w') as cuda_source_file:
|
|
cuda_source_file.write('\n'.join(cuda_sources))
|
|
|
|
sources.append(cuda_source_path)
|
|
|
|
return _jit_compile(
|
|
name,
|
|
sources,
|
|
extra_cflags,
|
|
extra_cuda_cflags,
|
|
extra_ldflags,
|
|
extra_include_paths,
|
|
build_directory,
|
|
verbose,
|
|
with_cuda)
|
|
|
|
|
|
def _jit_compile(name,
|
|
sources,
|
|
extra_cflags,
|
|
extra_cuda_cflags,
|
|
extra_ldflags,
|
|
extra_include_paths,
|
|
build_directory,
|
|
verbose,
|
|
with_cuda=None):
|
|
old_version = JIT_EXTENSION_VERSIONER.get_version(name)
|
|
version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
|
|
name,
|
|
sources,
|
|
build_arguments=[extra_cflags, extra_cuda_cflags, extra_ldflags, extra_include_paths],
|
|
build_directory=build_directory,
|
|
with_cuda=with_cuda
|
|
)
|
|
if version > 0:
|
|
if version != old_version and verbose:
|
|
print('The input conditions for extension module {} have changed. '.format(name) +
|
|
'Bumping to version {0} and re-building as {1}_v{0}...'.format(version, name))
|
|
name = '{}_v{}'.format(name, version)
|
|
|
|
if version != old_version:
|
|
baton = FileBaton(os.path.join(build_directory, 'lock'))
|
|
if baton.try_acquire():
|
|
try:
|
|
_write_ninja_file_and_build(
|
|
name=name,
|
|
sources=sources,
|
|
extra_cflags=extra_cflags or [],
|
|
extra_cuda_cflags=extra_cuda_cflags or [],
|
|
extra_ldflags=extra_ldflags or [],
|
|
extra_include_paths=extra_include_paths or [],
|
|
build_directory=build_directory,
|
|
verbose=verbose,
|
|
with_cuda=with_cuda)
|
|
finally:
|
|
baton.release()
|
|
else:
|
|
baton.wait()
|
|
elif verbose:
|
|
print('No modifications detected for re-loaded extension '
|
|
'module {}, skipping build step...'.format(name))
|
|
|
|
if verbose:
|
|
print('Loading extension module {}...'.format(name))
|
|
return _import_module_from_library(name, build_directory)
|
|
|
|
|
|
def _write_ninja_file_and_build(name,
|
|
sources,
|
|
extra_cflags,
|
|
extra_cuda_cflags,
|
|
extra_ldflags,
|
|
extra_include_paths,
|
|
build_directory,
|
|
verbose,
|
|
with_cuda):
|
|
verify_ninja_availability()
|
|
check_compiler_abi_compatibility(os.environ.get('CXX', 'c++'))
|
|
if with_cuda is None:
|
|
with_cuda = any(map(_is_cuda_file, sources))
|
|
extra_ldflags = _prepare_ldflags(
|
|
extra_ldflags or [],
|
|
with_cuda,
|
|
verbose)
|
|
build_file_path = os.path.join(build_directory, 'build.ninja')
|
|
if verbose:
|
|
print(
|
|
'Emitting ninja build file {}...'.format(build_file_path))
|
|
# NOTE: Emitting a new ninja build file does not cause re-compilation if
|
|
# the sources did not change, so it's ok to re-emit (and it's fast).
|
|
_write_ninja_file(
|
|
path=build_file_path,
|
|
name=name,
|
|
sources=sources,
|
|
extra_cflags=extra_cflags or [],
|
|
extra_cuda_cflags=extra_cuda_cflags or [],
|
|
extra_ldflags=extra_ldflags or [],
|
|
extra_include_paths=extra_include_paths or [],
|
|
with_cuda=with_cuda)
|
|
|
|
if verbose:
|
|
print('Building extension module {}...'.format(name))
|
|
_build_extension_module(name, build_directory, verbose)
|
|
|
|
|
|
def verify_ninja_availability():
|
|
'''
|
|
Returns ``True`` if the `ninja <https://ninja-build.org/>`_ build system is
|
|
available on the system.
|
|
'''
|
|
with open(os.devnull, 'wb') as devnull:
|
|
try:
|
|
subprocess.check_call('ninja --version'.split(), stdout=devnull)
|
|
except OSError:
|
|
raise RuntimeError("Ninja is required to load C++ extensions")
|
|
else:
|
|
return True
|
|
|
|
|
|
def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
|
|
if IS_WINDOWS:
|
|
python_path = os.path.dirname(sys.executable)
|
|
python_lib_path = os.path.join(python_path, 'libs')
|
|
|
|
here = os.path.abspath(__file__)
|
|
torch_path = os.path.dirname(os.path.dirname(here))
|
|
lib_path = os.path.join(torch_path, 'lib')
|
|
|
|
extra_ldflags.append('c10.lib')
|
|
extra_ldflags.append('caffe2.lib')
|
|
extra_ldflags.append('torch.lib')
|
|
if with_cuda:
|
|
extra_ldflags.append('caffe2_gpu.lib')
|
|
extra_ldflags.append('_C.lib')
|
|
extra_ldflags.append('/LIBPATH:{}'.format(python_lib_path))
|
|
extra_ldflags.append('/LIBPATH:{}'.format(lib_path))
|
|
|
|
if with_cuda:
|
|
if verbose:
|
|
print('Detected CUDA files, patching ldflags')
|
|
if IS_WINDOWS:
|
|
extra_ldflags.append('/LIBPATH:{}'.format(
|
|
_join_cuda_home('lib/x64')))
|
|
extra_ldflags.append('cudart.lib')
|
|
if CUDNN_HOME is not None:
|
|
extra_ldflags.append(os.path.join(CUDNN_HOME, 'lib/x64'))
|
|
else:
|
|
extra_ldflags.append('-L{}'.format(_join_cuda_home('lib64')))
|
|
extra_ldflags.append('-lcudart')
|
|
if CUDNN_HOME is not None:
|
|
extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64')))
|
|
|
|
return extra_ldflags
|
|
|
|
|
|
def _get_build_directory(name, verbose):
|
|
root_extensions_directory = os.environ.get('TORCH_EXTENSIONS_DIR')
|
|
if root_extensions_directory is None:
|
|
root_extensions_directory = get_default_build_root()
|
|
|
|
if verbose:
|
|
print('Using {} as PyTorch extensions root...'.format(
|
|
root_extensions_directory))
|
|
|
|
build_directory = os.path.join(root_extensions_directory, name)
|
|
if not os.path.exists(build_directory):
|
|
if verbose:
|
|
print('Creating extension directory {}...'.format(build_directory))
|
|
# This is like mkdir -p, i.e. will also create parent directories.
|
|
os.makedirs(build_directory)
|
|
|
|
return build_directory
|
|
|
|
|
|
def _build_extension_module(name, build_directory, verbose):
|
|
try:
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
if sys.version_info >= (3, 5):
|
|
subprocess.run(
|
|
['ninja', '-v'],
|
|
stdout=None if verbose else subprocess.PIPE,
|
|
stderr=subprocess.STDOUT,
|
|
cwd=build_directory,
|
|
check=True)
|
|
else:
|
|
subprocess.check_output(
|
|
['ninja', '-v'],
|
|
stderr=subprocess.STDOUT,
|
|
cwd=build_directory)
|
|
except subprocess.CalledProcessError:
|
|
# Python 2 and 3 compatible way of getting the error object.
|
|
_, error, _ = sys.exc_info()
|
|
# error.output contains the stdout and stderr of the build attempt.
|
|
message = "Error building extension '{}'".format(name)
|
|
if hasattr(error, 'output') and error.output:
|
|
message += ": {}".format(error.output.decode())
|
|
raise RuntimeError(message)
|
|
|
|
|
|
def _import_module_from_library(module_name, path):
|
|
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
|
|
file, path, description = imp.find_module(module_name, [path])
|
|
# Close the .so file after load.
|
|
with file:
|
|
return imp.load_module(module_name, file, path, description)
|
|
|
|
|
|
def _write_ninja_file(path,
|
|
name,
|
|
sources,
|
|
extra_cflags,
|
|
extra_cuda_cflags,
|
|
extra_ldflags,
|
|
extra_include_paths,
|
|
with_cuda):
|
|
extra_cflags = [flag.strip() for flag in extra_cflags]
|
|
extra_cuda_cflags = [flag.strip() for flag in extra_cuda_cflags]
|
|
extra_ldflags = [flag.strip() for flag in extra_ldflags]
|
|
extra_include_paths = [flag.strip() for flag in extra_include_paths]
|
|
|
|
# Version 1.3 is required for the `deps` directive.
|
|
config = ['ninja_required_version = 1.3']
|
|
config.append('cxx = {}'.format(os.environ.get('CXX', 'c++')))
|
|
if with_cuda:
|
|
config.append('nvcc = {}'.format(_join_cuda_home('bin', 'nvcc')))
|
|
|
|
# Turn into absolute paths so we can emit them into the ninja build
|
|
# file wherever it is.
|
|
sources = [os.path.abspath(file) for file in sources]
|
|
user_includes = [os.path.abspath(file) for file in extra_include_paths]
|
|
|
|
# include_paths() gives us the location of torch/extension.h
|
|
system_includes = include_paths(with_cuda)
|
|
# sysconfig.get_paths()['include'] gives us the location of Python.h
|
|
system_includes.append(sysconfig.get_paths()['include'])
|
|
|
|
# Windoze does not understand `-isystem`.
|
|
if IS_WINDOWS:
|
|
user_includes += system_includes
|
|
system_includes.clear()
|
|
|
|
common_cflags = ['-DTORCH_EXTENSION_NAME={}'.format(name)]
|
|
common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')
|
|
common_cflags += ['-I{}'.format(include) for include in user_includes]
|
|
common_cflags += ['-isystem {}'.format(include) for include in system_includes]
|
|
|
|
if _is_binary_build():
|
|
common_cflags += ['-D_GLIBCXX_USE_CXX11_ABI=0']
|
|
|
|
cflags = common_cflags + ['-fPIC', '-std=c++11'] + extra_cflags
|
|
if IS_WINDOWS:
|
|
from distutils.spawn import _nt_quote_args
|
|
cflags = _nt_quote_args(cflags)
|
|
flags = ['cflags = {}'.format(' '.join(cflags))]
|
|
|
|
if with_cuda:
|
|
cuda_flags = common_cflags + COMMON_NVCC_FLAGS
|
|
if IS_WINDOWS:
|
|
cuda_flags = _nt_quote_args(cuda_flags)
|
|
else:
|
|
cuda_flags += ['--compiler-options', "'-fPIC'"]
|
|
cuda_flags += extra_cuda_cflags
|
|
if not any(flag.startswith('-std=') for flag in cuda_flags):
|
|
cuda_flags.append('-std=c++11')
|
|
|
|
flags.append('cuda_flags = {}'.format(' '.join(cuda_flags)))
|
|
|
|
if IS_WINDOWS:
|
|
ldflags = ['/DLL'] + extra_ldflags
|
|
else:
|
|
ldflags = ['-shared'] + extra_ldflags
|
|
# The darwin linker needs explicit consent to ignore unresolved symbols.
|
|
if sys.platform == 'darwin':
|
|
ldflags.append('-undefined dynamic_lookup')
|
|
elif IS_WINDOWS:
|
|
ldflags = _nt_quote_args(ldflags)
|
|
flags.append('ldflags = {}'.format(' '.join(ldflags)))
|
|
|
|
# See https://ninja-build.org/build.ninja.html for reference.
|
|
compile_rule = ['rule compile']
|
|
if IS_WINDOWS:
|
|
compile_rule.append(
|
|
' command = cl /showIncludes $cflags -c $in /Fo$out')
|
|
compile_rule.append(' deps = msvc')
|
|
else:
|
|
compile_rule.append(
|
|
' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out')
|
|
compile_rule.append(' depfile = $out.d')
|
|
compile_rule.append(' deps = gcc')
|
|
|
|
if with_cuda:
|
|
cuda_compile_rule = ['rule cuda_compile']
|
|
cuda_compile_rule.append(
|
|
' command = $nvcc $cuda_flags -c $in -o $out')
|
|
|
|
link_rule = ['rule link']
|
|
if IS_WINDOWS:
|
|
cl_paths = subprocess.check_output(['where',
|
|
'cl']).decode().split('\r\n')
|
|
if len(cl_paths) >= 1:
|
|
cl_path = os.path.dirname(cl_paths[0]).replace(':', '$:')
|
|
else:
|
|
raise RuntimeError("MSVC is required to load C++ extensions")
|
|
link_rule.append(
|
|
' command = "{}/link.exe" $in /nologo $ldflags /out:$out'.format(
|
|
cl_path))
|
|
else:
|
|
link_rule.append(' command = $cxx $in $ldflags -o $out')
|
|
|
|
# Emit one build rule per source to enable incremental build.
|
|
object_files = []
|
|
build = []
|
|
for source_file in sources:
|
|
# '/path/to/file.cpp' -> 'file'
|
|
file_name = os.path.splitext(os.path.basename(source_file))[0]
|
|
if _is_cuda_file(source_file) and with_cuda:
|
|
rule = 'cuda_compile'
|
|
# Use a different object filename in case a C++ and CUDA file have
|
|
# the same filename but different extension (.cpp vs. .cu).
|
|
target = '{}.cuda.o'.format(file_name)
|
|
else:
|
|
rule = 'compile'
|
|
target = '{}.o'.format(file_name)
|
|
object_files.append(target)
|
|
if IS_WINDOWS:
|
|
source_file = source_file.replace(':', '$:')
|
|
source_file = source_file.replace(" ", "$ ")
|
|
build.append('build {}: {} {}'.format(target, rule, source_file))
|
|
|
|
ext = 'pyd' if IS_WINDOWS else 'so'
|
|
library_target = '{}.{}'.format(name, ext)
|
|
|
|
link = ['build {}: link {}'.format(library_target, ' '.join(object_files))]
|
|
|
|
default = ['default {}'.format(library_target)]
|
|
|
|
# 'Blocks' should be separated by newlines, for visual benefit.
|
|
blocks = [config, flags, compile_rule]
|
|
if with_cuda:
|
|
blocks.append(cuda_compile_rule)
|
|
blocks += [link_rule, build, link, default]
|
|
with open(path, 'w') as build_file:
|
|
for block in blocks:
|
|
lines = '\n'.join(block)
|
|
build_file.write('{}\n\n'.format(lines))
|
|
|
|
|
|
def _join_cuda_home(*paths):
|
|
'''
|
|
Joins paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set.
|
|
|
|
This is basically a lazy way of raising an error for missing $CUDA_HOME
|
|
only once we need to get any CUDA-specific path.
|
|
'''
|
|
if CUDA_HOME is None:
|
|
raise EnvironmentError('CUDA_HOME environment variable is not set. '
|
|
'Please set it to your CUDA install root.')
|
|
return os.path.join(CUDA_HOME, *paths)
|
|
|
|
|
|
def _is_cuda_file(path):
|
|
return os.path.splitext(path)[1] in ['.cu', '.cuh']
|