From 3bc6bdc8660c052d932f550d5734da6f801c2630 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 1 Jul 2025 16:18:53 +0800 Subject: [PATCH] [BE] add type annotations and run `mypy` on `setup.py` (#156741) Pull Request resolved: https://github.com/pytorch/pytorch/pull/156741 Approved by: https://github.com/aorenste --- .lintrunner.toml | 1 + Makefile | 6 +- setup.py | 223 ++++++++++++++++++++++++----------------------- 3 files changed, 119 insertions(+), 111 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 4b925b09a50d..77efcc8e1327 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -122,6 +122,7 @@ is_formatter = true [[linter]] code = 'MYPY' include_patterns = [ + 'setup.py', 'torch/**/*.py', 'torch/**/*.pyi', 'caffe2/**/*.py', diff --git a/Makefile b/Makefile index c5d8e71632dc..3db2b7aa44e7 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,8 @@ setup-env-cuda: setup-env-rocm: $(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm" -.lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml +.PHONY: setup-lint +setup-lint .lintbin/.lintrunner.sha256: requirements.txt pyproject.toml .lintrunner.toml @echo "Setting up lintrunner..." $(PIP) install lintrunner lintrunner init @@ -65,9 +66,6 @@ setup-env-rocm: @mkdir -p .lintbin @sha256sum requirements.txt pyproject.toml .lintrunner.toml > .lintbin/.lintrunner.sha256 -.PHONY: setup-lint -setup-lint: .lintbin/.lintrunner.sha256 - .PHONY: lazy-setup-lint lazy-setup-lint: .lintbin/.lintrunner.sha256 @if [ ! -x "$(shell command -v lintrunner)" ]; then \ diff --git a/setup.py b/setup.py index 67833b326f66..26c694c88c98 100644 --- a/setup.py +++ b/setup.py @@ -264,15 +264,17 @@ import subprocess import sysconfig import time from collections import defaultdict +from typing import Any, ClassVar, IO import setuptools.command.build_ext import setuptools.command.install import setuptools.command.sdist -from setuptools import Extension, find_packages, setup +import setuptools.errors +from setuptools import Command, Extension, find_packages, setup from setuptools.dist import Distribution from tools.build_pytorch_libs import build_pytorch from tools.generate_torch_version import get_torch_version -from tools.setup_helpers.cmake import CMake +from tools.setup_helpers.cmake import CMake, CMakeValue from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS from tools.setup_helpers.generate_linker_script import gen_linker_script @@ -318,7 +320,7 @@ def str2bool(value: str | None) -> bool: raise ValueError(f"Invalid string value for boolean conversion: {value}") -def _get_package_path(package_name): +def _get_package_path(package_name: str) -> str | None: spec = importlib.util.find_spec(package_name) if spec: # The package might be a namespace package, so get_data may fail @@ -381,16 +383,20 @@ sys.argv = filtered_args if VERBOSE_SCRIPT: - def report(*args, file=sys.stderr, **kwargs): - print(*args, file=file, **kwargs) + def report( + *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any + ) -> None: + print(*args, file=file, flush=flush, **kwargs) else: - def report(*args, **kwargs): + def report( + *args: Any, file: IO[str] = sys.stderr, flush: bool = True, **kwargs: Any + ) -> None: pass # Make distutils respect --quiet too - setuptools.distutils.log.warn = report + setuptools.distutils.log.warn = report # type: ignore[attr-defined] # Constant known variables used throughout this file cwd = os.path.dirname(os.path.abspath(__file__)) @@ -399,16 +405,16 @@ third_party_path = os.path.join(cwd, "third_party") # CMAKE: full path to python library if IS_WINDOWS: - cmake_python_library = "{}/libs/python{}.lib".format( + CMAKE_PYTHON_LIBRARY = "{}/libs/python{}.lib".format( sysconfig.get_config_var("prefix"), sysconfig.get_config_var("VERSION") ) # Fix virtualenv builds - if not os.path.exists(cmake_python_library): - cmake_python_library = "{}/libs/python{}.lib".format( + if not os.path.exists(CMAKE_PYTHON_LIBRARY): + CMAKE_PYTHON_LIBRARY = "{}/libs/python{}.lib".format( sys.base_prefix, sysconfig.get_config_var("VERSION") ) else: - cmake_python_library = "{}/{}".format( + CMAKE_PYTHON_LIBRARY = "{}/{}".format( sysconfig.get_config_var("LIBDIR"), sysconfig.get_config_var("INSTSONAME") ) cmake_python_include_dir = sysconfig.get_path("include") @@ -418,20 +424,18 @@ cmake_python_include_dir = sysconfig.get_path("include") # Version, create_version_file, and package_name ################################################################################ -package_name = os.getenv("TORCH_PACKAGE_NAME", "torch") +TORCH_PACKAGE_NAME = os.getenv("TORCH_PACKAGE_NAME", "torch") LIBTORCH_PKG_NAME = os.getenv("LIBTORCH_PACKAGE_NAME", "torch_no_python") if BUILD_LIBTORCH_WHL: - package_name = LIBTORCH_PKG_NAME + TORCH_PACKAGE_NAME = LIBTORCH_PKG_NAME - -package_type = os.getenv("PACKAGE_TYPE", "wheel") -version = get_torch_version() -report(f"Building wheel {package_name}-{version}") +TORCH_VERSION = get_torch_version() +report(f"Building wheel {TORCH_PACKAGE_NAME}-{TORCH_VERSION}") cmake = CMake() -def get_submodule_folders(): +def get_submodule_folders() -> list[str]: git_modules_path = os.path.join(cwd, ".gitmodules") default_modules_path = [ os.path.join(third_party_path, name) @@ -453,14 +457,14 @@ def get_submodule_folders(): ] -def check_submodules(): - def check_for_files(folder, files): +def check_submodules() -> None: + def check_for_files(folder: str, files: list[str]) -> None: if not any(os.path.exists(os.path.join(folder, f)) for f in files): report("Could not find any of {} in {}".format(", ".join(files), folder)) report("Did you run 'git submodule update --init --recursive'?") sys.exit(1) - def not_exists_or_empty(folder): + def not_exists_or_empty(folder: str) -> bool: return not os.path.exists(folder) or ( os.path.isdir(folder) and len(os.listdir(folder)) == 0 ) @@ -502,7 +506,7 @@ def check_submodules(): # Windows has very bad support for symbolic links. # Instead of using symlinks, we're going to copy files over -def mirror_files_into_torchgen(): +def mirror_files_into_torchgen() -> None: # (new_path, orig_path) # Directories are OK and are recursively mirrored. paths = [ @@ -534,15 +538,14 @@ def mirror_files_into_torchgen(): # all the work we need to do _before_ setup runs -def build_deps(): - report("-- Building version " + version) +def build_deps() -> None: + report(f"-- Building version {TORCH_VERSION}") check_submodules() check_pydep("yaml", "pyyaml") - build_python = not BUILD_LIBTORCH_WHL build_pytorch( - version=version, - cmake_python_library=cmake_python_library, - build_python=build_python, + version=TORCH_VERSION, + cmake_python_library=CMAKE_PYTHON_LIBRARY, + build_python=not BUILD_LIBTORCH_WHL, rerun_cmake=RERUN_CMAKE, cmake_only=CMAKE_ONLY, cmake=cmake, @@ -589,7 +592,7 @@ Please install it via `conda install {module}` or `pip install {module}` """.strip() -def check_pydep(importname, module): +def check_pydep(importname: str, module: str) -> None: try: importlib.import_module(importname) except ImportError as e: @@ -599,7 +602,7 @@ def check_pydep(importname, module): class build_ext(setuptools.command.build_ext.build_ext): - def _embed_libomp(self): + def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS lib_dir = os.path.join(self.build_lib, "torch", "lib") libtorch_cpu_path = os.path.join(lib_dir, "libtorch_cpu.dylib") @@ -623,8 +626,9 @@ class build_ext(setuptools.command.build_ext.build_ext): assert rpath.startswith("path ") rpaths.append(rpath.split(" ", 1)[1].rsplit("(", 1)[0][:-1]) - omplib_path = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] - omplib_name = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] + ".dylib" + omplib_path: str = get_cmake_cache_vars()["OpenMP_libomp_LIBRARY"] # type: ignore[assignment] + omplib_name: str = get_cmake_cache_vars()["OpenMP_C_LIB_NAMES"] # type: ignore[assignment] + omplib_name += ".dylib" omplib_rpath_path = os.path.join("@rpath", omplib_name) # This logic is fragile and checks only two cases: @@ -636,6 +640,7 @@ class build_ext(setuptools.command.build_ext.build_ext): # Copy libomp/libiomp5 from rpath locations target_lib = os.path.join(self.build_lib, "torch", "lib", omplib_name) libomp_relocated = False + install_name_tool_args: list[str] = [] for rpath in rpaths: source_lib = os.path.join(rpath, omplib_name) if not os.path.exists(source_lib): @@ -670,7 +675,7 @@ class build_ext(setuptools.command.build_ext.build_ext): install_name_tool_args.append(libtorch_cpu_path) subprocess.check_call(install_name_tool_args) # Copy omp.h from OpenMP_C_FLAGS and copy it into include folder - omp_cflags = get_cmake_cache_vars()["OpenMP_C_FLAGS"] + omp_cflags: str = get_cmake_cache_vars()["OpenMP_C_FLAGS"] # type: ignore[assignment] if not omp_cflags: return for include_dir in [f[2:] for f in omp_cflags.split(" ") if f.startswith("-I")]: @@ -681,7 +686,7 @@ class build_ext(setuptools.command.build_ext.build_ext): self.copy_file(omp_h, target_omp_h) break - def run(self): + def run(self) -> None: # Report build options. This is run after the build completes so # `CMakeCache.txt` exists # and we can get an accurate report on what is used and what is not. cmake_cache_vars = defaultdict(lambda: False, cmake.get_cmake_cache_variables()) @@ -692,18 +697,17 @@ class build_ext(setuptools.command.build_ext.build_ext): if cmake_cache_vars["USE_CUDNN"]: report( "-- Detected cuDNN at " - + cmake_cache_vars["CUDNN_LIBRARY"] - + ", " - + cmake_cache_vars["CUDNN_INCLUDE_DIR"] + f"{cmake_cache_vars['CUDNN_LIBRARY']}, " + f"{cmake_cache_vars['CUDNN_INCLUDE_DIR']}" ) else: report("-- Not using cuDNN") if cmake_cache_vars["USE_CUDA"]: - report("-- Detected CUDA at " + cmake_cache_vars["CUDA_TOOLKIT_ROOT_DIR"]) + report(f"-- Detected CUDA at {cmake_cache_vars['CUDA_TOOLKIT_ROOT_DIR']}") else: report("-- Not using CUDA") if cmake_cache_vars["USE_XPU"]: - report("-- Detected XPU runtime at " + cmake_cache_vars["SYCL_LIBRARY_DIR"]) + report(f"-- Detected XPU runtime at {cmake_cache_vars['SYCL_LIBRARY_DIR']}") else: report("-- Not using XPU") if cmake_cache_vars["USE_MKLDNN"]: @@ -722,10 +726,9 @@ class build_ext(setuptools.command.build_ext.build_ext): report("-- Not using MKLDNN") if cmake_cache_vars["USE_NCCL"] and cmake_cache_vars["USE_SYSTEM_NCCL"]: report( - "-- Using system provided NCCL library at {}, {}".format( - cmake_cache_vars["NCCL_LIBRARIES"], - cmake_cache_vars["NCCL_INCLUDE_DIRS"], - ) + "-- Using system provided NCCL library at " + f"{cmake_cache_vars['NCCL_LIBRARIES']}, " + f"{cmake_cache_vars['NCCL_INCLUDE_DIRS']}" ) elif cmake_cache_vars["USE_NCCL"]: report("-- Building NCCL library") @@ -736,18 +739,15 @@ class build_ext(setuptools.command.build_ext.build_ext): report("-- Building without distributed package") else: report("-- Building with distributed package: ") - report( - " -- USE_TENSORPIPE={}".format(cmake_cache_vars["USE_TENSORPIPE"]) - ) - report(" -- USE_GLOO={}".format(cmake_cache_vars["USE_GLOO"])) - report(" -- USE_MPI={}".format(cmake_cache_vars["USE_OPENMPI"])) + report(f" -- USE_TENSORPIPE={cmake_cache_vars['USE_TENSORPIPE']}") + report(f" -- USE_GLOO={cmake_cache_vars['USE_GLOO']}") + report(f" -- USE_MPI={cmake_cache_vars['USE_OPENMPI']}") else: report("-- Building without distributed package") if cmake_cache_vars["STATIC_DISPATCH_BACKEND"]: report( - "-- Using static dispatch with backend {}".format( - cmake_cache_vars["STATIC_DISPATCH_BACKEND"] - ) + "-- Using static dispatch with " + f"backend {cmake_cache_vars['STATIC_DISPATCH_BACKEND']}" ) if cmake_cache_vars["USE_LIGHTWEIGHT_DISPATCH"]: report("-- Using lightweight dispatch") @@ -801,7 +801,7 @@ class build_ext(setuptools.command.build_ext.build_ext): # In ROCm on Windows case copy rocblas and hipblaslt files into # torch/lib/rocblas/library and torch/lib/hipblaslt/library if str2bool(os.getenv("USE_ROCM")): - rocm_dir_path = os.environ.get("ROCM_DIR") + rocm_dir_path = os.environ["ROCM_DIR"] rocm_bin_path = os.path.join(rocm_dir_path, "bin") rocblas_dir = os.path.join(rocm_bin_path, "rocblas") @@ -816,7 +816,7 @@ class build_ext(setuptools.command.build_ext.build_ext): else: report("The specified environment variable does not exist.") - def build_extensions(self): + def build_extensions(self) -> None: self.create_compile_commands() # Copy functorch extension @@ -837,14 +837,14 @@ class build_ext(setuptools.command.build_ext.build_ext): setuptools.command.build_ext.build_ext.build_extensions(self) - def get_outputs(self): + def get_outputs(self) -> list[str]: outputs = setuptools.command.build_ext.build_ext.get_outputs(self) outputs.append(os.path.join(self.build_lib, "caffe2")) report(f"setup.py::get_outputs returning {outputs}") return outputs - def create_compile_commands(self): - def load(filename): + def create_compile_commands(self) -> None: + def load(filename: str) -> Any: with open(filename) as f: return json.load(f) @@ -879,18 +879,18 @@ class concat_license_files: licensing info. """ - def __init__(self, include_files=False): + def __init__(self, include_files: bool = False) -> None: self.f1 = "LICENSE" self.f2 = "third_party/LICENSES_BUNDLED.txt" self.include_files = include_files - def __enter__(self): + def __enter__(self) -> None: """Concatenate files""" old_path = sys.path sys.path.append(third_party_path) try: - from build_bundled import create_bundled + from build_bundled import create_bundled # type: ignore[import-not-found] finally: sys.path = old_path @@ -903,29 +903,29 @@ class concat_license_files: os.path.relpath(third_party_path), f1, include_files=self.include_files ) - def __exit__(self, exception_type, exception_value, traceback): + def __exit__(self, *exc_info: object) -> None: """Restore content of f1""" with open(self.f1, "w") as f: f.write(self.bsd_text) try: - from wheel.bdist_wheel import bdist_wheel + from wheel.bdist_wheel import bdist_wheel # type: ignore[import-untyped] except ImportError: # This is useful when wheel is not installed and bdist_wheel is not # specified on the command line. If it _is_ specified, parsing the command # line will fail before wheel_concatenate is needed - wheel_concatenate = None + wheel_concatenate: type[Command] | None = None else: # Need to create the proper LICENSE.txt for the wheel - class wheel_concatenate(bdist_wheel): + class wheel_concatenate(bdist_wheel): # type: ignore[no-redef] """check submodules on sdist to prevent incomplete tarballs""" - def run(self): + def run(self) -> None: with concat_license_files(include_files=True): super().run() - def write_wheelfile(self, *args, **kwargs): + def write_wheelfile(self, *args: Any, **kwargs: Any) -> None: super().write_wheelfile(*args, **kwargs) if BUILD_LIBTORCH_WHL: @@ -943,21 +943,20 @@ else: class install(setuptools.command.install.install): - def run(self): + def run(self) -> None: super().run() -class clean(setuptools.Command): - user_options = [] +class clean(Command): + user_options: ClassVar[list[tuple[str, str | None, str]]] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): - import glob + def run(self) -> None: import re with open(".gitignore") as f: @@ -982,12 +981,12 @@ class clean(setuptools.Command): class sdist(setuptools.command.sdist.sdist): - def run(self): + def run(self) -> None: with concat_license_files(): super().run() -def get_cmake_cache_vars(): +def get_cmake_cache_vars() -> defaultdict[str, CMakeValue]: try: return defaultdict(lambda: False, cmake.get_cmake_cache_variables()) except FileNotFoundError: @@ -996,7 +995,13 @@ def get_cmake_cache_vars(): return defaultdict(lambda: False) -def configure_extension_build(): +def configure_extension_build() -> tuple[ + list[Extension], # ext_modules + dict[str, type[Command]], # cmdclass + list[str], # packages + dict[str, list[str]], # entry_points + list[str], # extra_install_requires +]: r"""Configures extension build options according to system environment and user's choice. Returns: @@ -1009,17 +1014,17 @@ def configure_extension_build(): # Configure compile flags ################################################################################ - library_dirs = [] - extra_install_requires = [] + library_dirs: list[str] = [] + extra_install_requires: list[str] = [] if IS_WINDOWS: # /NODEFAULTLIB makes sure we only link to DLL runtime # and matches the flags set for protobuf and ONNX - extra_link_args = ["/NODEFAULTLIB:LIBCMT.LIB"] + extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] # /MD links against DLL runtime # and matches the flags set for protobuf and ONNX # /EHsc is about standard C++ exception handling - extra_compile_args = ["/MD", "/FS", "/EHsc"] + extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] else: extra_link_args = [] extra_compile_args = [ @@ -1037,11 +1042,11 @@ def configure_extension_build(): library_dirs.append(lib_path) - main_compile_args = [] - main_libraries = ["torch_python"] + main_compile_args: list[str] = [] + main_libraries: list[str] = ["torch_python"] - main_link_args = [] - main_sources = ["torch/csrc/stub.c"] + main_link_args: list[str] = [] + main_sources: list[str] = ["torch/csrc/stub.c"] if BUILD_LIBTORCH_WHL: main_libraries = ["torch"] @@ -1049,16 +1054,16 @@ def configure_extension_build(): if build_type.is_debug(): if IS_WINDOWS: - extra_compile_args.append("/Z7") - extra_link_args.append("/DEBUG:FULL") + extra_compile_args += ["/Z7"] + extra_link_args += ["/DEBUG:FULL"] else: extra_compile_args += ["-O0", "-g"] extra_link_args += ["-O0", "-g"] if build_type.is_rel_with_deb_info(): if IS_WINDOWS: - extra_compile_args.append("/Z7") - extra_link_args.append("/DEBUG:FULL") + extra_compile_args += ["/Z7"] + extra_link_args += ["/DEBUG:FULL"] else: extra_compile_args += ["-g"] extra_link_args += ["-g"] @@ -1095,7 +1100,7 @@ def configure_extension_build(): ] extra_link_args += ["-arch", macos_target_arch] - def make_relative_rpath_args(path): + def make_relative_rpath_args(path: str) -> list[str]: if IS_DARWIN: return ["-Wl,-rpath,@loader_path/" + path] elif IS_WINDOWS: @@ -1107,7 +1112,7 @@ def configure_extension_build(): # Declare extensions and package ################################################################################ - extensions = [] + ext_modules: list[Extension] = [] excludes = ["tools", "tools.*", "caffe2", "caffe2.*"] if not cmake_cache_vars["BUILD_FUNCTORCH"]: excludes.extend(["functorch", "functorch.*"]) @@ -1117,29 +1122,33 @@ def configure_extension_build(): libraries=main_libraries, sources=main_sources, language="c", - extra_compile_args=main_compile_args + extra_compile_args, + extra_compile_args=[ + *main_compile_args, + *extra_compile_args, + ], include_dirs=[], library_dirs=library_dirs, - extra_link_args=extra_link_args - + main_link_args - + make_relative_rpath_args("lib"), + extra_link_args=[ + *extra_link_args, + *main_link_args, + *make_relative_rpath_args("lib"), + ], ) - extensions.append(C) + ext_modules.append(C) # These extensions are built by cmake and copied manually in build_extensions() # inside the build_ext implementation if cmake_cache_vars["BUILD_FUNCTORCH"]: - extensions.append( - Extension(name="functorch._C", sources=[]), - ) + ext_modules.append(Extension(name="functorch._C", sources=[])) cmdclass = { - "bdist_wheel": wheel_concatenate, "build_ext": build_ext, "clean": clean, "install": install, "sdist": sdist, } + if wheel_concatenate is not None: + cmdclass["bdist_wheel"] = wheel_concatenate entry_points = { "console_scripts": [ @@ -1155,7 +1164,7 @@ def configure_extension_build(): entry_points["console_scripts"].append( "torchfrtrace = tools.flight_recorder.fr_trace:main", ) - return extensions, cmdclass, packages, entry_points, extra_install_requires + return ext_modules, cmdclass, packages, entry_points, extra_install_requires # post run, warnings, printed at the end to make them more visible @@ -1171,7 +1180,7 @@ build_update_message = """ """ -def print_box(msg): +def print_box(msg: str) -> None: lines = msg.split("\n") size = max(len(l) + 1 for l in lines) print("-" * (size + 2)) @@ -1180,7 +1189,7 @@ def print_box(msg): print("-" * (size + 2)) -def main(): +def main() -> None: if BUILD_LIBTORCH_WHL and BUILD_PYTHON_ONLY: raise RuntimeError( "Conflict: 'BUILD_LIBTORCH_WHL' and 'BUILD_PYTHON_ONLY' can't both be 1. " @@ -1226,7 +1235,7 @@ def main(): dist.script_args = sys.argv[1:] try: dist.parse_command_line() - except setuptools.distutils.errors.DistutilsArgError as e: + except setuptools.errors.BaseError as e: print(e) sys.exit(1) @@ -1235,7 +1244,7 @@ def main(): build_deps() ( - extensions, + ext_modules, cmdclass, packages, entry_points, @@ -1345,17 +1354,17 @@ def main(): package_data["torchgen"] = torchgen_package_data else: # no extensions in BUILD_LIBTORCH_WHL mode - extensions = [] + ext_modules = [] setup( - name=package_name, - version=version, + name=TORCH_PACKAGE_NAME, + version=TORCH_VERSION, description=( "Tensors and Dynamic neural networks in Python with strong GPU acceleration" ), long_description=long_description, long_description_content_type="text/markdown", - ext_modules=extensions, + ext_modules=ext_modules, cmdclass=cmdclass, packages=packages, entry_points=entry_points,