mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] use pathlib.Path instead of os.path.* in setup.py (#156742)
Resolves: - https://github.com/pytorch/pytorch/pull/155998#discussion_r2164376634 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156742 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
90b973a2e2
commit
2380115f97
@ -122,6 +122,7 @@ is_formatter = true
|
|||||||
[[linter]]
|
[[linter]]
|
||||||
code = 'MYPY'
|
code = 'MYPY'
|
||||||
include_patterns = [
|
include_patterns = [
|
||||||
|
'setup.py',
|
||||||
'torch/**/*.py',
|
'torch/**/*.py',
|
||||||
'torch/**/*.pyi',
|
'torch/**/*.pyi',
|
||||||
'caffe2/**/*.py',
|
'caffe2/**/*.py',
|
||||||
|
|||||||
521
setup.py
521
setup.py
@ -257,23 +257,31 @@ if sys.version_info < python_min_version:
|
|||||||
import filecmp
|
import filecmp
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import itertools
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sysconfig
|
import sysconfig
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, ClassVar, IO
|
||||||
|
|
||||||
import setuptools.command.build_ext
|
import setuptools.command.build_ext
|
||||||
import setuptools.command.install
|
|
||||||
import setuptools.command.sdist
|
import setuptools.command.sdist
|
||||||
from setuptools import Extension, find_packages, setup
|
import setuptools.errors
|
||||||
|
from setuptools import Command, Extension, find_packages, setup
|
||||||
from setuptools.dist import Distribution
|
from setuptools.dist import Distribution
|
||||||
from tools.build_pytorch_libs import build_pytorch
|
from tools.build_pytorch_libs import build_pytorch
|
||||||
from tools.generate_torch_version import get_torch_version
|
from tools.generate_torch_version import get_torch_version
|
||||||
from tools.setup_helpers.cmake import CMake
|
from tools.setup_helpers.cmake import CMake, CMakeValue
|
||||||
from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
|
from tools.setup_helpers.env import (
|
||||||
|
BUILD_DIR,
|
||||||
|
build_type,
|
||||||
|
IS_DARWIN,
|
||||||
|
IS_LINUX,
|
||||||
|
IS_WINDOWS,
|
||||||
|
)
|
||||||
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
||||||
|
|
||||||
|
|
||||||
@ -318,18 +326,20 @@ def str2bool(value: str | None) -> bool:
|
|||||||
raise ValueError(f"Invalid string value for boolean conversion: {value}")
|
raise ValueError(f"Invalid string value for boolean conversion: {value}")
|
||||||
|
|
||||||
|
|
||||||
def _get_package_path(package_name):
|
def _get_package_path(package_name: str) -> Path:
|
||||||
spec = importlib.util.find_spec(package_name)
|
from importlib.util import find_spec
|
||||||
|
|
||||||
|
spec = find_spec(package_name)
|
||||||
if spec:
|
if spec:
|
||||||
# The package might be a namespace package, so get_data may fail
|
# The package might be a namespace package, so get_data may fail
|
||||||
try:
|
try:
|
||||||
loader = spec.loader
|
loader = spec.loader
|
||||||
if loader is not None:
|
if loader is not None:
|
||||||
file_path = loader.get_filename() # type: ignore[attr-defined]
|
file_path = loader.get_filename() # type: ignore[attr-defined]
|
||||||
return os.path.dirname(file_path)
|
return Path(file_path).parent
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
return None
|
return CWD / package_name
|
||||||
|
|
||||||
|
|
||||||
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
|
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
|
||||||
@ -343,7 +353,7 @@ if BUILD_LIBTORCH_WHL:
|
|||||||
|
|
||||||
if BUILD_PYTHON_ONLY:
|
if BUILD_PYTHON_ONLY:
|
||||||
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
||||||
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
|
os.environ["LIBTORCH_LIB_PATH"] = (_get_package_path("torch") / "lib").as_posix()
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Parameters parsed from environment
|
# Parameters parsed from environment
|
||||||
@ -381,60 +391,61 @@ sys.argv = filtered_args
|
|||||||
|
|
||||||
if VERBOSE_SCRIPT:
|
if VERBOSE_SCRIPT:
|
||||||
|
|
||||||
def report(*args, file=sys.stderr, **kwargs):
|
def report(*args: Any, file: IO[str] = sys.stderr, **kwargs: Any) -> None:
|
||||||
print(*args, file=file, **kwargs)
|
print(*args, file=file, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def report(*args, **kwargs):
|
def report(*args: Any, file: IO[str] = sys.stderr, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Make distutils respect --quiet too
|
# Make distutils respect --quiet too
|
||||||
setuptools.distutils.log.warn = report
|
setuptools.distutils.log.warn = report # type: ignore[attr-defined]
|
||||||
|
|
||||||
# Constant known variables used throughout this file
|
# Constant known variables used throughout this file
|
||||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
CWD = Path(__file__).absolute().parent
|
||||||
lib_path = os.path.join(cwd, "torch", "lib")
|
TORCH_LIB_DIR = CWD / "torch" / "lib"
|
||||||
third_party_path = os.path.join(cwd, "third_party")
|
THIRD_PARTY_DIR = CWD / "third_party"
|
||||||
|
|
||||||
# CMAKE: full path to python library
|
# CMAKE: full path to python library
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
cmake_python_library = "{}/libs/python{}.lib".format(
|
CMAKE_PYTHON_LIBRARY = (
|
||||||
sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION")
|
Path(sysconfig.get_config_var("prefix"))
|
||||||
|
/ "libs"
|
||||||
|
/ f"python{sysconfig.get_config_var('VERSION')}.lib"
|
||||||
)
|
)
|
||||||
# Fix virtualenv builds
|
# Fix virtualenv builds
|
||||||
if not os.path.exists(cmake_python_library):
|
if not CMAKE_PYTHON_LIBRARY.exists():
|
||||||
cmake_python_library = "{}/libs/python{}.lib".format(
|
CMAKE_PYTHON_LIBRARY = (
|
||||||
sys.base_prefix, sysconfig.get_config_var("VERSION")
|
Path(sys.base_prefix)
|
||||||
|
/ "libs"
|
||||||
|
/ f"python{sysconfig.get_config_var('VERSION')}.lib"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cmake_python_library = "{}/{}".format(
|
CMAKE_PYTHON_LIBRARY = Path(
|
||||||
sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME")
|
sysconfig.get_config_var("LIBDIR")
|
||||||
)
|
) / sysconfig.get_config_var("INSTSONAME")
|
||||||
cmake_python_include_dir = sysconfig.get_path("include")
|
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Version, create_version_file, and package_name
|
# Version, create_version_file, and package_name
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
package_name = os.getenv("TORCH_PACKAGE_NAME", "torch")
|
TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch")
|
||||||
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python")
|
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python")
|
||||||
if BUILD_LIBTORCH_WHL:
|
if BUILD_LIBTORCH_WHL:
|
||||||
package_name = LIBTORCH_PKG_NAME
|
TORCH_PACKAGE_NAME = LIBTORCH_PKG_NAME
|
||||||
|
|
||||||
|
TORCH_VERSION = get_torch_version()
|
||||||
package_type = os.getenv("PACKAGE_TYPE", "wheel")
|
report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}")
|
||||||
version = get_torch_version()
|
|
||||||
report(f"Building wheel {package_name}-{version}")
|
|
||||||
|
|
||||||
cmake = CMake()
|
cmake = CMake()
|
||||||
|
|
||||||
|
|
||||||
def get_submodule_folders():
|
def get_submodule_folders() -> list[Path]:
|
||||||
git_modules_path = os.path.join(cwd, ".gitmodules")
|
git_modules_file = CWD / ".gitmodules"
|
||||||
default_modules_path = [
|
default_modules_path = [
|
||||||
os.path.join(third_party_path, name)
|
THIRD_PARTY_DIR / name
|
||||||
for name in [
|
for name in [
|
||||||
"gloo",
|
"gloo",
|
||||||
"cpuinfo",
|
"cpuinfo",
|
||||||
@ -443,26 +454,26 @@ def get_submodule_folders():
|
|||||||
"cutlass",
|
"cutlass",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
if not os.path.exists(git_modules_path):
|
if not git_modules_file.exists():
|
||||||
return default_modules_path
|
return default_modules_path
|
||||||
with open(git_modules_path) as f:
|
with git_modules_file.open(encoding="utf-8") as f:
|
||||||
return [
|
return [
|
||||||
os.path.join(cwd, line.split("=", 1)[1].strip())
|
CWD / line.partition("=")[-1].strip()
|
||||||
for line in f
|
for line in f
|
||||||
if line.strip().startswith("path")
|
if line.strip().startswith("path")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def check_submodules():
|
def check_submodules() -> None:
|
||||||
def check_for_files(folder, files):
|
def check_for_files(folder: Path, files: list[str]) -> None:
|
||||||
if not any(os.path.exists(os.path.join(folder, f)) for f in files):
|
if not any((folder / f).exists() for f in files):
|
||||||
report("Could not find any of {} in {}".format(", ".join(files), folder))
|
report("Could not find any of {} in {}".format(", ".join(files), folder))
|
||||||
report("Did you run 'git submodule update --init --recursive'?")
|
report("Did you run 'git submodule update --init --recursive'?")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def not_exists_or_empty(folder):
|
def not_exists_or_empty(folder: Path) -> bool:
|
||||||
return not os.path.exists(folder) or (
|
return not folder.exists() or (
|
||||||
os.path.isdir(folder) and len(os.listdir(folder)) == 0
|
folder.is_dir() and next(folder.iterdir(), None) is None
|
||||||
)
|
)
|
||||||
|
|
||||||
if str2bool(os.getenv("USE_SYSTEM_LIBS")):
|
if str2bool(os.getenv("USE_SYSTEM_LIBS")):
|
||||||
@ -474,7 +485,7 @@ def check_submodules():
|
|||||||
report(" --- Trying to initialize submodules")
|
report(" --- Trying to initialize submodules")
|
||||||
start = time.time()
|
start = time.time()
|
||||||
subprocess.check_call(
|
subprocess.check_call(
|
||||||
["git", "submodule", "update", "--init", "--recursive"], cwd=cwd
|
["git", "submodule", "update", "--init", "--recursive"], cwd=CWD
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
report(f" --- Submodule initialization took {end - start:.2f} sec")
|
report(f" --- Submodule initialization took {end - start:.2f} sec")
|
||||||
@ -495,37 +506,41 @@ def check_submodules():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
check_for_files(
|
check_for_files(
|
||||||
os.path.join(third_party_path, "fbgemm", "external", "asmjit"),
|
THIRD_PARTY_DIR / "fbgemm" / "external" / "asmjit",
|
||||||
["CMakeLists.txt"],
|
["CMakeLists.txt"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Windows has very bad support for symbolic links.
|
# Windows has very bad support for symbolic links.
|
||||||
# Instead of using symlinks, we're going to copy files over
|
# Instead of using symlinks, we're going to copy files over
|
||||||
def mirror_files_into_torchgen():
|
def mirror_files_into_torchgen() -> None:
|
||||||
# (new_path, orig_path)
|
# (new_path, orig_path)
|
||||||
# Directories are OK and are recursively mirrored.
|
# Directories are OK and are recursively mirrored.
|
||||||
paths = [
|
new_paths = [
|
||||||
(
|
"torchgen/packaged/ATen/native/native_functions.yaml",
|
||||||
"torchgen/packaged/ATen/native/native_functions.yaml",
|
"torchgen/packaged/ATen/native/tags.yaml",
|
||||||
"aten/src/ATen/native/native_functions.yaml",
|
"torchgen/packaged/ATen/templates",
|
||||||
),
|
"torchgen/packaged/autograd",
|
||||||
("torchgen/packaged/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml"),
|
"torchgen/packaged/autograd/templates",
|
||||||
("torchgen/packaged/ATen/templates", "aten/src/ATen/templates"),
|
|
||||||
("torchgen/packaged/autograd", "tools/autograd"),
|
|
||||||
("torchgen/packaged/autograd/templates", "tools/autograd/templates"),
|
|
||||||
]
|
]
|
||||||
for new_path, orig_path in paths:
|
orig_paths = [
|
||||||
|
"aten/src/ATen/native/native_functions.yaml",
|
||||||
|
"aten/src/ATen/native/tags.yaml",
|
||||||
|
"aten/src/ATen/templates",
|
||||||
|
"tools/autograd",
|
||||||
|
"tools/autograd/templates",
|
||||||
|
]
|
||||||
|
for new_path, orig_path in zip(map(Path, new_paths), map(Path, orig_paths)):
|
||||||
# Create the dirs involved in new_path if they don't exist
|
# Create the dirs involved in new_path if they don't exist
|
||||||
if not os.path.exists(new_path):
|
if not new_path.exists():
|
||||||
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Copy the files from the orig location to the new location
|
# Copy the files from the orig location to the new location
|
||||||
if os.path.isfile(orig_path):
|
if orig_path.is_file():
|
||||||
shutil.copyfile(orig_path, new_path)
|
shutil.copyfile(orig_path, new_path)
|
||||||
continue
|
continue
|
||||||
if os.path.isdir(orig_path):
|
if orig_path.is_dir():
|
||||||
if os.path.exists(new_path):
|
if new_path.exists():
|
||||||
# copytree fails if the tree exists already, so remove it.
|
# copytree fails if the tree exists already, so remove it.
|
||||||
shutil.rmtree(new_path)
|
shutil.rmtree(new_path)
|
||||||
shutil.copytree(orig_path, new_path)
|
shutil.copytree(orig_path, new_path)
|
||||||
@ -534,15 +549,14 @@ def mirror_files_into_torchgen():
|
|||||||
|
|
||||||
|
|
||||||
# all the work we need to do _before_ setup runs
|
# all the work we need to do _before_ setup runs
|
||||||
def build_deps():
|
def build_deps() -> None:
|
||||||
report("-- Building version " + version)
|
report(f"-- Building version {TORCH_VERSION}")
|
||||||
check_submodules()
|
check_submodules()
|
||||||
check_pydep("yaml", "pyyaml")
|
check_pydep("yaml", "pyyaml")
|
||||||
build_python = not BUILD_LIBTORCH_WHL
|
|
||||||
build_pytorch(
|
build_pytorch(
|
||||||
version=version,
|
version=TORCH_VERSION,
|
||||||
cmake_python_library=cmake_python_library,
|
cmake_python_library=CMAKE_PYTHON_LIBRARY.as_posix(),
|
||||||
build_python=build_python,
|
build_python=not BUILD_LIBTORCH_WHL,
|
||||||
rerun_cmake=RERUN_CMAKE,
|
rerun_cmake=RERUN_CMAKE,
|
||||||
cmake_only=CMAKE_ONLY,
|
cmake_only=CMAKE_ONLY,
|
||||||
cmake=cmake,
|
cmake=cmake,
|
||||||
@ -568,13 +582,13 @@ def build_deps():
|
|||||||
"third_party/valgrind-headers/callgrind.h",
|
"third_party/valgrind-headers/callgrind.h",
|
||||||
"third_party/valgrind-headers/valgrind.h",
|
"third_party/valgrind-headers/valgrind.h",
|
||||||
]
|
]
|
||||||
for sym_file, orig_file in zip(sym_files, orig_files):
|
for sym_file, orig_file in zip(map(Path, sym_files), map(Path, orig_files)):
|
||||||
same = False
|
same = False
|
||||||
if os.path.exists(sym_file):
|
if sym_file.exists():
|
||||||
if filecmp.cmp(sym_file, orig_file):
|
if filecmp.cmp(sym_file, orig_file):
|
||||||
same = True
|
same = True
|
||||||
else:
|
else:
|
||||||
os.remove(sym_file)
|
sym_file.unlink()
|
||||||
if not same:
|
if not same:
|
||||||
shutil.copyfile(orig_file, sym_file)
|
shutil.copyfile(orig_file, sym_file)
|
||||||
|
|
||||||
@ -589,7 +603,7 @@ Please install it via `conda install {module}` or `pip install {module}`
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def check_pydep(importname, module):
|
def check_pydep(importname: str, module: str) -> None:
|
||||||
try:
|
try:
|
||||||
importlib.import_module(importname)
|
importlib.import_module(importname)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -599,19 +613,22 @@ def check_pydep(importname, module):
|
|||||||
|
|
||||||
|
|
||||||
class build_ext(setuptools.command.build_ext.build_ext):
|
class build_ext(setuptools.command.build_ext.build_ext):
|
||||||
def _embed_libomp(self):
|
def _embed_libomp(self) -> None:
|
||||||
# Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS
|
# Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS
|
||||||
lib_dir = os.path.join(self.build_lib, "torch", "lib")
|
build_lib = Path(self.build_lib)
|
||||||
libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib")
|
build_torch_lib_dir = build_lib / "torch" / "lib"
|
||||||
if not os.path.exists(libtorch_cpu_path):
|
build_torch_include_dir = build_lib / "torch" / "include"
|
||||||
|
libtorch_cpu_path = build_torch_lib_dir / "libtorch_cpu.dylib"
|
||||||
|
if not libtorch_cpu_path.exists():
|
||||||
return
|
return
|
||||||
# Parse libtorch_cpu load commands
|
# Parse libtorch_cpu load commands
|
||||||
otool_cmds = (
|
otool_cmds = (
|
||||||
subprocess.check_output(["otool", "-l", libtorch_cpu_path])
|
subprocess.check_output(["otool", "-l", str(libtorch_cpu_path)])
|
||||||
.decode("utf-8")
|
.decode("utf-8")
|
||||||
.split("\n")
|
.split("\n")
|
||||||
)
|
)
|
||||||
rpaths, libs = [], []
|
rpaths: list[str] = []
|
||||||
|
libs: list[str] = []
|
||||||
for idx, line in enumerate(otool_cmds):
|
for idx, line in enumerate(otool_cmds):
|
||||||
if line.strip() == "cmd LC_LOAD_DYLIB":
|
if line.strip() == "cmd LC_LOAD_DYLIB":
|
||||||
lib_name = otool_cmds[idx + 2].strip()
|
lib_name = otool_cmds[idx + 2].strip()
|
||||||
@ -623,8 +640,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
assert rpath.startswith("path ")
|
assert rpath.startswith("path ")
|
||||||
rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1])
|
rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1])
|
||||||
|
|
||||||
omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"]
|
omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment]
|
||||||
omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib"
|
omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment]
|
||||||
|
omplib_name += ".dylib"
|
||||||
omplib_rpath_path = os.path.join("@rpath", omplib_name)
|
omplib_rpath_path = os.path.join("@rpath", omplib_name)
|
||||||
|
|
||||||
# This logic is fragile and checks only two cases:
|
# This logic is fragile and checks only two cases:
|
||||||
@ -634,8 +652,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Copy libomp/libiomp5 from rpath locations
|
# Copy libomp/libiomp5 from rpath locations
|
||||||
target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name)
|
target_lib = build_torch_lib_dir / omplib_name
|
||||||
libomp_relocated = False
|
libomp_relocated = False
|
||||||
|
install_name_tool_args: list[str] = []
|
||||||
for rpath in rpaths:
|
for rpath in rpaths:
|
||||||
source_lib = os.path.join(rpath, omplib_name)
|
source_lib = os.path.join(rpath, omplib_name)
|
||||||
if not os.path.exists(source_lib):
|
if not os.path.exists(source_lib):
|
||||||
@ -666,22 +685,29 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
]
|
]
|
||||||
libomp_relocated = True
|
libomp_relocated = True
|
||||||
if libomp_relocated:
|
if libomp_relocated:
|
||||||
install_name_tool_args.insert(0, "install_name_tool")
|
install_name_tool_args = [
|
||||||
install_name_tool_args.append(libtorch_cpu_path)
|
"install_name_tool",
|
||||||
|
*install_name_tool_args,
|
||||||
|
str(libtorch_cpu_path),
|
||||||
|
]
|
||||||
subprocess.check_call(install_name_tool_args)
|
subprocess.check_call(install_name_tool_args)
|
||||||
# Copy omp.h from OpenMP_C_FLAGS and copy it into include folder
|
# Copy omp.h from OpenMP_C_FLAGS and copy it into include folder
|
||||||
omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"]
|
omp_cflags: str = get_cmake_cache_vars()["OpenMP_C_FLAGS"] # type: ignore[assignment]
|
||||||
if not omp_cflags:
|
if not omp_cflags:
|
||||||
return
|
return
|
||||||
for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]:
|
for include_dir in [
|
||||||
omp_h = os.path.join(include_dir, "omp.h")
|
Path(f.removeprefix("-I"))
|
||||||
if not os.path.exists(omp_h):
|
for f in omp_cflags.split(" ")
|
||||||
|
if f.startswith("-I")
|
||||||
|
]:
|
||||||
|
omp_h = include_dir / "omp.h"
|
||||||
|
if not omp_h.exists():
|
||||||
continue
|
continue
|
||||||
target_omp_h = os.path.join(self.build_lib, "torch", "include", "omp.h")
|
target_omp_h = build_torch_include_dir / "omp.h"
|
||||||
self.copy_file(omp_h, target_omp_h)
|
self.copy_file(omp_h, target_omp_h)
|
||||||
break
|
break
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
# Report build options. This is run after the build completes so # `CMakeCache.txt` exists
|
# Report build options. This is run after the build completes so # `CMakeCache.txt` exists
|
||||||
# and we can get an accurate report on what is used and what is not.
|
# and we can get an accurate report on what is used and what is not.
|
||||||
cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables())
|
cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables())
|
||||||
@ -692,18 +718,17 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
if cmake_cache_vars["USE_CUDNN"]:
|
if cmake_cache_vars["USE_CUDNN"]:
|
||||||
report(
|
report(
|
||||||
"-- Detected cuDNN at "
|
"-- Detected cuDNN at "
|
||||||
+ cmake_cache_vars["CUDNN_LIBRARY"]
|
f"{cmake_cache_vars['CUDNN_LIBRARY']}, "
|
||||||
+ ", "
|
f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}"
|
||||||
+ cmake_cache_vars["CUDNN_INCLUDE_DIR"]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
report("-- Not using cuDNN")
|
report("-- Not using cuDNN")
|
||||||
if cmake_cache_vars["USE_CUDA"]:
|
if cmake_cache_vars["USE_CUDA"]:
|
||||||
report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"])
|
report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}")
|
||||||
else:
|
else:
|
||||||
report("-- Not using CUDA")
|
report("-- Not using CUDA")
|
||||||
if cmake_cache_vars["USE_XPU"]:
|
if cmake_cache_vars["USE_XPU"]:
|
||||||
report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"])
|
report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}")
|
||||||
else:
|
else:
|
||||||
report("-- Not using XPU")
|
report("-- Not using XPU")
|
||||||
if cmake_cache_vars["USE_MKLDNN"]:
|
if cmake_cache_vars["USE_MKLDNN"]:
|
||||||
@ -722,10 +747,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
report("-- Not using MKLDNN")
|
report("-- Not using MKLDNN")
|
||||||
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
|
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
|
||||||
report(
|
report(
|
||||||
"-- Using system provided NCCL library at {}, {}".format(
|
"-- Using system provided NCCL library at "
|
||||||
cmake_cache_vars["NCCL_LIBRARIES"],
|
f"{cmake_cache_vars['NCCL_LIBRARIES']}, "
|
||||||
cmake_cache_vars["NCCL_INCLUDE_DIRS"],
|
f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}"
|
||||||
)
|
|
||||||
)
|
)
|
||||||
elif cmake_cache_vars["USE_NCCL"]:
|
elif cmake_cache_vars["USE_NCCL"]:
|
||||||
report("-- Building NCCL library")
|
report("-- Building NCCL library")
|
||||||
@ -736,18 +760,15 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
report("-- Building without distributed package")
|
report("-- Building without distributed package")
|
||||||
else:
|
else:
|
||||||
report("-- Building with distributed package: ")
|
report("-- Building with distributed package: ")
|
||||||
report(
|
report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}")
|
||||||
" -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"])
|
report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}")
|
||||||
)
|
report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}")
|
||||||
report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"]))
|
|
||||||
report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"]))
|
|
||||||
else:
|
else:
|
||||||
report("-- Building without distributed package")
|
report("-- Building without distributed package")
|
||||||
if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]:
|
if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]:
|
||||||
report(
|
report(
|
||||||
"-- Using static dispatch with backend {}".format(
|
"-- Using static dispatch with "
|
||||||
cmake_cache_vars["STATIC_DISPATCH_BACKEND"]
|
f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}"
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]:
|
if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]:
|
||||||
report("-- Using lightweight dispatch")
|
report("-- Using lightweight dispatch")
|
||||||
@ -759,98 +780,90 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
|
|
||||||
# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
|
# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
|
||||||
# in system CFLAGS
|
# in system CFLAGS
|
||||||
c_flags = str(os.getenv("CFLAGS", ""))
|
c_flags = os.getenv("CFLAGS", "")
|
||||||
if (
|
if (
|
||||||
IS_LINUX
|
IS_LINUX
|
||||||
and "-fstack-clash-protection" in c_flags
|
and "-fstack-clash-protection" in c_flags
|
||||||
and "clang" in os.environ.get("CC", "")
|
and "clang" in os.getenv("CC", "")
|
||||||
):
|
):
|
||||||
os.environ["CC"] = str(os.environ["CC"])
|
os.environ["CC"] = str(os.environ["CC"])
|
||||||
|
|
||||||
# It's an old-style class in Python 2.7...
|
super().run()
|
||||||
setuptools.command.build_ext.build_ext.run(self)
|
|
||||||
|
|
||||||
if IS_DARWIN:
|
if IS_DARWIN:
|
||||||
self._embed_libomp()
|
self._embed_libomp()
|
||||||
|
|
||||||
# Copy the essential export library to compile C++ extensions.
|
# Copy the essential export library to compile C++ extensions.
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
build_temp = self.build_temp
|
build_temp = Path(self.build_temp)
|
||||||
|
build_lib = Path(self.build_lib)
|
||||||
|
|
||||||
ext_filename = self.get_ext_filename("_C")
|
ext_filename = self.get_ext_filename("_C")
|
||||||
lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib"
|
lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib"
|
||||||
|
|
||||||
export_lib = os.path.join(
|
export_lib = build_temp / "torch" / "csrc" / lib_filename
|
||||||
build_temp, "torch", "csrc", lib_filename
|
target_lib = build_lib / "torch" / "lib" / "_C.lib"
|
||||||
).replace("\\", "/")
|
|
||||||
|
|
||||||
build_lib = self.build_lib
|
|
||||||
|
|
||||||
target_lib = os.path.join(build_lib, "torch", "lib", "_C.lib").replace(
|
|
||||||
"\\", "/"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create "torch/lib" directory if not exists.
|
# Create "torch/lib" directory if not exists.
|
||||||
# (It is not created yet in "develop" mode.)
|
# (It is not created yet in "develop" mode.)
|
||||||
target_dir = os.path.dirname(target_lib)
|
target_dir = target_lib.parent
|
||||||
if not os.path.exists(target_dir):
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
os.makedirs(target_dir)
|
|
||||||
|
|
||||||
self.copy_file(export_lib, target_lib)
|
self.copy_file(export_lib, target_lib)
|
||||||
|
|
||||||
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
||||||
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
||||||
if str2bool(os.getenv("USE_ROCM")):
|
if str2bool(os.getenv("USE_ROCM")):
|
||||||
rocm_dir_path = os.environ.get("ROCM_DIR")
|
rocm_dir_path = Path(os.environ["ROCM_DIR"])
|
||||||
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
|
rocm_bin_path = rocm_dir_path / "bin"
|
||||||
|
rocblas_dir = rocm_bin_path / "rocblas"
|
||||||
|
target_rocblas_dir = target_dir / "rocblas"
|
||||||
|
target_rocblas_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.copy_tree(rocblas_dir, str(target_rocblas_dir))
|
||||||
|
|
||||||
rocblas_dir = os.path.join(rocm_bin_path, "rocblas")
|
hipblaslt_dir = rocm_bin_path / "hipblaslt"
|
||||||
target_rocblas_dir = os.path.join(target_dir, "rocblas")
|
target_hipblaslt_dir = target_dir / "hipblaslt"
|
||||||
os.makedirs(target_rocblas_dir, exist_ok=True)
|
target_hipblaslt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.copy_tree(rocblas_dir, target_rocblas_dir)
|
self.copy_tree(hipblaslt_dir, str(target_hipblaslt_dir))
|
||||||
|
|
||||||
hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt")
|
|
||||||
target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt")
|
|
||||||
os.makedirs(target_hipblaslt_dir, exist_ok=True)
|
|
||||||
self.copy_tree(hipblaslt_dir, target_hipblaslt_dir)
|
|
||||||
else:
|
else:
|
||||||
report("The specified environment variable does not exist.")
|
report("The specified environment variable does not exist.")
|
||||||
|
|
||||||
def build_extensions(self):
|
def build_extensions(self) -> None:
|
||||||
self.create_compile_commands()
|
self.create_compile_commands()
|
||||||
|
|
||||||
|
build_lib = Path(self.build_lib).resolve()
|
||||||
|
|
||||||
# Copy functorch extension
|
# Copy functorch extension
|
||||||
for i, ext in enumerate(self.extensions):
|
for ext in self.extensions:
|
||||||
if ext.name != "functorch._C":
|
if ext.name != "functorch._C":
|
||||||
continue
|
continue
|
||||||
fullname = self.get_ext_fullname(ext.name)
|
fullname = self.get_ext_fullname(ext.name)
|
||||||
filename = self.get_ext_filename(fullname)
|
filename = Path(self.get_ext_filename(fullname))
|
||||||
fileext = os.path.splitext(filename)[1]
|
src = filename.with_stem("functorch")
|
||||||
src = os.path.join(os.path.dirname(filename), "functorch" + fileext)
|
dst = build_lib / filename
|
||||||
dst = os.path.join(os.path.realpath(self.build_lib), filename)
|
if src.exists():
|
||||||
if os.path.exists(src):
|
|
||||||
report(f"Copying {ext.name} from {src} to {dst}")
|
report(f"Copying {ext.name} from {src} to {dst}")
|
||||||
dst_dir = os.path.dirname(dst)
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if not os.path.exists(dst_dir):
|
|
||||||
os.makedirs(dst_dir)
|
|
||||||
self.copy_file(src, dst)
|
self.copy_file(src, dst)
|
||||||
|
|
||||||
setuptools.command.build_ext.build_ext.build_extensions(self)
|
super().build_extensions()
|
||||||
|
|
||||||
def get_outputs(self):
|
def get_outputs(self) -> list[str]:
|
||||||
outputs = setuptools.command.build_ext.build_ext.get_outputs(self)
|
outputs = super().get_outputs()
|
||||||
outputs.append(os.path.join(self.build_lib, "caffe2"))
|
outputs.append(os.path.join(self.build_lib, "caffe2"))
|
||||||
report(f"setup.py::get_outputs returning {outputs}")
|
report(f"setup.py::get_outputs returning {outputs}")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def create_compile_commands(self):
|
def create_compile_commands(self) -> None:
|
||||||
def load(filename):
|
def load(file: Path) -> list[dict[str, Any]]:
|
||||||
with open(filename) as f:
|
return json.loads(file.read_text(encoding="utf-8"))
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
ninja_files = glob.glob("build/*compile_commands.json")
|
ninja_files = (CWD / BUILD_DIR).glob("*compile_commands.json")
|
||||||
cmake_files = glob.glob("torch/lib/build/*/compile_commands.json")
|
cmake_files = (CWD / "torch" / "lib" / "build").glob("*/compile_commands.json")
|
||||||
all_commands = [entry for f in ninja_files + cmake_files for entry in load(f)]
|
all_commands = [
|
||||||
|
entry
|
||||||
|
for f in itertools.chain(ninja_files, cmake_files)
|
||||||
|
for entry in load(f)
|
||||||
|
]
|
||||||
|
|
||||||
# cquery does not like c++ compiles that start with gcc.
|
# cquery does not like c++ compiles that start with gcc.
|
||||||
# It forgets to include the c++ header directories.
|
# It forgets to include the c++ header directories.
|
||||||
@ -862,12 +875,11 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
|
|
||||||
new_contents = json.dumps(all_commands, indent=2)
|
new_contents = json.dumps(all_commands, indent=2)
|
||||||
contents = ""
|
contents = ""
|
||||||
if os.path.exists("compile_commands.json"):
|
compile_commands_json = CWD / "compile_commands.json"
|
||||||
with open("compile_commands.json") as f:
|
if compile_commands_json.exists():
|
||||||
contents = f.read()
|
contents = compile_commands_json.read_text(encoding="utf-8")
|
||||||
if contents != new_contents:
|
if contents != new_contents:
|
||||||
with open("compile_commands.json", "w") as f:
|
compile_commands_json.write_text(new_contents, encoding="utf-8")
|
||||||
f.write(new_contents)
|
|
||||||
|
|
||||||
|
|
||||||
class concat_license_files:
|
class concat_license_files:
|
||||||
@ -879,115 +891,109 @@ class concat_license_files:
|
|||||||
licensing info.
|
licensing info.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, include_files=False):
|
def __init__(self, include_files: bool = False) -> None:
|
||||||
self.f1 = "LICENSE"
|
self.f1 = CWD / "LICENSE"
|
||||||
self.f2 = "third_party/LICENSES_BUNDLED.txt"
|
self.f2 = THIRD_PARTY_DIR / "LICENSES_BUNDLED.txt"
|
||||||
self.include_files = include_files
|
self.include_files = include_files
|
||||||
|
self.bsd_text = ""
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> None:
|
||||||
"""Concatenate files"""
|
"""Concatenate files"""
|
||||||
|
|
||||||
old_path = sys.path
|
old_path = sys.path
|
||||||
sys.path.append(third_party_path)
|
sys.path.append(str(THIRD_PARTY_DIR))
|
||||||
try:
|
try:
|
||||||
from build_bundled import create_bundled
|
from build_bundled import create_bundled # type: ignore[import-not-found]
|
||||||
finally:
|
finally:
|
||||||
sys.path = old_path
|
sys.path = old_path
|
||||||
|
|
||||||
with open(self.f1) as f1:
|
self.bsd_text = self.f1.read_text(encoding="utf-8")
|
||||||
self.bsd_text = f1.read()
|
|
||||||
|
|
||||||
with open(self.f1, "a") as f1:
|
with self.f1.open(mode="a", encoding="utf-8") as f1:
|
||||||
f1.write("\n\n")
|
f1.write("\n\n")
|
||||||
create_bundled(
|
create_bundled(
|
||||||
os.path.relpath(third_party_path), f1, include_files=self.include_files
|
str(THIRD_PARTY_DIR.resolve()),
|
||||||
|
f1,
|
||||||
|
include_files=self.include_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __exit__(self, exception_type, exception_value, traceback):
|
def __exit__(self, *exc_info: object) -> None:
|
||||||
"""Restore content of f1"""
|
"""Restore content of f1"""
|
||||||
with open(self.f1, "w") as f:
|
self.f1.write_text(self.bsd_text, encoding="utf-8")
|
||||||
f.write(self.bsd_text)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from wheel.bdist_wheel import bdist_wheel
|
from wheel.bdist_wheel import bdist_wheel # type: ignore[import-untyped]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# This is useful when wheel is not installed and bdist_wheel is not
|
# This is useful when wheel is not installed and bdist_wheel is not
|
||||||
# specified on the command line. If it _is_ specified, parsing the command
|
# specified on the command line. If it _is_ specified, parsing the command
|
||||||
# line will fail before wheel_concatenate is needed
|
# line will fail before wheel_concatenate is needed
|
||||||
wheel_concatenate = None
|
wheel_concatenate: type[Command] | None = None
|
||||||
else:
|
else:
|
||||||
# Need to create the proper LICENSE.txt for the wheel
|
# Need to create the proper LICENSE.txt for the wheel
|
||||||
class wheel_concatenate(bdist_wheel):
|
class wheel_concatenate(bdist_wheel): # type: ignore[no-redef]
|
||||||
"""check submodules on sdist to prevent incomplete tarballs"""
|
"""check submodules on sdist to prevent incomplete tarballs"""
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
with concat_license_files(include_files=True):
|
with concat_license_files(include_files=True):
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
def write_wheelfile(self, *args, **kwargs):
|
def write_wheelfile(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().write_wheelfile(*args, **kwargs)
|
super().write_wheelfile(*args, **kwargs)
|
||||||
|
|
||||||
if BUILD_LIBTORCH_WHL:
|
if BUILD_LIBTORCH_WHL:
|
||||||
|
bdist_dir = Path(self.bdist_dir)
|
||||||
# Remove extraneneous files in the libtorch wheel
|
# Remove extraneneous files in the libtorch wheel
|
||||||
for root, dirs, files in os.walk(self.bdist_dir):
|
for file in itertools.chain(
|
||||||
for file in files:
|
bdist_dir.glob("**/*.a"),
|
||||||
if file.endswith((".a", ".so")) and os.path.isfile(
|
bdist_dir.glob("**/*.so"),
|
||||||
os.path.join(self.bdist_dir, file)
|
):
|
||||||
):
|
if (bdist_dir / file.name).is_file():
|
||||||
os.remove(os.path.join(root, file))
|
file.unlink()
|
||||||
elif file.endswith(".py"):
|
for file in bdist_dir.glob("**/*.py"):
|
||||||
os.remove(os.path.join(root, file))
|
file.unlink()
|
||||||
# need an __init__.py file otherwise we wouldn't have a package
|
# need an __init__.py file otherwise we wouldn't have a package
|
||||||
open(os.path.join(self.bdist_dir, "torch", "__init__.py"), "w").close()
|
(bdist_dir / "torch" / "__init__.py").touch()
|
||||||
|
|
||||||
|
|
||||||
class install(setuptools.command.install.install):
|
class clean(Command):
|
||||||
def run(self):
|
user_options: ClassVar[list[tuple[str, str | None, str]]] = []
|
||||||
super().run()
|
|
||||||
|
|
||||||
|
def initialize_options(self) -> None:
|
||||||
class clean(setuptools.Command):
|
|
||||||
user_options = []
|
|
||||||
|
|
||||||
def initialize_options(self):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def finalize_options(self):
|
def finalize_options(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
import glob
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
with open(".gitignore") as f:
|
ignores = (CWD / ".gitignore").read_text(encoding="utf-8")
|
||||||
ignores = f.read()
|
pattern = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?")
|
||||||
pat = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?")
|
for wildcard in filter(None, ignores.splitlines()):
|
||||||
for wildcard in filter(None, ignores.split("\n")):
|
match = pattern.match(wildcard)
|
||||||
match = pat.match(wildcard)
|
if match:
|
||||||
if match:
|
if match.group(1):
|
||||||
if match.group(1):
|
# Marker is found and stop reading .gitignore.
|
||||||
# Marker is found and stop reading .gitignore.
|
break
|
||||||
break
|
# Ignore lines which begin with '#'.
|
||||||
# Ignore lines which begin with '#'.
|
else:
|
||||||
else:
|
# Don't remove absolute paths from the system
|
||||||
# Don't remove absolute paths from the system
|
wildcard = wildcard.lstrip("./")
|
||||||
wildcard = wildcard.lstrip("./")
|
for filename in glob.iglob(wildcard):
|
||||||
|
try:
|
||||||
for filename in glob.glob(wildcard):
|
os.remove(filename)
|
||||||
try:
|
except OSError:
|
||||||
os.remove(filename)
|
shutil.rmtree(filename, ignore_errors=True)
|
||||||
except OSError:
|
|
||||||
shutil.rmtree(filename, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
class sdist(setuptools.command.sdist.sdist):
|
class sdist(setuptools.command.sdist.sdist):
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
with concat_license_files():
|
with concat_license_files():
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
|
|
||||||
def get_cmake_cache_vars():
|
def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]:
|
||||||
try:
|
try:
|
||||||
return defaultdict(lambda: False, cmake.get_cmake_cache_variables())
|
return defaultdict(lambda: False, cmake.get_cmake_cache_variables())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@ -996,7 +1002,13 @@ def get_cmake_cache_vars():
|
|||||||
return defaultdict(lambda: False)
|
return defaultdict(lambda: False)
|
||||||
|
|
||||||
|
|
||||||
def configure_extension_build():
|
def configure_extension_build() -> tuple[
|
||||||
|
list[Extension], # ext_modules
|
||||||
|
dict[str, type[Command]], # cmdclass
|
||||||
|
list[str], # packages
|
||||||
|
dict[str, list[str]], # entry_points
|
||||||
|
list[str], # extra_install_requires
|
||||||
|
]:
|
||||||
r"""Configures extension build options according to system environment and user's choice.
|
r"""Configures extension build options according to system environment and user's choice.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1009,17 +1021,17 @@ def configure_extension_build():
|
|||||||
# Configure compile flags
|
# Configure compile flags
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
library_dirs = []
|
library_dirs: list[str] = [str(TORCH_LIB_DIR)]
|
||||||
extra_install_requires = []
|
extra_install_requires: list[str] = []
|
||||||
|
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
# /NODEFAULTLIB makes sure we only link to DLL runtime
|
# /NODEFAULTLIB makes sure we only link to DLL runtime
|
||||||
# and matches the flags set for protobuf and ONNX
|
# and matches the flags set for protobuf and ONNX
|
||||||
extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"]
|
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"]
|
||||||
# /MD links against DLL runtime
|
# /MD links against DLL runtime
|
||||||
# and matches the flags set for protobuf and ONNX
|
# and matches the flags set for protobuf and ONNX
|
||||||
# /EHsc is about standard C++ exception handling
|
# /EHsc is about standard C++ exception handling
|
||||||
extra_compile_args = ["/MD", "/FS", "/EHsc"]
|
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
|
||||||
else:
|
else:
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
extra_compile_args = [
|
extra_compile_args = [
|
||||||
@ -1035,13 +1047,11 @@ def configure_extension_build():
|
|||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
]
|
]
|
||||||
|
|
||||||
library_dirs.append(lib_path)
|
main_compile_args: list[str] = []
|
||||||
|
main_libraries: list[str] = ["torch_python"]
|
||||||
|
|
||||||
main_compile_args = []
|
main_link_args: list[str] = []
|
||||||
main_libraries = ["torch_python"]
|
main_sources: list[str] = ["torch/csrc/stub.c"]
|
||||||
|
|
||||||
main_link_args = []
|
|
||||||
main_sources = ["torch/csrc/stub.c"]
|
|
||||||
|
|
||||||
if BUILD_LIBTORCH_WHL:
|
if BUILD_LIBTORCH_WHL:
|
||||||
main_libraries = ["torch"]
|
main_libraries = ["torch"]
|
||||||
@ -1049,16 +1059,16 @@ def configure_extension_build():
|
|||||||
|
|
||||||
if build_type.is_debug():
|
if build_type.is_debug():
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
extra_compile_args.append("/Z7")
|
extra_compile_args += ["/Z7"]
|
||||||
extra_link_args.append("/DEBUG:FULL")
|
extra_link_args += ["/DEBUG:FULL"]
|
||||||
else:
|
else:
|
||||||
extra_compile_args += ["-O0", "-g"]
|
extra_compile_args += ["-O0", "-g"]
|
||||||
extra_link_args += ["-O0", "-g"]
|
extra_link_args += ["-O0", "-g"]
|
||||||
|
|
||||||
if build_type.is_rel_with_deb_info():
|
if build_type.is_rel_with_deb_info():
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
extra_compile_args.append("/Z7")
|
extra_compile_args += ["/Z7"]
|
||||||
extra_link_args.append("/DEBUG:FULL")
|
extra_link_args += ["/DEBUG:FULL"]
|
||||||
else:
|
else:
|
||||||
extra_compile_args += ["-g"]
|
extra_compile_args += ["-g"]
|
||||||
extra_link_args += ["-g"]
|
extra_link_args += ["-g"]
|
||||||
@ -1095,7 +1105,7 @@ def configure_extension_build():
|
|||||||
]
|
]
|
||||||
extra_link_args += ["-arch", macos_target_arch]
|
extra_link_args += ["-arch", macos_target_arch]
|
||||||
|
|
||||||
def make_relative_rpath_args(path):
|
def make_relative_rpath_args(path: str) -> list[str]:
|
||||||
if IS_DARWIN:
|
if IS_DARWIN:
|
||||||
return ["-Wl,-rpath,@loader_path/" + path]
|
return ["-Wl,-rpath,@loader_path/" + path]
|
||||||
elif IS_WINDOWS:
|
elif IS_WINDOWS:
|
||||||
@ -1120,26 +1130,24 @@ def configure_extension_build():
|
|||||||
extra_compile_args=main_compile_args + extra_compile_args,
|
extra_compile_args=main_compile_args + extra_compile_args,
|
||||||
include_dirs=[],
|
include_dirs=[],
|
||||||
library_dirs=library_dirs,
|
library_dirs=library_dirs,
|
||||||
extra_link_args=extra_link_args
|
extra_link_args=(
|
||||||
+ main_link_args
|
extra_link_args + main_link_args + make_relative_rpath_args("lib")
|
||||||
+ make_relative_rpath_args("lib"),
|
),
|
||||||
)
|
)
|
||||||
extensions.append(C)
|
extensions.append(C)
|
||||||
|
|
||||||
# These extensions are built by cmake and copied manually in build_extensions()
|
# These extensions are built by cmake and copied manually in build_extensions()
|
||||||
# inside the build_ext implementation
|
# inside the build_ext implementation
|
||||||
if cmake_cache_vars["BUILD_FUNCTORCH"]:
|
if cmake_cache_vars["BUILD_FUNCTORCH"]:
|
||||||
extensions.append(
|
extensions.append(Extension(name="functorch._C", sources=[]))
|
||||||
Extension(name="functorch._C", sources=[]),
|
|
||||||
)
|
|
||||||
|
|
||||||
cmdclass = {
|
cmdclass = {
|
||||||
"bdist_wheel": wheel_concatenate,
|
|
||||||
"build_ext": build_ext,
|
"build_ext": build_ext,
|
||||||
"clean": clean,
|
"clean": clean,
|
||||||
"install": install,
|
|
||||||
"sdist": sdist,
|
"sdist": sdist,
|
||||||
}
|
}
|
||||||
|
if wheel_concatenate is not None:
|
||||||
|
cmdclass["bdist_wheel"] = wheel_concatenate
|
||||||
|
|
||||||
entry_points = {
|
entry_points = {
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
@ -1171,7 +1179,7 @@ build_update_message = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def print_box(msg):
|
def print_box(msg: str) -> None:
|
||||||
lines = msg.split("\n")
|
lines = msg.split("\n")
|
||||||
size = max(len(l) + 1 for l in lines)
|
size = max(len(l) + 1 for l in lines)
|
||||||
print("-" * (size + 2))
|
print("-" * (size + 2))
|
||||||
@ -1180,7 +1188,7 @@ def print_box(msg):
|
|||||||
print("-" * (size + 2))
|
print("-" * (size + 2))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY:
|
if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. "
|
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. "
|
||||||
@ -1226,7 +1234,7 @@ def main():
|
|||||||
dist.script_args = sys.argv[1:]
|
dist.script_args = sys.argv[1:]
|
||||||
try:
|
try:
|
||||||
dist.parse_command_line()
|
dist.parse_command_line()
|
||||||
except setuptools.distutils.errors.DistutilsArgError as e:
|
except setuptools.errors.BaseError as e:
|
||||||
print(e)
|
print(e)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -1235,7 +1243,7 @@ def main():
|
|||||||
build_deps()
|
build_deps()
|
||||||
|
|
||||||
(
|
(
|
||||||
extensions,
|
ext_modules,
|
||||||
cmdclass,
|
cmdclass,
|
||||||
packages,
|
packages,
|
||||||
entry_points,
|
entry_points,
|
||||||
@ -1250,7 +1258,7 @@ def main():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Read in README.md for our long_description
|
# Read in README.md for our long_description
|
||||||
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f:
|
with open(os.path.join(CWD, "README.md"), encoding="utf-8") as f:
|
||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
version_range_max = max(sys.version_info[1], 13) + 1
|
version_range_max = max(sys.version_info[1], 13) + 1
|
||||||
@ -1312,12 +1320,11 @@ def main():
|
|||||||
"lib/*.lib",
|
"lib/*.lib",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
aotriton_image_path = os.path.join(lib_path, "aotriton.images")
|
aotriton_image_path = TORCH_LIB_DIR / "aotriton.images"
|
||||||
aks2_files = []
|
aks2_files: list[str] = []
|
||||||
for root, dirs, files in os.walk(aotriton_image_path):
|
for file in filter(lambda p: p.is_file(), aotriton_image_path.glob("**")):
|
||||||
subpath = os.path.relpath(root, start=aotriton_image_path)
|
subpath = file.relative_to(aotriton_image_path)
|
||||||
for fn in files:
|
aks2_files.append(os.path.join("lib/aotriton.images", subpath))
|
||||||
aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn))
|
|
||||||
torch_package_data += aks2_files
|
torch_package_data += aks2_files
|
||||||
if get_cmake_cache_vars()["USE_TENSORPIPE"]:
|
if get_cmake_cache_vars()["USE_TENSORPIPE"]:
|
||||||
torch_package_data.extend(
|
torch_package_data.extend(
|
||||||
@ -1345,17 +1352,17 @@ def main():
|
|||||||
package_data["torchgen"] = torchgen_package_data
|
package_data["torchgen"] = torchgen_package_data
|
||||||
else:
|
else:
|
||||||
# no extensions in BUILD_LIBTORCH_WHL mode
|
# no extensions in BUILD_LIBTORCH_WHL mode
|
||||||
extensions = []
|
ext_modules = []
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=TORCH_PACKAGE_NAME,
|
||||||
version=version,
|
version=TORCH_VERSION,
|
||||||
description=(
|
description=(
|
||||||
"Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
"Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||||
),
|
),
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
ext_modules=extensions,
|
ext_modules=ext_modules,
|
||||||
cmdclass=cmdclass,
|
cmdclass=cmdclass,
|
||||||
packages=packages,
|
packages=packages,
|
||||||
entry_points=entry_points,
|
entry_points=entry_points,
|
||||||
|
|||||||
Reference in New Issue
Block a user