mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
[BE][setup] gracefully handle envvars representing a boolean in setup.py
(#156040)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156040 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
f48a157660
commit
4162c0f702
76
setup.py
76
setup.py
@ -221,6 +221,8 @@
|
|||||||
# BUILD_PYTHON_ONLY
|
# BUILD_PYTHON_ONLY
|
||||||
# Builds pytorch as a wheel using libtorch.so from a separate wheel
|
# Builds pytorch as a wheel using libtorch.so from a separate wheel
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -234,9 +236,6 @@ if sys.platform == "win32" and sys.maxsize.bit_length() == 31:
|
|||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
|
||||||
BUILD_LIBTORCH_WHL = os.getenv("BUILD_LIBTORCH_WHL", "0") == "1"
|
|
||||||
BUILD_PYTHON_ONLY = os.getenv("BUILD_PYTHON_ONLY", "0") == "1"
|
|
||||||
|
|
||||||
python_min_version = (3, 9, 0)
|
python_min_version = (3, 9, 0)
|
||||||
python_min_version_str = ".".join(map(str, python_min_version))
|
python_min_version_str = ".".join(map(str, python_min_version))
|
||||||
if sys.version_info < python_min_version:
|
if sys.version_info < python_min_version:
|
||||||
@ -268,6 +267,47 @@ from tools.setup_helpers.env import build_type, IS_DARWIN, IS_LINUX, IS_WINDOWS
|
|||||||
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
from tools.setup_helpers.generate_linker_script import gen_linker_script
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(value: str | None) -> bool:
|
||||||
|
"""Convert environment variables to boolean values."""
|
||||||
|
if not value:
|
||||||
|
return False
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected a string value for boolean conversion, got {type(value)}"
|
||||||
|
)
|
||||||
|
value = value.strip().lower()
|
||||||
|
if value in (
|
||||||
|
"1",
|
||||||
|
"true",
|
||||||
|
"t",
|
||||||
|
"yes",
|
||||||
|
"y",
|
||||||
|
"on",
|
||||||
|
"enable",
|
||||||
|
"enabled",
|
||||||
|
"found",
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if value in (
|
||||||
|
"0",
|
||||||
|
"false",
|
||||||
|
"f",
|
||||||
|
"no",
|
||||||
|
"n",
|
||||||
|
"off",
|
||||||
|
"disable",
|
||||||
|
"disabled",
|
||||||
|
"notfound",
|
||||||
|
"none",
|
||||||
|
"null",
|
||||||
|
"nil",
|
||||||
|
"undefined",
|
||||||
|
"n/a",
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
raise ValueError(f"Invalid string value for boolean conversion: {value}")
|
||||||
|
|
||||||
|
|
||||||
def _get_package_path(package_name):
|
def _get_package_path(package_name):
|
||||||
spec = importlib.util.find_spec(package_name)
|
spec = importlib.util.find_spec(package_name)
|
||||||
if spec:
|
if spec:
|
||||||
@ -282,13 +322,15 @@ def _get_package_path(package_name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
BUILD_LIBTORCH_WHL = str2bool(os.getenv("BUILD_LIBTORCH_WHL"))
|
||||||
|
BUILD_PYTHON_ONLY = str2bool(os.getenv("BUILD_PYTHON_ONLY"))
|
||||||
|
|
||||||
# set up appropriate env variables
|
# set up appropriate env variables
|
||||||
if BUILD_LIBTORCH_WHL:
|
if BUILD_LIBTORCH_WHL:
|
||||||
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
|
# Set up environment variables for ONLY building libtorch.so and not libtorch_python.so
|
||||||
# functorch is not supported without python
|
# functorch is not supported without python
|
||||||
os.environ["BUILD_FUNCTORCH"] = "OFF"
|
os.environ["BUILD_FUNCTORCH"] = "OFF"
|
||||||
|
|
||||||
|
|
||||||
if BUILD_PYTHON_ONLY:
|
if BUILD_PYTHON_ONLY:
|
||||||
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
os.environ["BUILD_LIBTORCHLESS"] = "ON"
|
||||||
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
|
os.environ["LIBTORCH_LIB_PATH"] = f"{_get_package_path('torch')}/lib"
|
||||||
@ -413,7 +455,7 @@ def check_submodules():
|
|||||||
os.path.isdir(folder) and len(os.listdir(folder)) == 0
|
os.path.isdir(folder) and len(os.listdir(folder)) == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
if bool(os.getenv("USE_SYSTEM_LIBS", False)):
|
if str2bool(os.getenv("USE_SYSTEM_LIBS")):
|
||||||
return
|
return
|
||||||
folders = get_submodule_folders()
|
folders = get_submodule_folders()
|
||||||
# If none of the submodule folders exists, try to initialize them
|
# If none of the submodule folders exists, try to initialize them
|
||||||
@ -748,8 +790,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
|||||||
|
|
||||||
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
# In ROCm on Windows case copy rocblas and hipblaslt files into
|
||||||
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
# torch/lib/rocblas/library and torch/lib/hipblaslt/library
|
||||||
use_rocm = os.environ.get("USE_ROCM")
|
if str2bool(os.getenv("USE_ROCM")):
|
||||||
if use_rocm:
|
|
||||||
rocm_dir_path = os.environ.get("ROCM_DIR")
|
rocm_dir_path = os.environ.get("ROCM_DIR")
|
||||||
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
|
rocm_bin_path = os.path.join(rocm_dir_path, "bin")
|
||||||
|
|
||||||
@ -1146,19 +1187,7 @@ def main():
|
|||||||
if BUILD_PYTHON_ONLY:
|
if BUILD_PYTHON_ONLY:
|
||||||
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
|
install_requires.append(f"{LIBTORCH_PKG_NAME}=={get_torch_version()}")
|
||||||
|
|
||||||
use_prioritized_text = str(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD", ""))
|
if str2bool(os.getenv("USE_PRIORITIZED_TEXT_FOR_LD")):
|
||||||
if (
|
|
||||||
use_prioritized_text == ""
|
|
||||||
and platform.system() == "Linux"
|
|
||||||
and platform.processor() == "aarch64"
|
|
||||||
):
|
|
||||||
print_box(
|
|
||||||
"""
|
|
||||||
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
|
|
||||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
if use_prioritized_text == "1" or use_prioritized_text == "True":
|
|
||||||
gen_linker_script(
|
gen_linker_script(
|
||||||
filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld"
|
filein="cmake/prioritized_text.txt", fout="cmake/linker_script.ld"
|
||||||
)
|
)
|
||||||
@ -1170,6 +1199,13 @@ def main():
|
|||||||
os.environ["CXXFLAGS"] = (
|
os.environ["CXXFLAGS"] = (
|
||||||
os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections"
|
os.getenv("CXXFLAGS", "") + " -ffunction-sections -fdata-sections"
|
||||||
)
|
)
|
||||||
|
elif platform.system() == "Linux" and platform.processor() == "aarch64":
|
||||||
|
print_box(
|
||||||
|
"""
|
||||||
|
WARNING: we strongly recommend enabling linker script optimization for ARM + CUDA.
|
||||||
|
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Parse the command line and check the arguments before we proceed with
|
# Parse the command line and check the arguments before we proceed with
|
||||||
# building deps and setup. We need to set values so `--help` works.
|
# building deps and setup. We need to set values so `--help` works.
|
||||||
|
Reference in New Issue
Block a user