[BE] add type annotations and run mypy on setup.py (#156741)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156741
Approved by: https://github.com/aorenste
This commit is contained in:
Xuehai Pan
2025-07-01 16:18:53 +08:00
committed by PyTorch MergeBot
parent 47f10d0ad0
commit 3bc6bdc866
3 changed files with 119 additions and 111 deletions

View File

@ -122,6 +122,7 @@ is_formatter = true
[[linter]] [[linter]]
code = 'MYPY' code = 'MYPY'
include_patterns = [ include_patterns = [
'setup.py',
'torch/**/*.py', 'torch/**/*.py',
'torch/**/*.pyi', 'torch/**/*.pyi',
'caffe2/**/*.py', 'caffe2/**/*.py',

View File

@ -57,7 +57,8 @@ setup-env-cuda:
setup-env-rocm: setup-env-rocm:
$(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm" $(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm"
.lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml .PHONY: setup-lint
setup-lint .lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml
@echo "Setting up lintrunner..." @echo "Setting up lintrunner..."
$(PIP) install lintrunner $(PIP) install lintrunner
lintrunner init lintrunner init
@ -65,9 +66,6 @@ setup-env-rocm:
@mkdir -p .lintbin @mkdir -p .lintbin
@sha256sum requirements.txt pyproject.toml .lintrunner.toml > .lintbin/.lintrunner.sha256 @sha256sum requirements.txt pyproject.toml .lintrunner.toml > .lintbin/.lintrunner.sha256
.PHONY: setup-lint
setup-lint: .lintbin/.lintrunner.sha256
.PHONY: lazy-setup-lint .PHONY: lazy-setup-lint
lazy-setup-lint: .lintbin/.lintrunner.sha256 lazy-setup-lint: .lintbin/.lintrunner.sha256
@if [ ! -x "$(shell command -v lintrunner)" ]; then \ @if [ ! -x "$(shell command -v lintrunner)" ]; then \

223
setup.py
View File

@ -264,15 +264,17 @@ import subprocess
import sysconfig import sysconfig
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Any, ClassVar, IO
import setuptools.command.build_ext import setuptools.command.build_ext
import setuptools.command.install import setuptools.command.install
import setuptools.command.sdist import setuptools.command.sdist
from setuptools import Extension, find_packages, setup import setuptools.errors
from setuptools import Command, Extension, find_packages, setup
from setuptools.dist import Distribution from setuptools.dist import Distribution
from tools.build_pytorch_libs import build_pytorch from tools.build_pytorch_libs import build_pytorch
from tools.generate_torch_version import get_torch_version from tools.generate_torch_version import get_torch_version
from tools.setup_helpers.cmake import CMake from tools.setup_helpers.cmake import CMake, CMakeValue
from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
from tools.setup_helpers.generate_linker_script import gen_linker_script from tools.setup_helpers.generate_linker_script import gen_linker_script
@ -318,7 +320,7 @@ def str2bool(value: str | None) -> bool:
raise ValueError(f"Invalid string value for boolean conversion: {value}") raise ValueError(f"Invalid string value for boolean conversion: {value}")
def _get_package_path(package_name): def _get_package_path(package_name: str) -> str | None:
spec = importlib.util.find_spec(package_name) spec = importlib.util.find_spec(package_name)
if spec: if spec:
# The package might be a namespace package, so get_data may fail # The package might be a namespace package, so get_data may fail
@ -381,16 +383,20 @@ sys.argv = filtered_args
if VERBOSE_SCRIPT: if VERBOSE_SCRIPT:
def report(*args, file=sys.stderr, **kwargs): def report(
print(*args, file=file, **kwargs) *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any
) -> None:
print(*args, file=file, flush=flush, **kwargs)
else: else:
def report(*args, **kwargs): def report(
*args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any
) -> None:
pass pass
# Make distutils respect --quiet too # Make distutils respect --quiet too
setuptools.distutils.log.warn = report setuptools.distutils.log.warn = report # type: ignore[attr-defined]
# Constant known variables used throughout this file # Constant known variables used throughout this file
cwd = os.path.dirname(os.path.abspath(__file__)) cwd = os.path.dirname(os.path.abspath(__file__))
@ -399,16 +405,16 @@ third_party_path = os.path.join(cwd, "third_party")
# CMAKE: full path to python library # CMAKE: full path to python library
if IS_WINDOWS: if IS_WINDOWS:
cmake_python_library = "{}/libs/python{}.lib".format( CMAKE_PYTHON_LIBRARY = "{}/libs/python{}.lib".format(
sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION") sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION")
) )
# Fix virtualenv builds # Fix virtualenv builds
if not os.path.exists(cmake_python_library): if not os.path.exists(CMAKE_PYTHON_LIBRARY):
cmake_python_library = "{}/libs/python{}.lib".format( CMAKE_PYTHON_LIBRARY = "{}/libs/python{}.lib".format(
sys.base_prefix, sysconfig.get_config_var("VERSION") sys.base_prefix, sysconfig.get_config_var("VERSION")
) )
else: else:
cmake_python_library = "{}/{}".format( CMAKE_PYTHON_LIBRARY = "{}/{}".format(
sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME") sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME")
) )
cmake_python_include_dir = sysconfig.get_path("include") cmake_python_include_dir = sysconfig.get_path("include")
@ -418,20 +424,18 @@ cmake_python_include_dir = sysconfig.get_path("include")
# Version, create_version_file, and package_name # Version, create_version_file, and package_name
################################################################################ ################################################################################
package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch")
LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python")
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
package_name = LIBTORCH_PKG_NAME TORCH_PACKAGE_NAME = LIBTORCH_PKG_NAME
TORCH_VERSION = get_torch_version()
package_type = os.getenv("PACKAGE_TYPE", "wheel") report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}")
version = get_torch_version()
report(f"Building wheel {package_name}-{version}")
cmake = CMake() cmake = CMake()
def get_submodule_folders(): def get_submodule_folders() -> list[str]:
git_modules_path = os.path.join(cwd, ".gitmodules") git_modules_path = os.path.join(cwd, ".gitmodules")
default_modules_path = [ default_modules_path = [
os.path.join(third_party_path, name) os.path.join(third_party_path, name)
@ -453,14 +457,14 @@ def get_submodule_folders():
] ]
def check_submodules(): def check_submodules() -> None:
def check_for_files(folder, files): def check_for_files(folder: str, files: list[str]) -> None:
if not any(os.path.exists(os.path.join(folder, f)) for f in files): if not any(os.path.exists(os.path.join(folder, f)) for f in files):
report("Could not find any of {} in {}".format(", ".join(files), folder)) report("Could not find any of {} in {}".format(", ".join(files), folder))
report("Did you run 'git submodule update --init --recursive'?") report("Did you run 'git submodule update --init --recursive'?")
sys.exit(1) sys.exit(1)
def not_exists_or_empty(folder): def not_exists_or_empty(folder: str) -> bool:
return not os.path.exists(folder) or ( return not os.path.exists(folder) or (
os.path.isdir(folder) and len(os.listdir(folder)) == 0 os.path.isdir(folder) and len(os.listdir(folder)) == 0
) )
@ -502,7 +506,7 @@ def check_submodules():
# Windows has very bad support for symbolic links. # Windows has very bad support for symbolic links.
# Instead of using symlinks, we're going to copy files over # Instead of using symlinks, we're going to copy files over
def mirror_files_into_torchgen(): def mirror_files_into_torchgen() -> None:
# (new_path, orig_path) # (new_path, orig_path)
# Directories are OK and are recursively mirrored. # Directories are OK and are recursively mirrored.
paths = [ paths = [
@ -534,15 +538,14 @@ def mirror_files_into_torchgen():
# all the work we need to do _before_ setup runs # all the work we need to do _before_ setup runs
def build_deps(): def build_deps() -> None:
report("-- Building version " + version) report(f"-- Building version {TORCH_VERSION}")
check_submodules() check_submodules()
check_pydep("yaml", "pyyaml") check_pydep("yaml", "pyyaml")
build_python = not BUILD_LIBTORCH_WHL
build_pytorch( build_pytorch(
version=version, version=TORCH_VERSION,
cmake_python_library=cmake_python_library, cmake_python_library=CMAKE_PYTHON_LIBRARY,
build_python=build_python, build_python=not BUILD_LIBTORCH_WHL,
rerun_cmake=RERUN_CMAKE, rerun_cmake=RERUN_CMAKE,
cmake_only=CMAKE_ONLY, cmake_only=CMAKE_ONLY,
cmake=cmake, cmake=cmake,
@ -589,7 +592,7 @@ Please install it via `conda install {module}` or `pip install {module}`
""".strip() """.strip()
def check_pydep(importname, module): def check_pydep(importname: str, module: str) -> None:
try: try:
importlib.import_module(importname) importlib.import_module(importname)
except ImportError as e: except ImportError as e:
@ -599,7 +602,7 @@ def check_pydep(importname, module):
class build_ext(setuptools.command.build_ext.build_ext): class build_ext(setuptools.command.build_ext.build_ext):
def _embed_libomp(self): def _embed_libomp(self) -> None:
# Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS
lib_dir = os.path.join(self.build_lib, "torch", "lib") lib_dir = os.path.join(self.build_lib, "torch", "lib")
libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib") libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib")
@ -623,8 +626,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
assert rpath.startswith("path ") assert rpath.startswith("path ")
rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1]) rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1])
omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment]
omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib" omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment]
omplib_name += ".dylib"
omplib_rpath_path = os.path.join("@rpath", omplib_name) omplib_rpath_path = os.path.join("@rpath", omplib_name)
# This logic is fragile and checks only two cases: # This logic is fragile and checks only two cases:
@ -636,6 +640,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
# Copy libomp/libiomp5 from rpath locations # Copy libomp/libiomp5 from rpath locations
target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name) target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name)
libomp_relocated = False libomp_relocated = False
install_name_tool_args: list[str] = []
for rpath in rpaths: for rpath in rpaths:
source_lib = os.path.join(rpath, omplib_name) source_lib = os.path.join(rpath, omplib_name)
if not os.path.exists(source_lib): if not os.path.exists(source_lib):
@ -670,7 +675,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
install_name_tool_args.append(libtorch_cpu_path) install_name_tool_args.append(libtorch_cpu_path)
subprocess.check_call(install_name_tool_args) subprocess.check_call(install_name_tool_args)
# Copy omp.h from OpenMP_C_FLAGS and copy it into include folder # Copy omp.h from OpenMP_C_FLAGS and copy it into include folder
omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"] omp_cflags: str = get_cmake_cache_vars()["OpenMP_C_FLAGS"] # type: ignore[assignment]
if not omp_cflags: if not omp_cflags:
return return
for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]: for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]:
@ -681,7 +686,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
self.copy_file(omp_h, target_omp_h) self.copy_file(omp_h, target_omp_h)
break break
def run(self): def run(self) -> None:
# Report build options. This is run after the build completes so # `CMakeCache.txt` exists # Report build options. This is run after the build completes so # `CMakeCache.txt` exists
# and we can get an accurate report on what is used and what is not. # and we can get an accurate report on what is used and what is not.
cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables())
@ -692,18 +697,17 @@ class build_ext(setuptools.command.build_ext.build_ext):
if cmake_cache_vars["USE_CUDNN"]: if cmake_cache_vars["USE_CUDNN"]:
report( report(
"-- Detected cuDNN at " "-- Detected cuDNN at "
+ cmake_cache_vars["CUDNN_LIBRARY"] f"{cmake_cache_vars['CUDNN_LIBRARY']}, "
+ ", " f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}"
+ cmake_cache_vars["CUDNN_INCLUDE_DIR"]
) )
else: else:
report("-- Not using cuDNN") report("-- Not using cuDNN")
if cmake_cache_vars["USE_CUDA"]: if cmake_cache_vars["USE_CUDA"]:
report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"]) report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}")
else: else:
report("-- Not using CUDA") report("-- Not using CUDA")
if cmake_cache_vars["USE_XPU"]: if cmake_cache_vars["USE_XPU"]:
report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"]) report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}")
else: else:
report("-- Not using XPU") report("-- Not using XPU")
if cmake_cache_vars["USE_MKLDNN"]: if cmake_cache_vars["USE_MKLDNN"]:
@ -722,10 +726,9 @@ class build_ext(setuptools.command.build_ext.build_ext):
report("-- Not using MKLDNN") report("-- Not using MKLDNN")
if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]: if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]:
report( report(
"-- Using system provided NCCL library at {}, {}".format( "-- Using system provided NCCL library at "
cmake_cache_vars["NCCL_LIBRARIES"], f"{cmake_cache_vars['NCCL_LIBRARIES']}, "
cmake_cache_vars["NCCL_INCLUDE_DIRS"], f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}"
)
) )
elif cmake_cache_vars["USE_NCCL"]: elif cmake_cache_vars["USE_NCCL"]:
report("-- Building NCCL library") report("-- Building NCCL library")
@ -736,18 +739,15 @@ class build_ext(setuptools.command.build_ext.build_ext):
report("-- Building without distributed package") report("-- Building without distributed package")
else: else:
report("-- Building with distributed package: ") report("-- Building with distributed package: ")
report( report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}")
" -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"]) report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}")
) report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}")
report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"]))
report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"]))
else: else:
report("-- Building without distributed package") report("-- Building without distributed package")
if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]: if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]:
report( report(
"-- Using static dispatch with backend {}".format( "-- Using static dispatch with "
cmake_cache_vars["STATIC_DISPATCH_BACKEND"] f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}"
)
) )
if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]: if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]:
report("-- Using lightweight dispatch") report("-- Using lightweight dispatch")
@ -801,7 +801,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
# In ROCm on Windows case copy rocblas and hipblaslt files into # In ROCm on Windows case copy rocblas and hipblaslt files into
# torch/lib/rocblas/library and torch/lib/hipblaslt/library # torch/lib/rocblas/library and torch/lib/hipblaslt/library
if str2bool(os.getenv("USE_ROCM")): if str2bool(os.getenv("USE_ROCM")):
rocm_dir_path = os.environ.get("ROCM_DIR") rocm_dir_path = os.environ["ROCM_DIR"]
rocm_bin_path = os.path.join(rocm_dir_path, "bin") rocm_bin_path = os.path.join(rocm_dir_path, "bin")
rocblas_dir = os.path.join(rocm_bin_path, "rocblas") rocblas_dir = os.path.join(rocm_bin_path, "rocblas")
@ -816,7 +816,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
else: else:
report("The specified environment variable does not exist.") report("The specified environment variable does not exist.")
def build_extensions(self): def build_extensions(self) -> None:
self.create_compile_commands() self.create_compile_commands()
# Copy functorch extension # Copy functorch extension
@ -837,14 +837,14 @@ class build_ext(setuptools.command.build_ext.build_ext):
setuptools.command.build_ext.build_ext.build_extensions(self) setuptools.command.build_ext.build_ext.build_extensions(self)
def get_outputs(self): def get_outputs(self) -> list[str]:
outputs = setuptools.command.build_ext.build_ext.get_outputs(self) outputs = setuptools.command.build_ext.build_ext.get_outputs(self)
outputs.append(os.path.join(self.build_lib, "caffe2")) outputs.append(os.path.join(self.build_lib, "caffe2"))
report(f"setup.py::get_outputs returning {outputs}") report(f"setup.py::get_outputs returning {outputs}")
return outputs return outputs
def create_compile_commands(self): def create_compile_commands(self) -> None:
def load(filename): def load(filename: str) -> Any:
with open(filename) as f: with open(filename) as f:
return json.load(f) return json.load(f)
@ -879,18 +879,18 @@ class concat_license_files:
licensing info. licensing info.
""" """
def __init__(self, include_files=False): def __init__(self, include_files: bool = False) -> None:
self.f1 = "LICENSE" self.f1 = "LICENSE"
self.f2 = "third_party/LICENSES_BUNDLED.txt" self.f2 = "third_party/LICENSES_BUNDLED.txt"
self.include_files = include_files self.include_files = include_files
def __enter__(self): def __enter__(self) -> None:
"""Concatenate files""" """Concatenate files"""
old_path = sys.path old_path = sys.path
sys.path.append(third_party_path) sys.path.append(third_party_path)
try: try:
from build_bundled import create_bundled from build_bundled import create_bundled # type: ignore[import-not-found]
finally: finally:
sys.path = old_path sys.path = old_path
@ -903,29 +903,29 @@ class concat_license_files:
os.path.relpath(third_party_path), f1, include_files=self.include_files os.path.relpath(third_party_path), f1, include_files=self.include_files
) )
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, *exc_info: object) -> None:
"""Restore content of f1""" """Restore content of f1"""
with open(self.f1, "w") as f: with open(self.f1, "w") as f:
f.write(self.bsd_text) f.write(self.bsd_text)
try: try:
from wheel.bdist_wheel import bdist_wheel from wheel.bdist_wheel import bdist_wheel # type: ignore[import-untyped]
except ImportError: except ImportError:
# This is useful when wheel is not installed and bdist_wheel is not # This is useful when wheel is not installed and bdist_wheel is not
# specified on the command line. If it _is_ specified, parsing the command # specified on the command line. If it _is_ specified, parsing the command
# line will fail before wheel_concatenate is needed # line will fail before wheel_concatenate is needed
wheel_concatenate = None wheel_concatenate: type[Command] | None = None
else: else:
# Need to create the proper LICENSE.txt for the wheel # Need to create the proper LICENSE.txt for the wheel
class wheel_concatenate(bdist_wheel): class wheel_concatenate(bdist_wheel): # type: ignore[no-redef]
"""check submodules on sdist to prevent incomplete tarballs""" """check submodules on sdist to prevent incomplete tarballs"""
def run(self): def run(self) -> None:
with concat_license_files(include_files=True): with concat_license_files(include_files=True):
super().run() super().run()
def write_wheelfile(self, *args, **kwargs): def write_wheelfile(self, *args: Any, **kwargs: Any) -> None:
super().write_wheelfile(*args, **kwargs) super().write_wheelfile(*args, **kwargs)
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
@ -943,21 +943,20 @@ else:
class install(setuptools.command.install.install): class install(setuptools.command.install.install):
def run(self): def run(self) -> None:
super().run() super().run()
class clean(setuptools.Command): class clean(Command):
user_options = [] user_options: ClassVar[list[tuple[str, str | None, str]]] = []
def initialize_options(self): def initialize_options(self) -> None:
pass pass
def finalize_options(self): def finalize_options(self) -> None:
pass pass
def run(self): def run(self) -> None:
import glob
import re import re
with open(".gitignore") as f: with open(".gitignore") as f:
@ -982,12 +981,12 @@ class clean(setuptools.Command):
class sdist(setuptools.command.sdist.sdist): class sdist(setuptools.command.sdist.sdist):
def run(self): def run(self) -> None:
with concat_license_files(): with concat_license_files():
super().run() super().run()
def get_cmake_cache_vars(): def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]:
try: try:
return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) return defaultdict(lambda: False, cmake.get_cmake_cache_variables())
except FileNotFoundError: except FileNotFoundError:
@ -996,7 +995,13 @@ def get_cmake_cache_vars():
return defaultdict(lambda: False) return defaultdict(lambda: False)
def configure_extension_build(): def configure_extension_build() -> tuple[
list[Extension], # ext_modules
dict[str, type[Command]], # cmdclass
list[str], # packages
dict[str, list[str]], # entry_points
list[str], # extra_install_requires
]:
r"""Configures extension build options according to system environment and user's choice. r"""Configures extension build options according to system environment and user's choice.
Returns: Returns:
@ -1009,17 +1014,17 @@ def configure_extension_build():
# Configure compile flags # Configure compile flags
################################################################################ ################################################################################
library_dirs = [] library_dirs: list[str] = []
extra_install_requires = [] extra_install_requires: list[str] = []
if IS_WINDOWS: if IS_WINDOWS:
# /NODEFAULTLIB makes sure we only link to DLL runtime # /NODEFAULTLIB makes sure we only link to DLL runtime
# and matches the flags set for protobuf and ONNX # and matches the flags set for protobuf and ONNX
extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"] extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"]
# /MD links against DLL runtime # /MD links against DLL runtime
# and matches the flags set for protobuf and ONNX # and matches the flags set for protobuf and ONNX
# /EHsc is about standard C++ exception handling # /EHsc is about standard C++ exception handling
extra_compile_args = ["/MD", "/FS", "/EHsc"] extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
else: else:
extra_link_args = [] extra_link_args = []
extra_compile_args = [ extra_compile_args = [
@ -1037,11 +1042,11 @@ def configure_extension_build():
library_dirs.append(lib_path) library_dirs.append(lib_path)
main_compile_args = [] main_compile_args: list[str] = []
main_libraries = ["torch_python"] main_libraries: list[str] = ["torch_python"]
main_link_args = [] main_link_args: list[str] = []
main_sources = ["torch/csrc/stub.c"] main_sources: list[str] = ["torch/csrc/stub.c"]
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
main_libraries = ["torch"] main_libraries = ["torch"]
@ -1049,16 +1054,16 @@ def configure_extension_build():
if build_type.is_debug(): if build_type.is_debug():
if IS_WINDOWS: if IS_WINDOWS:
extra_compile_args.append("/Z7") extra_compile_args += ["/Z7"]
extra_link_args.append("/DEBUG:FULL") extra_link_args += ["/DEBUG:FULL"]
else: else:
extra_compile_args += ["-O0", "-g"] extra_compile_args += ["-O0", "-g"]
extra_link_args += ["-O0", "-g"] extra_link_args += ["-O0", "-g"]
if build_type.is_rel_with_deb_info(): if build_type.is_rel_with_deb_info():
if IS_WINDOWS: if IS_WINDOWS:
extra_compile_args.append("/Z7") extra_compile_args += ["/Z7"]
extra_link_args.append("/DEBUG:FULL") extra_link_args += ["/DEBUG:FULL"]
else: else:
extra_compile_args += ["-g"] extra_compile_args += ["-g"]
extra_link_args += ["-g"] extra_link_args += ["-g"]
@ -1095,7 +1100,7 @@ def configure_extension_build():
] ]
extra_link_args += ["-arch", macos_target_arch] extra_link_args += ["-arch", macos_target_arch]
def make_relative_rpath_args(path): def make_relative_rpath_args(path: str) -> list[str]:
if IS_DARWIN: if IS_DARWIN:
return ["-Wl,-rpath,@loader_path/" + path] return ["-Wl,-rpath,@loader_path/" + path]
elif IS_WINDOWS: elif IS_WINDOWS:
@ -1107,7 +1112,7 @@ def configure_extension_build():
# Declare extensions and package # Declare extensions and package
################################################################################ ################################################################################
extensions = [] ext_modules: list[Extension] = []
excludes = ["tools", "tools.*", "caffe2", "caffe2.*"] excludes = ["tools", "tools.*", "caffe2", "caffe2.*"]
if not cmake_cache_vars["BUILD_FUNCTORCH"]: if not cmake_cache_vars["BUILD_FUNCTORCH"]:
excludes.extend(["functorch", "functorch.*"]) excludes.extend(["functorch", "functorch.*"])
@ -1117,29 +1122,33 @@ def configure_extension_build():
libraries=main_libraries, libraries=main_libraries,
sources=main_sources, sources=main_sources,
language="c", language="c",
extra_compile_args=main_compile_args + extra_compile_args, extra_compile_args=[
*main_compile_args,
*extra_compile_args,
],
include_dirs=[], include_dirs=[],
library_dirs=library_dirs, library_dirs=library_dirs,
extra_link_args=extra_link_args extra_link_args=[
+ main_link_args *extra_link_args,
+ make_relative_rpath_args("lib"), *main_link_args,
*make_relative_rpath_args("lib"),
],
) )
extensions.append(C) ext_modules.append(C)
# These extensions are built by cmake and copied manually in build_extensions() # These extensions are built by cmake and copied manually in build_extensions()
# inside the build_ext implementation # inside the build_ext implementation
if cmake_cache_vars["BUILD_FUNCTORCH"]: if cmake_cache_vars["BUILD_FUNCTORCH"]:
extensions.append( ext_modules.append(Extension(name="functorch._C", sources=[]))
Extension(name="functorch._C", sources=[]),
)
cmdclass = { cmdclass = {
"bdist_wheel": wheel_concatenate,
"build_ext": build_ext, "build_ext": build_ext,
"clean": clean, "clean": clean,
"install": install, "install": install,
"sdist": sdist, "sdist": sdist,
} }
if wheel_concatenate is not None:
cmdclass["bdist_wheel"] = wheel_concatenate
entry_points = { entry_points = {
"console_scripts": [ "console_scripts": [
@ -1155,7 +1164,7 @@ def configure_extension_build():
entry_points["console_scripts"].append( entry_points["console_scripts"].append(
"torchfrtrace = tools.flight_recorder.fr_trace:main", "torchfrtrace = tools.flight_recorder.fr_trace:main",
) )
return extensions, cmdclass, packages, entry_points, extra_install_requires return ext_modules, cmdclass, packages, entry_points, extra_install_requires
# post run, warnings, printed at the end to make them more visible # post run, warnings, printed at the end to make them more visible
@ -1171,7 +1180,7 @@ build_update_message = """
""" """
def print_box(msg): def print_box(msg: str) -> None:
lines = msg.split("\n") lines = msg.split("\n")
size = max(len(l) + 1 for l in lines) size = max(len(l) + 1 for l in lines)
print("-" * (size + 2)) print("-" * (size + 2))
@ -1180,7 +1189,7 @@ def print_box(msg):
print("-" * (size + 2)) print("-" * (size + 2))
def main(): def main() -> None:
if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY: if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY:
raise RuntimeError( raise RuntimeError(
"Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. " "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. "
@ -1226,7 +1235,7 @@ def main():
dist.script_args = sys.argv[1:] dist.script_args = sys.argv[1:]
try: try:
dist.parse_command_line() dist.parse_command_line()
except setuptools.distutils.errors.DistutilsArgError as e: except setuptools.errors.BaseError as e:
print(e) print(e)
sys.exit(1) sys.exit(1)
@ -1235,7 +1244,7 @@ def main():
build_deps() build_deps()
( (
extensions, ext_modules,
cmdclass, cmdclass,
packages, packages,
entry_points, entry_points,
@ -1345,17 +1354,17 @@ def main():
package_data["torchgen"] = torchgen_package_data package_data["torchgen"] = torchgen_package_data
else: else:
# no extensions in BUILD_LIBTORCH_WHL mode # no extensions in BUILD_LIBTORCH_WHL mode
extensions = [] ext_modules = []
setup( setup(
name=package_name, name=TORCH_PACKAGE_NAME,
version=version, version=TORCH_VERSION,
description=( description=(
"Tensors and Dynamic neural networks in Python with strong GPU acceleration" "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
), ),
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
ext_modules=extensions, ext_modules=ext_modules,
cmdclass=cmdclass, cmdclass=cmdclass,
packages=packages, packages=packages,
entry_points=entry_points, entry_points=entry_points,