mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[CI] Add prebuild command option, set prebuild command option for CI to build flash attention (#156236)
Build flash attention separately in build using 2 jobs since it OOMs on more, then the rest of the job uses 6 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156236 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ed4384f5b
commit
f40efde2a4
@ -198,10 +198,8 @@ fi
|
||||
|
||||
# We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
|
||||
# memory to build and will OOM
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]] && [ -z "$MAX_JOBS_OVERRIDE" ]; then
|
||||
echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM"
|
||||
echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage"
|
||||
export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))"
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then
|
||||
export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2"
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
from .optional_submodules import checkout_nccl
|
||||
from .setup_helpers.cmake import CMake, USE_NINJA
|
||||
@ -98,4 +99,20 @@ def build_pytorch(
|
||||
)
|
||||
if cmake_only:
|
||||
return
|
||||
build_custom_step = os.getenv("BUILD_CUSTOM_STEP")
|
||||
if build_custom_step:
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
build_custom_step,
|
||||
shell=True,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
print("Command output:")
|
||||
print(output)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Command failed with return code:", e.returncode)
|
||||
print("Output (stdout and stderr):")
|
||||
print(e.output)
|
||||
raise
|
||||
cmake.build(my_env)
|
||||
|
Reference in New Issue
Block a user