[BE] use pathlib.Path instead of os.path.* in setup.py (#156742)

Resolves:

- https://github.com/pytorch/pytorch/pull/155998#discussion_r2164376634

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156742
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2025-06-29 01:09:29 +08:00
committed by PyTorch MergeBot
parent 90b973a2e2
commit 2380115f97
2 changed files with 265 additions and 257 deletions

View File

@ -122,6 +122,7 @@ is_formatter = true
[[linter]] [[linter]]
code = 'MYPY' code = 'MYPY'
include_patterns = [ include_patterns = [
'setup.py',
'torch/**/*.py', 'torch/**/*.py',
'torch/**/*.pyi', 'torch/**/*.pyi',
'caffe2/**/*.py', 'caffe2/**/*.py',

521
setup.py
View File

@ -257,23 +257,31 @@ if sys.version_info < python_min_version:
import filecmp import filecmp
import glob import glob
import importlib import importlib
import importlib.util import itertools
import json import json
import shutil import shutil
import subprocess import subprocess
import sysconfig import sysconfig
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import Any, ClassVar, IO
import setuptools.command.build_ext import setuptools.command.build_ext
import setuptools.command.install
import setuptools.command.sdist import setuptools.command.sdist
from setuptools import Extension, find_packages, setup import setuptools.errors
from setuptools import Command, Extension, find_packages, setup
from setuptools.dist import Distribution from setuptools.dist import Distribution
from tools.build_pytorch_libs import build_pytorch from tools.build_pytorch_libs import build_pytorch
from tools.generate_torch_version import get_torch_version from tools.generate_torch_version import get_torch_version
from tools.setup_helpers.cmake import CMake from tools.setup_helpers.cmake import CMake, CMakeValue
from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS from tools.setup_helpers.env import (
BUILD_DIR,
build_type,
IS_DARWIN,
IS_LINUX,
IS_WINDOWS,
)
from tools.setup_helpers.generate_linker_script import gen_linker_script from tools.setup_helpers.generate_linker_script import gen_linker_script
@ -318,18 +326,20 @@ def str2bool(value: str | None) -> bool:
raise ValueError(f"Invalid string value for boolean conversion: {value}") raise ValueError(f"Invalid string value for boolean conversion: {value}")
def _get_package_path(package_name): def _get_package_path(package_name: str) -> Path:
spec = importlib.util.find_spec(package_name) from importlib.util import find_spec
spec = find_spec(package_name)
if spec: if spec:
# The package might be a namespace package, so get_data may fail # The package might be a namespace package, so get_data may fail
try: try:
loader = spec.loader loader = spec.loader
if loader is not None: if loader is not None:
file_path = loader.get_filename() # type: ignore[attr-defined] file_path = loader.get_filename() # type: ignore[attr-defined]
return os.path.dirname(file_path) return Path(file_path).parent
except AttributeError: except AttributeError:
pass pass
return None return CWD / package_name
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL")) BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
@ -343,7 +353,7 @@ if BUILD_LIBTORCH_WHL:
if BUILD_PYTHON_ONLY: if BUILD_PYTHON_ONLY:
os.environ["BUILD_LIBTORCHLESS"] = "ON" os.environ["BUILD_LIBTORCHLESS"] = "ON"
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib" os.environ["LIBTORCH_LIB_PATH"] = (_get_package_path("torch") / "lib").as_posix()
################################################################################ ################################################################################
# Parameters parsed from environment # Parameters parsed from environment
@ -381,60 +391,61 @@ sys.argv = filtered_args
if VERBOSE_SCRIPT: if VERBOSE_SCRIPT:
def report(*args, file=sys.stderr, **kwargs): def report(*args: Any, file: IO[str] = sys.stderr, **kwargs: Any) -> None:
print(*args, file=file, **kwargs) print(*args, file=file, **kwargs)
else: else:
def report(*args, **kwargs): def report(*args: Any, file: IO[str] = sys.stderr, **kwargs: Any) -> None:
pass pass
# Make distutils respect --quiet too # Make distutils respect --quiet too
setuptools.distutils.log.warn = report setuptools.distutils.log.warn = report # type: ignore[attr-defined]
# Constant known variables used throughout this file # Constant known variables used throughout this file
cwd = os.path.dirname(os.path.abspath(__file__)) CWD = Path(__file__).absolute().parent
lib_path = os.path.join(cwd, "torch", "lib") TORCH_LIB_DIR = CWD / "torch" / "lib"
third_party_path = os.path.join(cwd, "third_party") THIRD_PARTY_DIR = CWD / "third_party"
# CMAKE: full path to python library # CMAKE: full path to python library
if IS_WINDOWS: if IS_WINDOWS:
cmake_python_library = "{}/libs/python{}.lib".format( CMAKE_PYTHON_LIBRARY = (
sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION") Path(sysconfig.get_config_var("prefix"))
/ "libs"
/ f"python{sysconfig.get_config_var('VERSION')}.lib"
) )
# Fix virtualenv builds # Fix virtualenv builds
if not os.path.exists(cmake_python_library): if not CMAKE_PYTHON_LIBRARY.exists():
cmake_python_library = "{}/libs/python{}.lib".format( CMAKE_PYTHON_LIBRARY = (
sys.base_prefix, sysconfig.get_config_var("VERSION") Path(sys.base_prefix)
/ "libs"
/ f"python{sysconfig.get_config_var('VERSION')}.lib"
) )
else: else:
cmake_python_library = "{}/{}".format( CMAKE_PYTHON_LIBRARY = Path(
sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME") sysconfig.get_config_var("LIBDIR")
) ) / sysconfig.get_config_var("INSTSONAME")
cmake_python_include_dir = sysconfig.get_path("include")
################################################################################ ################################################################################
# Version, create_version_file, and package_name # Version, create_version_file, and package_name
################################################################################ ################################################################################
package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch")
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python")
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
package_name = LIBTORCH_PKG_NAME TORCH_PACKAGE_NAME = LIBTORCH_PKG_NAME
TORCH_VERSION = get_torch_version()
package_type = os.getenv("PACKAGE_TYPE", "wheel") report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}")
version = get_torch_version()
report(f"Building wheel {package_name}-{version}")
cmake = CMake() cmake = CMake()
def get_submodule_folders(): def get_submodule_folders() -> list[Path]:
git_modules_path = os.path.join(cwd, ".gitmodules") git_modules_file = CWD / ".gitmodules"
default_modules_path = [ default_modules_path = [
os.path.join(third_party_path, name) THIRD_PARTY_DIR / name
for name in [ for name in [
"gloo", "gloo",
"cpuinfo", "cpuinfo",
@ -443,26 +454,26 @@ def get_submodule_folders():
"cutlass", "cutlass",
] ]
] ]
if not os.path.exists(git_modules_path): if not git_modules_file.exists():
return default_modules_path return default_modules_path
with open(git_modules_path) as f: with git_modules_file.open(encoding="utf-8") as f:
return [ return [
os.path.join(cwd, line.split("=", 1)[1].strip()) CWD / line.partition("=")[-1].strip()
for line in f for line in f
if line.strip().startswith("path") if line.strip().startswith("path")
] ]
def check_submodules(): def check_submodules() -> None:
def check_for_files(folder, files): def check_for_files(folder: Path, files: list[str]) -> None:
if not any(os.path.exists(os.path.join(folder, f)) for f in files): if not any((folder / f).exists() for f in files):
report("Could not find any of {} in {}".format(", ".join(files), folder)) report("Could not find any of {} in {}".format(", ".join(files), folder))
report("Did you run 'git submodule update --init --recursive'?") report("Did you run 'git submodule update --init --recursive'?")
sys.exit(1) sys.exit(1)
def not_exists_or_empty(folder): def not_exists_or_empty(folder: Path) -> bool:
return not os.path.exists(folder) or ( return not folder.exists() or (
os.path.isdir(folder) and len(os.listdir(folder)) == 0 folder.is_dir() and next(folder.iterdir(), None) is None
) )
if str2bool(os.getenv("USE_SYSTEM_LIBS")): if str2bool(os.getenv("USE_SYSTEM_LIBS")):
@ -474,7 +485,7 @@ def check_submodules():
report(" --- Trying to initialize submodules") report(" --- Trying to initialize submodules")
start = time.time() start = time.time()
subprocess.check_call( subprocess.check_call(
["git", "submodule", "update", "--init", "--recursive"], cwd=cwd ["git", "submodule", "update", "--init", "--recursive"], cwd=CWD
) )
end = time.time() end = time.time()
report(f" --- Submodule initialization took {end - start:.2f} sec") report(f" --- Submodule initialization took {end - start:.2f} sec")
@ -495,37 +506,41 @@ def check_submodules():
], ],
) )
check_for_files( check_for_files(
os.path.join(third_party_path, "fbgemm", "external", "asmjit"), THIRD_PARTY_DIR / "fbgemm" / "external" / "asmjit",
["CMakeLists.txt"], ["CMakeLists.txt"],
) )
# Windows has very bad support for symbolic links. # Windows has very bad support for symbolic links.
# Instead of using symlinks, we're going to copy files over # Instead of using symlinks, we're going to copy files over
def mirror_files_into_torchgen(): def mirror_files_into_torchgen() -> None:
# (new_path, orig_path) # (new_path, orig_path)
# Directories are OK and are recursively mirrored. # Directories are OK and are recursively mirrored.
paths = [ new_paths = [
( "torchgen/packaged/ATen/native/native_functions.yaml",
"torchgen/packaged/ATen/native/native_functions.yaml", "torchgen/packaged/ATen/native/tags.yaml",
"aten/src/ATen/native/native_functions.yaml", "torchgen/packaged/ATen/templates",
), "torchgen/packaged/autograd",
("torchgen/packaged/ATen/native/tags.yaml", "aten/src/ATen/native/tags.yaml"), "torchgen/packaged/autograd/templates",
("torchgen/packaged/ATen/templates", "aten/src/ATen/templates"),
("torchgen/packaged/autograd", "tools/autograd"),
("torchgen/packaged/autograd/templates", "tools/autograd/templates"),
] ]
for new_path, orig_path in paths: orig_paths = [
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
"aten/src/ATen/templates",
"tools/autograd",
"tools/autograd/templates",
]
for new_path, orig_path in zip(map(Path, new_paths), map(Path, orig_paths)):
# Create the dirs involved in new_path if they don't exist # Create the dirs involved in new_path if they don't exist
if not os.path.exists(new_path): if not new_path.exists():
os.makedirs(os.path.dirname(new_path), exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location # Copy the files from the orig location to the new location
if os.path.isfile(orig_path): if orig_path.is_file():
shutil.copyfile(orig_path, new_path) shutil.copyfile(orig_path, new_path)
continue continue
if os.path.isdir(orig_path): if orig_path.is_dir():
if os.path.exists(new_path): if new_path.exists():
# copytree fails if the tree exists already, so remove it. # copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path) shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path) shutil.copytree(orig_path, new_path)
@ -534,15 +549,14 @@ def mirror_files_into_torchgen():
# all the work we need to do _before_ setup runs # all the work we need to do _before_ setup runs
def build_deps(): def build_deps() -> None:
report("-- Building version " + version) report(f"-- Building version {TORCH_VERSION}")
check_submodules() check_submodules()
check_pydep("yaml", "pyyaml") check_pydep("yaml", "pyyaml")
build_python = not BUILD_LIBTORCH_WHL
build_pytorch( build_pytorch(
version=version, version=TORCH_VERSION,
cmake_python_library=cmake_python_library, cmake_python_library=CMAKE_PYTHON_LIBRARY.as_posix(),
build_python=build_python, build_python=not BUILD_LIBTORCH_WHL,
rerun_cmake=RERUN_CMAKE, rerun_cmake=RERUN_CMAKE,
cmake_only=CMAKE_ONLY, cmake_only=CMAKE_ONLY,
cmake=cmake, cmake=cmake,
@ -568,13 +582,13 @@ def build_deps():
"third_party/valgrind-headers/callgrind.h", "third_party/valgrind-headers/callgrind.h",
"third_party/valgrind-headers/valgrind.h", "third_party/valgrind-headers/valgrind.h",
] ]
for sym_file, orig_file in zip(sym_files, orig_files): for sym_file, orig_file in zip(map(Path, sym_files), map(Path, orig_files)):
same = False same = False
if os.path.exists(sym_file): if sym_file.exists():
if filecmp.cmp(sym_file, orig_file): if filecmp.cmp(sym_file, orig_file):
same = True same = True
else: else:
os.remove(sym_file) sym_file.unlink()
if not same: if not same:
shutil.copyfile(orig_file, sym_file) shutil.copyfile(orig_file, sym_file)
@ -589,7 +603,7 @@ Please install it via `conda install {module}` or `pip install {module}`
""".strip() """.strip()
def check_pydep(importname, module): def check_pydep(importname: str, module: str) -> None:
try: try:
importlib.import_module(importname) importlib.import_module(importname)
except ImportError as e: except ImportError as e:
@ -599,19 +613,22 @@ def check_pydep(importname, module):
class build_ext(setuptools.command.build_ext.build_ext): class build_ext(setuptools.command.build_ext.build_ext):
def _embed_libomp(self): def _embed_libomp(self) -> None:
# Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS
lib_dir = os.path.join(self.build_lib, "torch", "lib") build_lib = Path(self.build_lib)
libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib") build_torch_lib_dir = build_lib / "torch" / "lib"
if not os.path.exists(libtorch_cpu_path): build_torch_include_dir = build_lib / "torch" / "include"
libtorch_cpu_path = build_torch_lib_dir / "libtorch_cpu.dylib"
if not libtorch_cpu_path.exists():
return return
# Parse libtorch_cpu load commands # Parse libtorch_cpu load commands
otool_cmds = ( otool_cmds = (
subprocess.check_output(["otool", "-l", libtorch_cpu_path]) subprocess.check_output(["otool", "-l", str(libtorch_cpu_path)])
.decode("utf-8") .decode("utf-8")
.split("\n") .split("\n")
) )
rpaths, libs = [], [] rpaths: list[str] = []
libs: list[str] = []
for idx, line in enumerate(otool_cmds): for idx, line in enumerate(otool_cmds):
if line.strip() == "cmd LC_LOAD_DYLIB": if line.strip() == "cmd LC_LOAD_DYLIB":
lib_name = otool_cmds[idx + 2].strip() lib_name = otool_cmds[idx + 2].strip()
@ -623,8 +640,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
assert rpath.startswith("path ") assert rpath.startswith("path ")
rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1]) rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1])
omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment]
omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib" omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment]
omplib_name += ".dylib"
omplib_rpath_path = os.path.join("@rpath", omplib_name) omplib_rpath_path = os.path.join("@rpath", omplib_name)
# This logic is fragile and checks only two cases: # This logic is fragile and checks only two cases:
@ -634,8 +652,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
return return
# Copy libomp/libiomp5 from rpath locations # Copy libomp/libiomp5 from rpath locations
target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name) target_lib = build_torch_lib_dir / omplib_name
libomp_relocated = False libomp_relocated = False
install_name_tool_args: list[str] = []
for rpath in rpaths: for rpath in rpaths:
source_lib = os.path.join(rpath, omplib_name) source_lib = os.path.join(rpath, omplib_name)
if not os.path.exists(source_lib): if not os.path.exists(source_lib):
@ -666,22 +685,29 @@ class build_ext(setuptools.command.build_ext.build_ext):
] ]
libomp_relocated = True libomp_relocated = True
if libomp_relocated: if libomp_relocated:
install_name_tool_args.insert(0, "install_name_tool") install_name_tool_args = [
install_name_tool_args.append(libtorch_cpu_path) "install_name_tool",
*install_name_tool_args,
str(libtorch_cpu_path),
]
subprocess.check_call(install_name_tool_args) subprocess.check_call(install_name_tool_args)
# Copy omp.h from OpenMP_C_FLAGS and copy it into include folder # Copy omp.h from OpenMP_C_FLAGS and copy it into include folder
omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"] omp_cflags: str = get_cmake_cache_vars()["OpenMP_C_FLAGS"] # type: ignore[assignment]
if not omp_cflags: if not omp_cflags:
return return
for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]: for include_dir in [
omp_h = os.path.join(include_dir, "omp.h") Path(f.removeprefix("-I"))
if not os.path.exists(omp_h): for f in omp_cflags.split(" ")
if f.startswith("-I")
]:
omp_h = include_dir / "omp.h"
if not omp_h.exists():
continue continue
target_omp_h = os.path.join(self.build_lib, "torch", "include", "omp.h") target_omp_h = build_torch_include_dir / "omp.h"
self.copy_file(omp_h, target_omp_h) self.copy_file(omp_h, target_omp_h)
break break
def run(self): def run(self) -> None:
# Report build options. This is run after the build completes so # `CMakeCache.txt` exists # Report build options. This is run after the build completes so # `CMakeCache.txt` exists
# and we can get an accurate report on what is used and what is not. # and we can get an accurate report on what is used and what is not.
cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables())
@ -692,18 +718,17 @@ class build_ext(setuptools.command.build_ext.build_ext):
if cmake_cache_vars["USE_CUDNN"]: if cmake_cache_vars["USE_CUDNN"]:
report( report(
"-- Detected cuDNN at " "-- Detected cuDNN at "
+ cmake_cache_vars["CUDNN_LIBRARY"] f"{cmake_cache_vars['CUDNN_LIBRARY']}, "
+ ", " f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}"
+ cmake_cache_vars["CUDNN_INCLUDE_DIR"]
) )
else: else:
report("-- Not using cuDNN") report("-- Not using cuDNN")
if cmake_cache_vars["USE_CUDA"]: if cmake_cache_vars["USE_CUDA"]:
report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"]) report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}")
else: else:
report("-- Not using CUDA") report("-- Not using CUDA")
if cmake_cache_vars["USE_XPU"]: if cmake_cache_vars["USE_XPU"]:
report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"]) report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}")
else: else:
report("-- Not using XPU") report("-- Not using XPU")
if cmake_cache_vars["USE_MKLDNN"]: if cmake_cache_vars["USE_MKLDNN"]:
@ -722,10 +747,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
report("-- Not using MKLDNN") report("-- Not using MKLDNN")
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]: if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
report( report(
"-- Using system provided NCCL library at {}, {}".format( "-- Using system provided NCCL library at "
cmake_cache_vars["NCCL_LIBRARIES"], f"{cmake_cache_vars['NCCL_LIBRARIES']}, "
cmake_cache_vars["NCCL_INCLUDE_DIRS"], f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}"
)
) )
elif cmake_cache_vars["USE_NCCL"]: elif cmake_cache_vars["USE_NCCL"]:
report("-- Building NCCL library") report("-- Building NCCL library")
@ -736,18 +760,15 @@ class build_ext(setuptools.command.build_ext.build_ext):
report("-- Building without distributed package") report("-- Building without distributed package")
else: else:
report("-- Building with distributed package: ") report("-- Building with distributed package: ")
report( report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}")
" -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"]) report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}")
) report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}")
report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"]))
report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"]))
else: else:
report("-- Building without distributed package") report("-- Building without distributed package")
if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]: if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]:
report( report(
"-- Using static dispatch with backend {}".format( "-- Using static dispatch with "
cmake_cache_vars["STATIC_DISPATCH_BACKEND"] f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}"
)
) )
if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]: if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]:
report("-- Using lightweight dispatch") report("-- Using lightweight dispatch")
@ -759,98 +780,90 @@ class build_ext(setuptools.command.build_ext.build_ext):
# Do not use clang to compile extensions if `-fstack-clash-protection` is defined # Do not use clang to compile extensions if `-fstack-clash-protection` is defined
# in system CFLAGS # in system CFLAGS
c_flags = str(os.getenv("CFLAGS", "")) c_flags = os.getenv("CFLAGS", "")
if ( if (
IS_LINUX IS_LINUX
and "-fstack-clash-protection" in c_flags and "-fstack-clash-protection" in c_flags
and "clang" in os.environ.get("CC", "") and "clang" in os.getenv("CC", "")
): ):
os.environ["CC"] = str(os.environ["CC"]) os.environ["CC"] = str(os.environ["CC"])
# It's an old-style class in Python 2.7... super().run()
setuptools.command.build_ext.build_ext.run(self)
if IS_DARWIN: if IS_DARWIN:
self._embed_libomp() self._embed_libomp()
# Copy the essential export library to compile C++ extensions. # Copy the essential export library to compile C++ extensions.
if IS_WINDOWS: if IS_WINDOWS:
build_temp = self.build_temp build_temp = Path(self.build_temp)
build_lib = Path(self.build_lib)
ext_filename = self.get_ext_filename("_C") ext_filename = self.get_ext_filename("_C")
lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib" lib_filename = ".".join(ext_filename.split(".")[:-1]) + ".lib"
export_lib = os.path.join( export_lib = build_temp / "torch" / "csrc" / lib_filename
build_temp, "torch", "csrc", lib_filename target_lib = build_lib / "torch" / "lib" / "_C.lib"
).replace("\\", "/")
build_lib = self.build_lib
target_lib = os.path.join(build_lib, "torch", "lib", "_C.lib").replace(
"\\", "/"
)
# Create "torch/lib" directory if not exists. # Create "torch/lib" directory if not exists.
# (It is not created yet in "develop" mode.) # (It is not created yet in "develop" mode.)
target_dir = os.path.dirname(target_lib) target_dir = target_lib.parent
if not os.path.exists(target_dir): target_dir.mkdir(parents=True, exist_ok=True)
os.makedirs(target_dir)
self.copy_file(export_lib, target_lib) self.copy_file(export_lib, target_lib)
# In ROCm on Windows case copy rocblas and hipblaslt files into # In ROCm on Windows case copy rocblas and hipblaslt files into
# torch/lib/rocblas/library and torch/lib/hipblaslt/library # torch/lib/rocblas/library and torch/lib/hipblaslt/library
if str2bool(os.getenv("USE_ROCM")): if str2bool(os.getenv("USE_ROCM")):
rocm_dir_path = os.environ.get("ROCM_DIR") rocm_dir_path = Path(os.environ["ROCM_DIR"])
rocm_bin_path = os.path.join(rocm_dir_path, "bin") rocm_bin_path = rocm_dir_path / "bin"
rocblas_dir = rocm_bin_path / "rocblas"
target_rocblas_dir = target_dir / "rocblas"
target_rocblas_dir.mkdir(parents=True, exist_ok=True)
self.copy_tree(rocblas_dir, str(target_rocblas_dir))
rocblas_dir = os.path.join(rocm_bin_path, "rocblas") hipblaslt_dir = rocm_bin_path / "hipblaslt"
target_rocblas_dir = os.path.join(target_dir, "rocblas") target_hipblaslt_dir = target_dir / "hipblaslt"
os.makedirs(target_rocblas_dir, exist_ok=True) target_hipblaslt_dir.mkdir(parents=True, exist_ok=True)
self.copy_tree(rocblas_dir, target_rocblas_dir) self.copy_tree(hipblaslt_dir, str(target_hipblaslt_dir))
hipblaslt_dir = os.path.join(rocm_bin_path, "hipblaslt")
target_hipblaslt_dir = os.path.join(target_dir, "hipblaslt")
os.makedirs(target_hipblaslt_dir, exist_ok=True)
self.copy_tree(hipblaslt_dir, target_hipblaslt_dir)
else: else:
report("The specified environment variable does not exist.") report("The specified environment variable does not exist.")
def build_extensions(self): def build_extensions(self) -> None:
self.create_compile_commands() self.create_compile_commands()
build_lib = Path(self.build_lib).resolve()
# Copy functorch extension # Copy functorch extension
for i, ext in enumerate(self.extensions): for ext in self.extensions:
if ext.name != "functorch._C": if ext.name != "functorch._C":
continue continue
fullname = self.get_ext_fullname(ext.name) fullname = self.get_ext_fullname(ext.name)
filename = self.get_ext_filename(fullname) filename = Path(self.get_ext_filename(fullname))
fileext = os.path.splitext(filename)[1] src = filename.with_stem("functorch")
src = os.path.join(os.path.dirname(filename), "functorch" + fileext) dst = build_lib / filename
dst = os.path.join(os.path.realpath(self.build_lib), filename) if src.exists():
if os.path.exists(src):
report(f"Copying {ext.name} from {src} to {dst}") report(f"Copying {ext.name} from {src} to {dst}")
dst_dir = os.path.dirname(dst) dst.parent.mkdir(parents=True, exist_ok=True)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
self.copy_file(src, dst) self.copy_file(src, dst)
setuptools.command.build_ext.build_ext.build_extensions(self) super().build_extensions()
def get_outputs(self): def get_outputs(self) -> list[str]:
outputs = setuptools.command.build_ext.build_ext.get_outputs(self) outputs = super().get_outputs()
outputs.append(os.path.join(self.build_lib, "caffe2")) outputs.append(os.path.join(self.build_lib, "caffe2"))
report(f"setup.py::get_outputs returning {outputs}") report(f"setup.py::get_outputs returning {outputs}")
return outputs return outputs
def create_compile_commands(self): def create_compile_commands(self) -> None:
def load(filename): def load(file: Path) -> list[dict[str, Any]]:
with open(filename) as f: return json.loads(file.read_text(encoding="utf-8"))
return json.load(f)
ninja_files = glob.glob("build/*compile_commands.json") ninja_files = (CWD / BUILD_DIR).glob("*compile_commands.json")
cmake_files = glob.glob("torch/lib/build/*/compile_commands.json") cmake_files = (CWD / "torch" / "lib" / "build").glob("*/compile_commands.json")
all_commands = [entry for f in ninja_files + cmake_files for entry in load(f)] all_commands = [
entry
for f in itertools.chain(ninja_files, cmake_files)
for entry in load(f)
]
# cquery does not like c++ compiles that start with gcc. # cquery does not like c++ compiles that start with gcc.
# It forgets to include the c++ header directories. # It forgets to include the c++ header directories.
@ -862,12 +875,11 @@ class build_ext(setuptools.command.build_ext.build_ext):
new_contents = json.dumps(all_commands, indent=2) new_contents = json.dumps(all_commands, indent=2)
contents = "" contents = ""
if os.path.exists("compile_commands.json"): compile_commands_json = CWD / "compile_commands.json"
with open("compile_commands.json") as f: if compile_commands_json.exists():
contents = f.read() contents = compile_commands_json.read_text(encoding="utf-8")
if contents != new_contents: if contents != new_contents:
with open("compile_commands.json", "w") as f: compile_commands_json.write_text(new_contents, encoding="utf-8")
f.write(new_contents)
class concat_license_files: class concat_license_files:
@ -879,115 +891,109 @@ class concat_license_files:
licensing info. licensing info.
""" """
def __init__(self, include_files=False): def __init__(self, include_files: bool = False) -> None:
self.f1 = "LICENSE" self.f1 = CWD / "LICENSE"
self.f2 = "third_party/LICENSES_BUNDLED.txt" self.f2 = THIRD_PARTY_DIR / "LICENSES_BUNDLED.txt"
self.include_files = include_files self.include_files = include_files
self.bsd_text = ""
def __enter__(self): def __enter__(self) -> None:
"""Concatenate files""" """Concatenate files"""
old_path = sys.path old_path = sys.path
sys.path.append(third_party_path) sys.path.append(str(THIRD_PARTY_DIR))
try: try:
from build_bundled import create_bundled from build_bundled import create_bundled # type: ignore[import-not-found]
finally: finally:
sys.path = old_path sys.path = old_path
with open(self.f1) as f1: self.bsd_text = self.f1.read_text(encoding="utf-8")
self.bsd_text = f1.read()
with open(self.f1, "a") as f1: with self.f1.open(mode="a", encoding="utf-8") as f1:
f1.write("\n\n") f1.write("\n\n")
create_bundled( create_bundled(
os.path.relpath(third_party_path), f1, include_files=self.include_files str(THIRD_PARTY_DIR.resolve()),
f1,
include_files=self.include_files,
) )
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, *exc_info: object) -> None:
"""Restore content of f1""" """Restore content of f1"""
with open(self.f1, "w") as f: self.f1.write_text(self.bsd_text, encoding="utf-8")
f.write(self.bsd_text)
try: try:
from wheel.bdist_wheel import bdist_wheel from wheel.bdist_wheel import bdist_wheel # type: ignore[import-untyped]
except ImportError: except ImportError:
# This is useful when wheel is not installed and bdist_wheel is not # This is useful when wheel is not installed and bdist_wheel is not
# specified on the command line. If it _is_ specified, parsing the command # specified on the command line. If it _is_ specified, parsing the command
# line will fail before wheel_concatenate is needed # line will fail before wheel_concatenate is needed
wheel_concatenate = None wheel_concatenate: type[Command] | None = None
else: else:
# Need to create the proper LICENSE.txt for the wheel # Need to create the proper LICENSE.txt for the wheel
class wheel_concatenate(bdist_wheel): class wheel_concatenate(bdist_wheel): # type: ignore[no-redef]
"""check submodules on sdist to prevent incomplete tarballs""" """check submodules on sdist to prevent incomplete tarballs"""
def run(self): def run(self) -> None:
with concat_license_files(include_files=True): with concat_license_files(include_files=True):
super().run() super().run()
def write_wheelfile(self, *args, **kwargs): def write_wheelfile(self, *args: Any, **kwargs: Any) -> None:
super().write_wheelfile(*args, **kwargs) super().write_wheelfile(*args, **kwargs)
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
bdist_dir = Path(self.bdist_dir)
# Remove extraneneous files in the libtorch wheel # Remove extraneneous files in the libtorch wheel
for root, dirs, files in os.walk(self.bdist_dir): for file in itertools.chain(
for file in files: bdist_dir.glob("**/*.a"),
if file.endswith((".a", ".so")) and os.path.isfile( bdist_dir.glob("**/*.so"),
os.path.join(self.bdist_dir, file) ):
): if (bdist_dir / file.name).is_file():
os.remove(os.path.join(root, file)) file.unlink()
elif file.endswith(".py"): for file in bdist_dir.glob("**/*.py"):
os.remove(os.path.join(root, file)) file.unlink()
# need an __init__.py file otherwise we wouldn't have a package # need an __init__.py file otherwise we wouldn't have a package
open(os.path.join(self.bdist_dir, "torch", "__init__.py"), "w").close() (bdist_dir / "torch" / "__init__.py").touch()
class install(setuptools.command.install.install): class clean(Command):
def run(self): user_options: ClassVar[list[tuple[str, str | None, str]]] = []
super().run()
def initialize_options(self) -> None:
class clean(setuptools.Command):
user_options = []
def initialize_options(self):
pass pass
def finalize_options(self): def finalize_options(self) -> None:
pass pass
def run(self): def run(self) -> None:
import glob
import re import re
with open(".gitignore") as f: ignores = (CWD / ".gitignore").read_text(encoding="utf-8")
ignores = f.read() pattern = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?")
pat = re.compile(r"^#( BEGIN NOT-CLEAN-FILES )?") for wildcard in filter(None, ignores.splitlines()):
for wildcard in filter(None, ignores.split("\n")): match = pattern.match(wildcard)
match = pat.match(wildcard) if match:
if match: if match.group(1):
if match.group(1): # Marker is found and stop reading .gitignore.
# Marker is found and stop reading .gitignore. break
break # Ignore lines which begin with '#'.
# Ignore lines which begin with '#'. else:
else: # Don't remove absolute paths from the system
# Don't remove absolute paths from the system wildcard = wildcard.lstrip("./")
wildcard = wildcard.lstrip("./") for filename in glob.iglob(wildcard):
try:
for filename in glob.glob(wildcard): os.remove(filename)
try: except OSError:
os.remove(filename) shutil.rmtree(filename, ignore_errors=True)
except OSError:
shutil.rmtree(filename, ignore_errors=True)
class sdist(setuptools.command.sdist.sdist): class sdist(setuptools.command.sdist.sdist):
def run(self): def run(self) -> None:
with concat_license_files(): with concat_license_files():
super().run() super().run()
def get_cmake_cache_vars(): def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]:
try: try:
return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) return defaultdict(lambda: False, cmake.get_cmake_cache_variables())
except FileNotFoundError: except FileNotFoundError:
@ -996,7 +1002,13 @@ def get_cmake_cache_vars():
return defaultdict(lambda: False) return defaultdict(lambda: False)
def configure_extension_build(): def configure_extension_build() -> tuple[
list[Extension], # ext_modules
dict[str, type[Command]], # cmdclass
list[str], # packages
dict[str, list[str]], # entry_points
list[str], # extra_install_requires
]:
r"""Configures extension build options according to system environment and user's choice. r"""Configures extension build options according to system environment and user's choice.
Returns: Returns:
@ -1009,17 +1021,17 @@ def configure_extension_build():
# Configure compile flags # Configure compile flags
################################################################################ ################################################################################
library_dirs = [] library_dirs: list[str] = [str(TORCH_LIB_DIR)]
extra_install_requires = [] extra_install_requires: list[str] = []
if IS_WINDOWS: if IS_WINDOWS:
# /NODEFAULTLIB makes sure we only link to DLL runtime # /NODEFAULTLIB makes sure we only link to DLL runtime
# and matches the flags set for protobuf and ONNX # and matches the flags set for protobuf and ONNX
extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"] extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"]
# /MD links against DLL runtime # /MD links against DLL runtime
# and matches the flags set for protobuf and ONNX # and matches the flags set for protobuf and ONNX
# /EHsc is about standard C++ exception handling # /EHsc is about standard C++ exception handling
extra_compile_args = ["/MD", "/FS", "/EHsc"] extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
else: else:
extra_link_args = [] extra_link_args = []
extra_compile_args = [ extra_compile_args = [
@ -1035,13 +1047,11 @@ def configure_extension_build():
"-fno-strict-aliasing", "-fno-strict-aliasing",
] ]
library_dirs.append(lib_path) main_compile_args: list[str] = []
main_libraries: list[str] = ["torch_python"]
main_compile_args = [] main_link_args: list[str] = []
main_libraries = ["torch_python"] main_sources: list[str] = ["torch/csrc/stub.c"]
main_link_args = []
main_sources = ["torch/csrc/stub.c"]
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
main_libraries = ["torch"] main_libraries = ["torch"]
@ -1049,16 +1059,16 @@ def configure_extension_build():
if build_type.is_debug(): if build_type.is_debug():
if IS_WINDOWS: if IS_WINDOWS:
extra_compile_args.append("/Z7") extra_compile_args += ["/Z7"]
extra_link_args.append("/DEBUG:FULL") extra_link_args += ["/DEBUG:FULL"]
else: else:
extra_compile_args += ["-O0", "-g"] extra_compile_args += ["-O0", "-g"]
extra_link_args += ["-O0", "-g"] extra_link_args += ["-O0", "-g"]
if build_type.is_rel_with_deb_info(): if build_type.is_rel_with_deb_info():
if IS_WINDOWS: if IS_WINDOWS:
extra_compile_args.append("/Z7") extra_compile_args += ["/Z7"]
extra_link_args.append("/DEBUG:FULL") extra_link_args += ["/DEBUG:FULL"]
else: else:
extra_compile_args += ["-g"] extra_compile_args += ["-g"]
extra_link_args += ["-g"] extra_link_args += ["-g"]
@ -1095,7 +1105,7 @@ def configure_extension_build():
] ]
extra_link_args += ["-arch", macos_target_arch] extra_link_args += ["-arch", macos_target_arch]
def make_relative_rpath_args(path): def make_relative_rpath_args(path: str) -> list[str]:
if IS_DARWIN: if IS_DARWIN:
return ["-Wl,-rpath,@loader_path/" + path] return ["-Wl,-rpath,@loader_path/" + path]
elif IS_WINDOWS: elif IS_WINDOWS:
@ -1120,26 +1130,24 @@ def configure_extension_build():
extra_compile_args=main_compile_args + extra_compile_args, extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[], include_dirs=[],
library_dirs=library_dirs, library_dirs=library_dirs,
extra_link_args=extra_link_args extra_link_args=(
+ main_link_args extra_link_args + main_link_args + make_relative_rpath_args("lib")
+ make_relative_rpath_args("lib"), ),
) )
extensions.append(C) extensions.append(C)
# These extensions are built by cmake and copied manually in build_extensions() # These extensions are built by cmake and copied manually in build_extensions()
# inside the build_ext implementation # inside the build_ext implementation
if cmake_cache_vars["BUILD_FUNCTORCH"]: if cmake_cache_vars["BUILD_FUNCTORCH"]:
extensions.append( extensions.append(Extension(name="functorch._C", sources=[]))
Extension(name="functorch._C", sources=[]),
)
cmdclass = { cmdclass = {
"bdist_wheel": wheel_concatenate,
"build_ext": build_ext, "build_ext": build_ext,
"clean": clean, "clean": clean,
"install": install,
"sdist": sdist, "sdist": sdist,
} }
if wheel_concatenate is not None:
cmdclass["bdist_wheel"] = wheel_concatenate
entry_points = { entry_points = {
"console_scripts": [ "console_scripts": [
@ -1171,7 +1179,7 @@ build_update_message = """
""" """
def print_box(msg): def print_box(msg: str) -> None:
lines = msg.split("\n") lines = msg.split("\n")
size = max(len(l) + 1 for l in lines) size = max(len(l) + 1 for l in lines)
print("-" * (size + 2)) print("-" * (size + 2))
@ -1180,7 +1188,7 @@ def print_box(msg):
print("-" * (size + 2)) print("-" * (size + 2))
def main(): def main() -> None:
if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY: if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY:
raise RuntimeError( raise RuntimeError(
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. " "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. "
@ -1226,7 +1234,7 @@ def main():
dist.script_args = sys.argv[1:] dist.script_args = sys.argv[1:]
try: try:
dist.parse_command_line() dist.parse_command_line()
except setuptools.distutils.errors.DistutilsArgError as e: except setuptools.errors.BaseError as e:
print(e) print(e)
sys.exit(1) sys.exit(1)
@ -1235,7 +1243,7 @@ def main():
build_deps() build_deps()
( (
extensions, ext_modules,
cmdclass, cmdclass,
packages, packages,
entry_points, entry_points,
@ -1250,7 +1258,7 @@ def main():
} }
# Read in README.md for our long_description # Read in README.md for our long_description
with open(os.path.join(cwd, "README.md"), encoding="utf-8") as f: with open(os.path.join(CWD, "README.md"), encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
version_range_max = max(sys.version_info[1], 13) + 1 version_range_max = max(sys.version_info[1], 13) + 1
@ -1312,12 +1320,11 @@ def main():
"lib/*.lib", "lib/*.lib",
] ]
) )
aotriton_image_path = os.path.join(lib_path, "aotriton.images") aotriton_image_path = TORCH_LIB_DIR / "aotriton.images"
aks2_files = [] aks2_files: list[str] = []
for root, dirs, files in os.walk(aotriton_image_path): for file in filter(lambda p: p.is_file(), aotriton_image_path.glob("**")):
subpath = os.path.relpath(root, start=aotriton_image_path) subpath = file.relative_to(aotriton_image_path)
for fn in files: aks2_files.append(os.path.join("lib/aotriton.images", subpath))
aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn))
torch_package_data += aks2_files torch_package_data += aks2_files
if get_cmake_cache_vars()["USE_TENSORPIPE"]: if get_cmake_cache_vars()["USE_TENSORPIPE"]:
torch_package_data.extend( torch_package_data.extend(
@ -1345,17 +1352,17 @@ def main():
package_data["torchgen"] = torchgen_package_data package_data["torchgen"] = torchgen_package_data
else: else:
# no extensions in BUILD_LIBTORCH_WHL mode # no extensions in BUILD_LIBTORCH_WHL mode
extensions = [] ext_modules = []
setup( setup(
name=package_name, name=TORCH_PACKAGE_NAME,
version=version, version=TORCH_VERSION,
description=( description=(
"Tensors and Dynamic neural networks in Python with strong GPU acceleration" "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
), ),
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
ext_modules=extensions, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
packages=packages, packages=packages,
entry_points=entry_points, entry_points=entry_points,