Compare commits

...

2 Commits

Author SHA1 Message Date
54c019979c Merge branch 'main' into adi/update_openblas 2025-05-20 19:42:52 +00:00
90701ab81b update openblas 2025-04-16 09:34:37 +00:00
5 changed files with 17 additions and 8 deletions

View File

@ -342,6 +342,7 @@ case "$tag" in
GCC_VERSION=11
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
@ -351,6 +352,7 @@ case "$tag" in
GCC_VERSION=11
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
@ -446,6 +448,7 @@ docker build \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \
--build-arg "OPENBLAS=${OPENBLAS:-}" \
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \

View File

@ -65,9 +65,7 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
conda_install libstdcxx-ng=12.3.0 --update-deps -c conda-forge
# Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README
if [[ $(uname -m) == "aarch64" ]]; then
conda_install "openblas==0.3.29=*openmp*"
else
if [[ $(uname -m) != "aarch64" ]]; then
conda_install "mkl=2021.4.0 mkl-include=2021.4.0"
fi

View File

@ -4,8 +4,11 @@
set -ex
cd /
git clone https://github.com/OpenMathLib/OpenBLAS.git -b v0.3.29 --depth 1 --shallow-submodules
OPENBLAS_HASH="b30dc9701f8e971720a02e24068acea274fd9cee" #Use SVE kernel for S/DGEMVT for SVE machines
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
git clone https://github.com/OpenMathLib/OpenBLAS.git -b develop --shallow-submodules
git -C $OPENBLAS_CHECKOUT_DIR fetch --depth 1 origin $OPENBLAS_HASH
git -C $OPENBLAS_CHECKOUT_DIR checkout $OPENBLAS_HASH
OPENBLAS_BUILD_FLAGS="
NUM_THREADS=128
@ -14,9 +17,8 @@ NO_SHARED=0
DYNAMIC_ARCH=1
TARGET=ARMV8
CFLAGS=-O3
BUILD_BFLOAT16=1
"
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
make -j8 ${OPENBLAS_BUILD_FLAGS} -C ${OPENBLAS_CHECKOUT_DIR}
make -j8 ${OPENBLAS_BUILD_FLAGS} install -C ${OPENBLAS_CHECKOUT_DIR}

View File

@ -147,6 +147,12 @@ RUN if [ -n "${ACL}" ]; then bash ./install_acl.sh; fi
RUN rm install_acl.sh
ENV INSTALLED_ACL ${ACL}
ARG OPENBLAS
COPY ./common/install_openblas.sh install_openblas.sh
RUN if [ -n "${OPENBLAS}" ]; then bash ./install_openblas.sh; fi
RUN rm install_openblas.sh
ENV INSTALLED_OPENBLAS ${OPENBLAS}
# Install ccache/sccache (do this last, so we get priority in PATH)
ARG SKIP_SCCACHE_INSTALL
COPY ./common/install_cache.sh install_cache.sh

View File

@ -420,7 +420,7 @@ void gemm(
const float beta,
float *c, int64_t ldc) {
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
#if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) && defined(__ARM_FEATURE_BF16)
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);