mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][setup] gracefully handle envvars representing a boolean in setup.py
(#156040)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156040 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f48a157660
commit
4162c0f702
76
setup.py
76
setup.py
@ -221,6 +221,8 @@
|
||||
# BUILD_PYTHON_ONLY
|
||||
# Builds pytorch as a wheel using libtorch.so from a separate wheel
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
@ -234,9 +236,6 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
|
||||
import platform
|
||||
|
||||
|
||||
BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1"
|
||||
BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1"
|
||||
|
||||
python_min_version = (3, 9, 0)
|
||||
python_min_version_str = ".".join(map(str, python_min_version))
|
||||
if sys.version_info < python_min_version:
|
||||
@ -268,6 +267,47 @@ from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
|
||||
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
||||
|
||||
|
||||
def str2bool(value: str | None) -> bool:
|
||||
"""Convert environment variables to boolean values."""
|
||||
if not value:
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Expected a string value for boolean conversion, got {type(value)}"
|
||||
)
|
||||
value = value.strip().lower()
|
||||
if value in (
|
||||
"1",
|
||||
"true",
|
||||
"t",
|
||||
"yes",
|
||||
"y",
|
||||
"on",
|
||||
"enable",
|
||||
"enabled",
|
||||
"found",
|
||||
):
|
||||
return True
|
||||
if value in (
|
||||
"0",
|
||||
"false",
|
||||
"f",
|
||||
"no",
|
||||
"n",
|
||||
"off",
|
||||
"disable",
|
||||
"disabled",
|
||||
"notfound",
|
||||
"none",
|
||||
"null",
|
||||
"nil",
|
||||
"undefined",
|
||||
"n/a",
|
||||
):
|
||||
return False
|
||||
raise ValueError(f"Invalid string value for boolean conversion: {value}")
|
||||
|
||||
|
||||
def _get_package_path(package_name):
|
||||
spec = importlib.util.find_spec(package_name)
|
||||
if spec:
|
||||
@ -282,13 +322,15 @@ def _get_package_path(package_name):
|
||||
return None
|
||||
|
||||
|
||||
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
|
||||
BUILD_PYTHON_ONLY = str2bool(os.getenv("BUILD_PYTHON_ONLY"))
|
||||
|
||||
# set up appropriate env variables
|
||||
if BUILD_LIBTORCH_WHL:
|
||||
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
|
||||
# functorch is not supported without python
|
||||
os.environ["BUILD_FUNCTORCH"] = "OFF"
|
||||
|
||||
|
||||
if BUILD_PYTHON_ONLY:
|
||||
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
||||
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
|
||||
@ -413,7 +455,7 @@ def check_submodules():
|
||||
os.path.isdir(folder) and len(os.listdir(folder)) == 0
|
||||
)
|
||||
|
||||
if bool(os.getenv("USE_SYSTEM_LIBS", False)):
|
||||
if str2bool(os.getenv("USE_SYSTEM_LIBS")):
|
||||
return
|
||||
folders = get_submodule_folders()
|
||||
# If none of the submodule folders exists, try to initialize them
|
||||
@ -748,8 +790,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
|
||||
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
||||
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
||||
use_rocm = os.environ.get("USE_ROCM")
|
||||
if use_rocm:
|
||||
if str2bool(os.getenv("USE_ROCM")):
|
||||
rocm_dir_path = os.environ.get("ROCM_DIR")
|
||||
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
|
||||
|
||||
@ -1146,19 +1187,7 @@ def main():
|
||||
if BUILD_PYTHON_ONLY:
|
||||
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
|
||||
|
||||
use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", ""))
|
||||
if (
|
||||
use_prioritized_text == ""
|
||||
and platform.system() == "Linux"
|
||||
and platform.processor() == "aarch64"
|
||||
):
|
||||
print_box(
|
||||
"""
|
||||
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
"""
|
||||
)
|
||||
if use_prioritized_text == "1" or use_prioritized_text == "True":
|
||||
if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")):
|
||||
gen_linker_script(
|
||||
filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld"
|
||||
)
|
||||
@ -1170,6 +1199,13 @@ def main():
|
||||
os.environ["CXXFLAGS"] = (
|
||||
os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections"
|
||||
)
|
||||
elif platform.system() == "Linux" and platform.processor() == "aarch64":
|
||||
print_box(
|
||||
"""
|
||||
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
"""
|
||||
)
|
||||
|
||||
# Parse the command line and check the arguments before we proceed with
|
||||
# building deps and setup. We need to set values so `--help` works.
|
||||
|
Reference in New Issue
Block a user