[BE][setup] gracefully handle envvars representing a boolean in setup.py (#156040)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156040
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2025-06-16 22:36:59 +08:00
committed by PyTorch MergeBot
parent f48a157660
commit 4162c0f702

View File

@ -221,6 +221,8 @@
# BUILD_PYTHON_ONLY # BUILD_PYTHON_ONLY
# Builds pytorch as a wheel using libtorch.so from a separate wheel # Builds pytorch as a wheel using libtorch.so from a separate wheel
from __future__ import annotations
import os import os
import sys import sys
@ -234,9 +236,6 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
import platform import platform
BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1"
BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1"
python_min_version = (3, 9, 0) python_min_version = (3, 9, 0)
python_min_version_str = ".".join(map(str, python_min_version)) python_min_version_str = ".".join(map(str, python_min_version))
if sys.version_info < python_min_version: if sys.version_info < python_min_version:
@ -268,6 +267,47 @@ from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
from tools.setup_helpers.generate_linker_script import gen_linker_script from tools.setup_helpers.generate_linker_script import gen_linker_script
def str2bool(value: str | None) -> bool:
"""Convert environment variables to boolean values."""
if not value:
return False
if not isinstance(value, str):
raise ValueError(
f"Expected a string value for boolean conversion, got {type(value)}"
)
value = value.strip().lower()
if value in (
"1",
"true",
"t",
"yes",
"y",
"on",
"enable",
"enabled",
"found",
):
return True
if value in (
"0",
"false",
"f",
"no",
"n",
"off",
"disable",
"disabled",
"notfound",
"none",
"null",
"nil",
"undefined",
"n/a",
):
return False
raise ValueError(f"Invalid string value for boolean conversion: {value}")
def _get_package_path(package_name): def _get_package_path(package_name):
spec = importlib.util.find_spec(package_name) spec = importlib.util.find_spec(package_name)
if spec: if spec:
@ -282,13 +322,15 @@ def _get_package_path(package_name):
return None return None
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
BUILD_PYTHON_ONLY = str2bool(os.getenv("BUILD_PYTHON_ONLY"))
# set up appropriate env variables # set up appropriate env variables
if BUILD_LIBTORCH_WHL: if BUILD_LIBTORCH_WHL:
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so # Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
# functorch is not supported without python # functorch is not supported without python
os.environ["BUILD_FUNCTORCH"] = "OFF" os.environ["BUILD_FUNCTORCH"] = "OFF"
if BUILD_PYTHON_ONLY: if BUILD_PYTHON_ONLY:
os.environ["BUILD_LIBTORCHLESS"] = "ON" os.environ["BUILD_LIBTORCHLESS"] = "ON"
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib" os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
@ -413,7 +455,7 @@ def check_submodules():
os.path.isdir(folder) and len(os.listdir(folder)) == 0 os.path.isdir(folder) and len(os.listdir(folder)) == 0
) )
if bool(os.getenv("USE_SYSTEM_LIBS", False)): if str2bool(os.getenv("USE_SYSTEM_LIBS")):
return return
folders = get_submodule_folders() folders = get_submodule_folders()
# If none of the submodule folders exists, try to initialize them # If none of the submodule folders exists, try to initialize them
@ -748,8 +790,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
# In ROCm on Windows case copy rocblas and hipblaslt files into # In ROCm on Windows case copy rocblas and hipblaslt files into
# torch/lib/rocblas/library and torch/lib/hipblaslt/library # torch/lib/rocblas/library and torch/lib/hipblaslt/library
use_rocm = os.environ.get("USE_ROCM") if str2bool(os.getenv("USE_ROCM")):
if use_rocm:
rocm_dir_path = os.environ.get("ROCM_DIR") rocm_dir_path = os.environ.get("ROCM_DIR")
rocm_bin_path = os.path.join(rocm_dir_path, "bin") rocm_bin_path = os.path.join(rocm_dir_path, "bin")
@ -1146,19 +1187,7 @@ def main():
if BUILD_PYTHON_ONLY: if BUILD_PYTHON_ONLY:
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}") install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", "")) if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")):
if (
use_prioritized_text == ""
and platform.system() == "Linux"
and platform.processor() == "aarch64"
):
print_box(
"""
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
"""
)
if use_prioritized_text == "1" or use_prioritized_text == "True":
gen_linker_script( gen_linker_script(
filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld" filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld"
) )
@ -1170,6 +1199,13 @@ def main():
os.environ["CXXFLAGS"] = ( os.environ["CXXFLAGS"] = (
os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections" os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections"
) )
elif platform.system() == "Linux" and platform.processor() == "aarch64":
print_box(
"""
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
"""
)
# Parse the command line and check the arguments before we proceed with # Parse the command line and check the arguments before we proceed with
# building deps and setup. We need to set values so `--help` works. # building deps and setup. We need to set values so `--help` works.