mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-07 10:01:39 +08:00
Compare commits
6 Commits
gh/lw/8/he
...
copilot/co
| Author | SHA1 | Date | |
|---|---|---|---|
| 241b702918 | |||
| 83df2e0610 | |||
| 77fe8234bb | |||
| 6ece527fc5 | |||
| ce29d0d796 | |||
| 7231118db3 |
@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
|||||||
|
|
||||||
# cmake-3.18.4 from pip
|
# cmake-3.18.4 from pip
|
||||||
RUN yum install -y python3-pip && \
|
RUN yum install -y python3-pip && \
|
||||||
python3 -mpip install cmake==3.18.4 && \
|
python3 -m pip install cmake==3.18.4 && \
|
||||||
ln -s /usr/local/bin/cmake /usr/bin/cmake3
|
ln -s /usr/local/bin/cmake /usr/bin/cmake3
|
||||||
RUN rm -rf /usr/local/cuda-*
|
RUN rm -rf /usr/local/cuda-*
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ function install_torchbench() {
|
|||||||
python install.py --continue_on_fail
|
python install.py --continue_on_fail
|
||||||
|
|
||||||
echo "Print all dependencies after TorchBench is installed"
|
echo "Print all dependencies after TorchBench is installed"
|
||||||
python -mpip freeze
|
python -m pip freeze
|
||||||
popd
|
popd
|
||||||
|
|
||||||
chown -R jenkins torchbench
|
chown -R jenkins torchbench
|
||||||
|
|||||||
@ -8,8 +8,8 @@ MKLROOT=/opt/intel
|
|||||||
mkdir -p ${MKLROOT}
|
mkdir -p ${MKLROOT}
|
||||||
pushd /tmp
|
pushd /tmp
|
||||||
|
|
||||||
python3 -mpip install wheel
|
python3 -m pip install wheel
|
||||||
python3 -mpip download -d . mkl-static==${MKL_VERSION}
|
python3 -m pip download -d . mkl-static==${MKL_VERSION}
|
||||||
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
|
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
|
||||||
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
|
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
|
||||||
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}
|
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}
|
||||||
|
|||||||
@ -11,5 +11,5 @@ ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
|
|||||||
python -m venv /var/lib/jenkins/ci_env
|
python -m venv /var/lib/jenkins/ci_env
|
||||||
source /var/lib/jenkins/ci_env/bin/activate
|
source /var/lib/jenkins/ci_env/bin/activate
|
||||||
|
|
||||||
python -mpip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -mpip install -r /opt/requirements-ci.txt
|
python -m pip install -r /opt/requirements-ci.txt
|
||||||
|
|||||||
@ -14,7 +14,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op
|
|||||||
|
|
||||||
# cmake-3.18.4 from pip
|
# cmake-3.18.4 from pip
|
||||||
RUN yum install -y python3-pip && \
|
RUN yum install -y python3-pip && \
|
||||||
python3 -mpip install cmake==3.18.4 && \
|
python3 -m pip install cmake==3.18.4 && \
|
||||||
ln -s /usr/local/bin/cmake /usr/bin/cmake3
|
ln -s /usr/local/bin/cmake /usr/bin/cmake3
|
||||||
|
|
||||||
FROM base as openssl
|
FROM base as openssl
|
||||||
@ -135,7 +135,7 @@ RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
|
|||||||
|
|
||||||
# cmake-3.18.4 from pip; force in case cmake3 already exists
|
# cmake-3.18.4 from pip; force in case cmake3 already exists
|
||||||
RUN yum install -y python3-pip && \
|
RUN yum install -y python3-pip && \
|
||||||
python3 -mpip install cmake==3.18.4 && \
|
python3 -m pip install cmake==3.18.4 && \
|
||||||
ln -sf /usr/local/bin/cmake /usr/bin/cmake3
|
ln -sf /usr/local/bin/cmake /usr/bin/cmake3
|
||||||
|
|
||||||
FROM cpu_final as cuda_final
|
FROM cpu_final as cuda_final
|
||||||
@ -157,7 +157,7 @@ ENV ROCM_PATH /opt/rocm
|
|||||||
# cmake-3.28.4 from pip to get enable_language(HIP)
|
# cmake-3.28.4 from pip to get enable_language(HIP)
|
||||||
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
|
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
|
||||||
RUN python3 -m pip install --upgrade pip && \
|
RUN python3 -m pip install --upgrade pip && \
|
||||||
python3 -mpip install cmake==3.28.4
|
python3 -m pip install cmake==3.28.4
|
||||||
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
|
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
|
||||||
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
|
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
|
||||||
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
|
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
|
||||||
@ -174,7 +174,7 @@ FROM cpu_final as xpu_final
|
|||||||
ENV XPU_DRIVER_TYPE ROLLING
|
ENV XPU_DRIVER_TYPE ROLLING
|
||||||
# cmake-3.28.4 from pip
|
# cmake-3.28.4 from pip
|
||||||
RUN python3 -m pip install --upgrade pip && \
|
RUN python3 -m pip install --upgrade pip && \
|
||||||
python3 -mpip install cmake==3.28.4
|
python3 -m pip install cmake==3.28.4
|
||||||
ADD ./common/install_xpu.sh install_xpu.sh
|
ADD ./common/install_xpu.sh install_xpu.sh
|
||||||
ENV XPU_VERSION 2025.2
|
ENV XPU_VERSION 2025.2
|
||||||
RUN bash ./install_xpu.sh && rm install_xpu.sh
|
RUN bash ./install_xpu.sh && rm install_xpu.sh
|
||||||
|
|||||||
@ -113,7 +113,7 @@ RUN dnf install -y \
|
|||||||
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
|
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
|
||||||
|
|
||||||
# cmake-3.28.0 from pip for onnxruntime
|
# cmake-3.28.0 from pip for onnxruntime
|
||||||
RUN python3 -mpip install cmake==3.28.0
|
RUN python3 -m pip install cmake==3.28.0
|
||||||
|
|
||||||
ADD ./common/patch_libstdc.sh patch_libstdc.sh
|
ADD ./common/patch_libstdc.sh patch_libstdc.sh
|
||||||
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
|
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
|
||||||
|
|||||||
@ -288,7 +288,7 @@ else
|
|||||||
# or building non-XLA tests.
|
# or building non-XLA tests.
|
||||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
|
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
|
||||||
# Install numpy-2.0.2 for builds which are backward compatible with 1.X
|
# Install numpy-2.0.2 for builds which are backward compatible with 1.X
|
||||||
python -mpip install numpy==2.0.2
|
python -m pip install numpy==2.0.2
|
||||||
|
|
||||||
WERROR=1 python setup.py clean
|
WERROR=1 python setup.py clean
|
||||||
|
|
||||||
|
|||||||
@ -67,13 +67,13 @@ function pip_install_whl() {
|
|||||||
# Loop through each path and install individually
|
# Loop through each path and install individually
|
||||||
for path in "${paths[@]}"; do
|
for path in "${paths[@]}"; do
|
||||||
echo "Installing $path"
|
echo "Installing $path"
|
||||||
python3 -mpip install --no-index --no-deps "$path"
|
python3 -m pip install --no-index --no-deps "$path"
|
||||||
done
|
done
|
||||||
else
|
else
|
||||||
# Loop through each argument and install individually
|
# Loop through each argument and install individually
|
||||||
for path in "${args[@]}"; do
|
for path in "${args[@]}"; do
|
||||||
echo "Installing $path"
|
echo "Installing $path"
|
||||||
python3 -mpip install --no-index --no-deps "$path"
|
python3 -m pip install --no-index --no-deps "$path"
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|||||||
@ -182,7 +182,7 @@ checkout_install_torchbench() {
|
|||||||
pip uninstall -y torchao
|
pip uninstall -y torchao
|
||||||
|
|
||||||
echo "Print all dependencies after TorchBench is installed"
|
echo "Print all dependencies after TorchBench is installed"
|
||||||
python -mpip freeze
|
python -m pip freeze
|
||||||
}
|
}
|
||||||
|
|
||||||
torchbench_setup_macos() {
|
torchbench_setup_macos() {
|
||||||
@ -211,7 +211,7 @@ torchbench_setup_macos() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pip_benchmark_deps() {
|
pip_benchmark_deps() {
|
||||||
python -mpip install --no-input requests cython scikit-learn six
|
python -m pip install --no-input requests cython scikit-learn six
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1434,7 +1434,7 @@ EOF
|
|||||||
# shellcheck source=./common-build.sh
|
# shellcheck source=./common-build.sh
|
||||||
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
|
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
|
||||||
python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
|
python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
|
||||||
python -mpip install base_dist/*.whl
|
python -m pip install base_dist/*.whl
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
pushd test/forward_backward_compatibility
|
pushd test/forward_backward_compatibility
|
||||||
|
|||||||
@ -173,7 +173,7 @@ esac
|
|||||||
PINNED_PACKAGES=(
|
PINNED_PACKAGES=(
|
||||||
"numpy${NUMPY_PINNED_VERSION}"
|
"numpy${NUMPY_PINNED_VERSION}"
|
||||||
)
|
)
|
||||||
python -mvenv ~/${desired_python}-build
|
python -m venv ~/${desired_python}-build
|
||||||
source ~/${desired_python}-build/bin/activate
|
source ~/${desired_python}-build/bin/activate
|
||||||
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
|
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
|
||||||
retry brew install libomp
|
retry brew install libomp
|
||||||
|
|||||||
6
.github/scripts/prepare_vllm_wheels.sh
vendored
6
.github/scripts/prepare_vllm_wheels.sh
vendored
@ -24,7 +24,7 @@ change_wheel_version() {
|
|||||||
local t_version=$4
|
local t_version=$4
|
||||||
|
|
||||||
# Extract the wheel
|
# Extract the wheel
|
||||||
${PYTHON_EXECUTABLE} -mwheel unpack $wheel
|
${PYTHON_EXECUTABLE} -m wheel unpack $wheel
|
||||||
|
|
||||||
mv "${package}-${f_version}" "${package}-${t_version}"
|
mv "${package}-${f_version}" "${package}-${t_version}"
|
||||||
# Change the version from f_version to t_version in the dist-info dir
|
# Change the version from f_version to t_version in the dist-info dir
|
||||||
@ -47,7 +47,7 @@ change_wheel_version() {
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
# Repack the wheel
|
# Repack the wheel
|
||||||
${PYTHON_EXECUTABLE} -mwheel pack "${package}-${t_version}"
|
${PYTHON_EXECUTABLE} -m wheel pack "${package}-${t_version}"
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
rm -rf "${package}-${t_version}"
|
rm -rf "${package}-${t_version}"
|
||||||
@ -85,7 +85,7 @@ repackage_wheel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Require to re-package the wheel
|
# Require to re-package the wheel
|
||||||
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
|
${PYTHON_EXECUTABLE} -m pip install wheel==0.45.1
|
||||||
|
|
||||||
pushd externals/vllm/wheels
|
pushd externals/vllm/wheels
|
||||||
for package in xformers flashinfer-python vllm; do
|
for package in xformers flashinfer-python vllm; do
|
||||||
|
|||||||
4
.github/workflows/_mac-test.yml
vendored
4
.github/workflows/_mac-test.yml
vendored
@ -211,7 +211,7 @@ jobs:
|
|||||||
$tool --version
|
$tool --version
|
||||||
done
|
done
|
||||||
|
|
||||||
python3 -mpip install --no-index --no-deps dist/*.whl
|
python3 -m pip install --no-index --no-deps dist/*.whl
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
pushd "${RUNNER_TEMP}"
|
pushd "${RUNNER_TEMP}"
|
||||||
@ -222,7 +222,7 @@ jobs:
|
|||||||
popd
|
popd
|
||||||
|
|
||||||
if [ "${RC}" -ne 0 ]; then
|
if [ "${RC}" -ne 0 ]; then
|
||||||
python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
|
python3 -m pip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
|
||||||
fi
|
fi
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/_win-test.yml
vendored
2
.github/workflows/_win-test.yml
vendored
@ -204,7 +204,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
|
pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
|
||||||
# shellcheck disable=SC2046,SC2102
|
# shellcheck disable=SC2046,SC2102
|
||||||
python3 -mpip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
|
python3 -m pip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
|
||||||
popd
|
popd
|
||||||
|
|
||||||
.ci/pytorch/win-test.sh
|
.ci/pytorch/win-test.sh
|
||||||
|
|||||||
4
.github/workflows/build-vllm-wheel.yml
vendored
4
.github/workflows/build-vllm-wheel.yml
vendored
@ -126,13 +126,13 @@ jobs:
|
|||||||
"${MANYLINUX_IMAGE}"
|
"${MANYLINUX_IMAGE}"
|
||||||
)
|
)
|
||||||
|
|
||||||
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip install \
|
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install \
|
||||||
--pre torch torchvision torchaudio \
|
--pre torch torchvision torchaudio \
|
||||||
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
|
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
|
||||||
|
|
||||||
# I wonder if there is a command to both download and install the wheels
|
# I wonder if there is a command to both download and install the wheels
|
||||||
# in one go
|
# in one go
|
||||||
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip download \
|
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip download \
|
||||||
--pre torch torchvision torchaudio \
|
--pre torch torchvision torchaudio \
|
||||||
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
|
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
|
||||||
|
|
||||||
|
|||||||
14
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
14
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
@ -106,7 +106,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -216,7 +216,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -326,7 +326,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -436,7 +436,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -546,7 +546,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -656,7 +656,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
@ -766,7 +766,7 @@ jobs:
|
|||||||
SMOKE_TEST_PARAMS=""
|
SMOKE_TEST_PARAMS=""
|
||||||
|
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
python -mvenv test_venv
|
python -m venv test_venv
|
||||||
source test_venv/bin/activate
|
source test_venv/bin/activate
|
||||||
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
|
||||||
|
|
||||||
|
|||||||
24
.github/workflows/operator_benchmark.yml
vendored
24
.github/workflows/operator_benchmark.yml
vendored
@ -52,3 +52,27 @@ jobs:
|
|||||||
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
aarch64-opbenchmark-build:
|
||||||
|
if: github.repository_owner == 'pytorch'
|
||||||
|
name: aarch64-opbenchmark-build
|
||||||
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
with:
|
||||||
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
|
runner: linux.arm64.m7g.4xlarge
|
||||||
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||||
|
]}
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
|
aarch64-opbenchmark-test:
|
||||||
|
name: aarch64-opbenchmark-test
|
||||||
|
uses: ./.github/workflows/_linux-test.yml
|
||||||
|
needs: aarch64-opbenchmark-build
|
||||||
|
with:
|
||||||
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
|
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
|
||||||
|
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
|
||||||
|
secrets: inherit
|
||||||
|
|||||||
@ -39,7 +39,7 @@ RUN chmod +x ~/miniconda.sh && \
|
|||||||
bash ~/miniconda.sh -b -p /opt/conda && \
|
bash ~/miniconda.sh -b -p /opt/conda && \
|
||||||
rm ~/miniconda.sh && \
|
rm ~/miniconda.sh && \
|
||||||
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
|
||||||
/opt/conda/bin/python -mpip install -r requirements.txt && \
|
/opt/conda/bin/python -m pip install -r requirements.txt && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
FROM dev-base as submodule-update
|
FROM dev-base as submodule-update
|
||||||
|
|||||||
@ -229,10 +229,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static const uint32_t kPhilox10A = 0x9E3779B9;
|
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
|
||||||
static const uint32_t kPhilox10B = 0xBB67AE85;
|
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
|
||||||
static const uint32_t kPhiloxSA = 0xD2511F53;
|
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
|
||||||
static const uint32_t kPhiloxSB = 0xCD9E8D57;
|
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef philox_engine Philox4_32;
|
typedef philox_engine Philox4_32;
|
||||||
|
|||||||
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
|
|||||||
*/
|
*/
|
||||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||||
// The RNG state comprises the seed, and an offset used for Philox.
|
// The RNG state comprises the seed, and an offset used for Philox.
|
||||||
static const size_t seed_size = sizeof(uint64_t);
|
constexpr size_t seed_size = sizeof(uint64_t);
|
||||||
static const size_t offset_size = sizeof(int64_t);
|
constexpr size_t offset_size = sizeof(int64_t);
|
||||||
static const size_t total_size = seed_size + offset_size;
|
constexpr size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
|||||||
* and size of the internal state.
|
* and size of the internal state.
|
||||||
*/
|
*/
|
||||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||||
static const size_t seed_size = sizeof(uint64_t);
|
constexpr size_t seed_size = sizeof(uint64_t);
|
||||||
static const size_t offset_size = sizeof(int64_t);
|
constexpr size_t offset_size = sizeof(int64_t);
|
||||||
static const size_t total_size = seed_size + offset_size;
|
constexpr size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
detail::check_rng_state(new_state);
|
detail::check_rng_state(new_state);
|
||||||
|
|
||||||
|
|||||||
@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
|
|||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
|
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
|
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||||
|
|
||||||
DEFINE_DISPATCH(elu_stub);
|
DEFINE_DISPATCH(elu_stub);
|
||||||
DEFINE_DISPATCH(elu_backward_stub);
|
DEFINE_DISPATCH(elu_backward_stub);
|
||||||
|
|||||||
@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
|
|||||||
#if AT_BUILD_WITH_BLAS()
|
#if AT_BUILD_WITH_BLAS()
|
||||||
template <>
|
template <>
|
||||||
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
|
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
|
||||||
auto intmax = std::numeric_limits<int>::max();
|
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||||
return n <= intmax && incx <= intmax;
|
return n <= intmax && incx <= intmax;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
|
|||||||
int64_t incx,
|
int64_t incx,
|
||||||
[[maybe_unused]] float beta,
|
[[maybe_unused]] float beta,
|
||||||
int64_t incy) {
|
int64_t incy) {
|
||||||
auto intmax = std::numeric_limits<int>::max();
|
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||||
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
|
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
|
||||||
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
|
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <ATen/native/Math.h>
|
#include <ATen/native/Math.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
#include <c10/util/MathConstants.h>
|
#include <c10/util/MathConstants.h>
|
||||||
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
|
|||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
||||||
const static scalar_t kTailValues[] = {
|
constexpr static scalar_t kTailValues[] = {
|
||||||
0.0810614667953272,
|
0.0810614667953272,
|
||||||
0.0413406959554092,
|
0.0413406959554092,
|
||||||
0.0276779256849983,
|
0.0276779256849983,
|
||||||
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
|||||||
0.00925546218271273,
|
0.00925546218271273,
|
||||||
0.00833056343336287
|
0.00833056343336287
|
||||||
};
|
};
|
||||||
if (k <= 9) {
|
if (k < std::size(kTailValues)) {
|
||||||
return kTailValues[static_cast<size_t>(k)];
|
return kTailValues[static_cast<size_t>(k)];
|
||||||
}
|
}
|
||||||
scalar_t kp1sq = (k + 1) * (k + 1);
|
scalar_t kp1sq = (k + 1) * (k + 1);
|
||||||
|
|||||||
@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
|
|||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||||
// lanczos approximation
|
// lanczos approximation
|
||||||
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
|
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||||
0.006061842346248906525783753964555936883222,
|
0.006061842346248906525783753964555936883222,
|
||||||
0.5098416655656676188125178644804694509993,
|
0.5098416655656676188125178644804694509993,
|
||||||
19.51992788247617482847860966235652136208,
|
19.51992788247617482847860966235652136208,
|
||||||
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
|||||||
103794043.1163445451906271053616070238554,
|
103794043.1163445451906271053616070238554,
|
||||||
56906521.91347156388090791033559122686859
|
56906521.91347156388090791033559122686859
|
||||||
};
|
};
|
||||||
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||||
1.,
|
1.,
|
||||||
66.,
|
66.,
|
||||||
1925.,
|
1925.,
|
||||||
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
|||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
|
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
|
||||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||||
static const scalar_t d[25][25] =
|
static constexpr scalar_t d[25][25] =
|
||||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
|
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
|
||||||
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
|
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
|
||||||
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
|
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
|
||||||
|
|||||||
@ -62,7 +62,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
static const int MIOPEN_DIM_MAX = 5;
|
static constexpr int MIOPEN_DIM_MAX = 5;
|
||||||
|
|
||||||
namespace at::meta {
|
namespace at::meta {
|
||||||
|
|
||||||
|
|||||||
@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
|
|||||||
// We keep this structure for BC and consider as deprecated.
|
// We keep this structure for BC and consider as deprecated.
|
||||||
// See HelperInterpNearestExact as replacement
|
// See HelperInterpNearestExact as replacement
|
||||||
|
|
||||||
static const int interp_size = 1;
|
static constexpr int interp_size = 1;
|
||||||
|
|
||||||
static inline void init_indices_weights(
|
static inline void init_indices_weights(
|
||||||
at::ScalarType output_type,
|
at::ScalarType output_type,
|
||||||
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
|
|||||||
|
|
||||||
struct HelperInterpLinear : public HelperInterpBase {
|
struct HelperInterpLinear : public HelperInterpBase {
|
||||||
|
|
||||||
static const int interp_size = 2;
|
static constexpr int interp_size = 2;
|
||||||
|
|
||||||
// Compute indices and weights for each interpolated dimension
|
// Compute indices and weights for each interpolated dimension
|
||||||
// indices_weights = {
|
// indices_weights = {
|
||||||
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
|
|||||||
|
|
||||||
struct HelperInterpCubic : public HelperInterpBase {
|
struct HelperInterpCubic : public HelperInterpBase {
|
||||||
|
|
||||||
static const int interp_size = 4;
|
static constexpr int interp_size = 4;
|
||||||
|
|
||||||
// Compute indices and weights for each interpolated dimension
|
// Compute indices and weights for each interpolated dimension
|
||||||
// indices_weights = {
|
// indices_weights = {
|
||||||
|
|||||||
@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static const int BLOCK_THREADS = 256;
|
static constexpr int BLOCK_THREADS = 256;
|
||||||
|
|
||||||
template <typename scalar_t, typename accscalar_t>
|
template <typename scalar_t, typename accscalar_t>
|
||||||
#if defined (USE_ROCM)
|
#if defined (USE_ROCM)
|
||||||
|
|||||||
@ -36,9 +36,9 @@ namespace at::native {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
static const int BLOCKDIMY = 16;
|
static constexpr int BLOCKDIMY = 16;
|
||||||
#else
|
#else
|
||||||
static const int BLOCKDIMY = 32;
|
static constexpr int BLOCKDIMY = 32;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template
|
template
|
||||||
|
|||||||
@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
|||||||
// lanczos approximation
|
// lanczos approximation
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
|
|
||||||
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||||
0.006061842346248906525783753964555936883222,
|
0.006061842346248906525783753964555936883222,
|
||||||
0.5098416655656676188125178644804694509993,
|
0.5098416655656676188125178644804694509993,
|
||||||
19.51992788247617482847860966235652136208,
|
19.51992788247617482847860966235652136208,
|
||||||
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
|||||||
103794043.1163445451906271053616070238554,
|
103794043.1163445451906271053616070238554,
|
||||||
56906521.91347156388090791033559122686859
|
56906521.91347156388090791033559122686859
|
||||||
};
|
};
|
||||||
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||||
1.,
|
1.,
|
||||||
66.,
|
66.,
|
||||||
1925.,
|
1925.,
|
||||||
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
|
|||||||
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
accscalar_t ax, fac, res, num, numfac;
|
accscalar_t ax, fac, res, num, numfac;
|
||||||
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
||||||
7.09782712893383996843E2 : 88.72283905206835;
|
7.09782712893383996843E2 : 88.72283905206835;
|
||||||
static const accscalar_t EXP1 = 2.718281828459045;
|
constexpr accscalar_t EXP1 = 2.718281828459045;
|
||||||
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
|
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
|
||||||
|
|
||||||
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
|
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
|
||||||
ax = a * ::log(x) - x - ::lgamma(a);
|
ax = a * ::log(x) - x - ::lgamma(a);
|
||||||
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
|
|||||||
// Compute igam using DLMF 8.11.4. [igam1]
|
// Compute igam using DLMF 8.11.4. [igam1]
|
||||||
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||||
static const int MAXITER = 2000;
|
constexpr int MAXITER = 2000;
|
||||||
|
|
||||||
int i;
|
int i;
|
||||||
accscalar_t ans, ax, c, r;
|
accscalar_t ans, ax, c, r;
|
||||||
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
|||||||
accscalar_t fac = 1;
|
accscalar_t fac = 1;
|
||||||
accscalar_t sum = 0;
|
accscalar_t sum = 0;
|
||||||
accscalar_t term, logx;
|
accscalar_t term, logx;
|
||||||
static const int MAXITER = 2000;
|
constexpr int MAXITER = 2000;
|
||||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||||
|
|
||||||
for (n = 1; n < MAXITER; n++) {
|
for (n = 1; n < MAXITER; n++) {
|
||||||
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
|||||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||||
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
static const accscalar_t d[25][25] =
|
constexpr accscalar_t d[25][25] =
|
||||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
|
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
|
||||||
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
|
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
|
||||||
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
|
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
|
||||||
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
|||||||
|
|
||||||
int k, n, sgn;
|
int k, n, sgn;
|
||||||
int maxpow = 0;
|
int maxpow = 0;
|
||||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||||
accscalar_t lambda = x / a;
|
accscalar_t lambda = x / a;
|
||||||
accscalar_t sigma = (x - a) / a;
|
accscalar_t sigma = (x - a) / a;
|
||||||
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
|
|||||||
int i;
|
int i;
|
||||||
accscalar_t ans, ax, c, yc, r, t, y, z;
|
accscalar_t ans, ax, c, yc, r, t, y, z;
|
||||||
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
|
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
|
||||||
static const int MAXITER = 2000;
|
constexpr int MAXITER = 2000;
|
||||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||||
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
||||||
4.503599627370496e15 : 16777216.;
|
4.503599627370496e15 : 16777216.;
|
||||||
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
||||||
2.22044604925031308085e-16 : 5.9604644775390625E-8;
|
2.22044604925031308085e-16 : 5.9604644775390625E-8;
|
||||||
|
|
||||||
ax = _igam_helper_fac(a, x);
|
ax = _igam_helper_fac(a, x);
|
||||||
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
|||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
accscalar_t absxma_a;
|
accscalar_t absxma_a;
|
||||||
|
|
||||||
static const accscalar_t SMALL = 20.0;
|
constexpr accscalar_t SMALL = 20.0;
|
||||||
static const accscalar_t LARGE = 200.0;
|
constexpr accscalar_t LARGE = 200.0;
|
||||||
static const accscalar_t SMALLRATIO = 0.3;
|
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||||
static const accscalar_t LARGERATIO = 4.5;
|
constexpr accscalar_t LARGERATIO = 4.5;
|
||||||
|
|
||||||
if ((x < 0) || (a < 0)) {
|
if ((x < 0) || (a < 0)) {
|
||||||
// out of defined-region of the function
|
// out of defined-region of the function
|
||||||
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
|||||||
|
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
accscalar_t absxma_a;
|
accscalar_t absxma_a;
|
||||||
static const accscalar_t SMALL = 20.0;
|
constexpr accscalar_t SMALL = 20.0;
|
||||||
static const accscalar_t LARGE = 200.0;
|
constexpr accscalar_t LARGE = 200.0;
|
||||||
static const accscalar_t SMALLRATIO = 0.3;
|
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||||
static const accscalar_t LARGERATIO = 4.5;
|
constexpr accscalar_t LARGERATIO = 4.5;
|
||||||
|
|
||||||
// boundary values following SciPy
|
// boundary values following SciPy
|
||||||
if ((x < 0) || (a < 0)) {
|
if ((x < 0) || (a < 0)) {
|
||||||
|
|||||||
@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
|
|||||||
const auto digamma_string = jiterator_stringify(
|
const auto digamma_string = jiterator_stringify(
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T digamma(T x) {
|
T digamma(T x) {
|
||||||
static const double PI_f64 = 3.14159265358979323846;
|
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||||
|
|
||||||
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
|
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
|
||||||
if (x == 0) {
|
if (x == 0) {
|
||||||
@ -3072,9 +3072,9 @@ template <typename scalar_t>
|
|||||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
|
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
|
||||||
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
||||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||||
static const double PI_f64 = 3.14159265358979323846;
|
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||||
const accscalar_t PSI_10 = 2.25175258906672110764;
|
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
||||||
const accscalar_t A[] = {
|
constexpr accscalar_t A[] = {
|
||||||
8.33333333333333333333E-2,
|
8.33333333333333333333E-2,
|
||||||
-2.10927960927960927961E-2,
|
-2.10927960927960927961E-2,
|
||||||
7.57575757575757575758E-3,
|
7.57575757575757575758E-3,
|
||||||
|
|||||||
@ -1097,11 +1097,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
|||||||
// threads with different threadIdx.x are independent and will produce results for different outputs.
|
// threads with different threadIdx.x are independent and will produce results for different outputs.
|
||||||
// In such case, values in each loaded vector always correspond to different outputs.
|
// In such case, values in each loaded vector always correspond to different outputs.
|
||||||
if (fastest_moving_stride == sizeof(scalar_t)) {
|
if (fastest_moving_stride == sizeof(scalar_t)) {
|
||||||
#ifdef USE_ROCM
|
|
||||||
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
|
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
|
||||||
#else
|
|
||||||
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
|
|
||||||
#endif
|
|
||||||
// Case 1: "vectorize along input"
|
// Case 1: "vectorize along input"
|
||||||
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
|
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
|
||||||
// we should avoid vectorization.
|
// we should avoid vectorization.
|
||||||
|
|||||||
@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
|
|||||||
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
|
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
|
||||||
void mean_kernel_impl(TensorIterator& iter) {
|
void mean_kernel_impl(TensorIterator& iter) {
|
||||||
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
|
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
|
||||||
|
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
|
||||||
using factor_t = typename c10::scalar_value_type<acc_t>::type;
|
using factor_t = typename c10::scalar_value_type<acc_t>::type;
|
||||||
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
|
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
|
||||||
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
if constexpr (is_16_bits) {
|
||||||
|
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
||||||
|
} else {
|
||||||
|
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void mean_kernel_cuda(TensorIterator& iter) {
|
static void mean_kernel_cuda(TensorIterator& iter) {
|
||||||
|
|||||||
@ -13,24 +13,19 @@ namespace at::native {
|
|||||||
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
|
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
|
||||||
struct sum_functor {
|
struct sum_functor {
|
||||||
void operator()(TensorIterator& iter) {
|
void operator()(TensorIterator& iter) {
|
||||||
#ifdef USE_ROCM
|
const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||||
// Half and BFloat16 can be packed in groups of up to 8 elements and
|
return a + b;
|
||||||
// can use *_DWORDX4 instructions to achieve that.
|
};
|
||||||
const bool is_16_bits =
|
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
|
||||||
( (std::is_same<at::Half, scalar_t>::value) ||
|
if constexpr (is_16_bits) {
|
||||||
(std::is_same<at::BFloat16, scalar_t>::value) );
|
|
||||||
if (is_16_bits) {
|
|
||||||
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
|
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
|
||||||
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
iter, func_wrapper<out_t>(sum_combine)
|
||||||
return a + b;
|
);
|
||||||
}));
|
} else {
|
||||||
return;
|
gpu_reduce_kernel<scalar_t, out_t>(
|
||||||
|
iter, func_wrapper<out_t>(sum_combine)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
gpu_reduce_kernel<scalar_t, out_t>(
|
|
||||||
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
|
||||||
return a + b;
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const int size = 2;
|
static constexpr int size = 2;
|
||||||
};
|
};
|
||||||
|
|
||||||
// taken from
|
// taken from
|
||||||
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const int size = 4;
|
static constexpr int size = 4;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename accscalar_t>
|
template <typename accscalar_t>
|
||||||
|
|||||||
@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
|||||||
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
||||||
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
|
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
|
||||||
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
|
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
|
||||||
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
||||||
if (mat1.dim() == 1 && mat2.dim() == 1) {
|
if (mat1.dim() == 1 && mat2.dim() == 1) {
|
||||||
// aten::dot
|
// aten::dot
|
||||||
return mat1.size(0) > mkldnn_gemm_min_size;
|
return mat1.size(0) > mkldnn_gemm_min_size;
|
||||||
|
|||||||
@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
|
|||||||
|
|
||||||
#if defined(__ARM_NEON__) || defined(__aarch64__)
|
#if defined(__ARM_NEON__) || defined(__aarch64__)
|
||||||
|
|
||||||
const static int PARALLEL_THRESHOLD = 1 << 20;
|
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
|
||||||
|
|
||||||
// Generic template defaults to naive quantize implementation
|
// Generic template defaults to naive quantize implementation
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -1388,7 +1388,7 @@ namespace at::native {
|
|||||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
||||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
||||||
static std::optional<at::Tensor> other = std::nullopt;
|
static std::optional<at::Tensor> other = std::nullopt;
|
||||||
static const std::string_view binary_post_op = "none";
|
constexpr std::string_view binary_post_op = "none";
|
||||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
||||||
return linear_int8_with_onednn_weight(
|
return linear_int8_with_onednn_weight(
|
||||||
act, act_scale.item().toDouble(), act_zp,
|
act, act_scale.item().toDouble(), act_zp,
|
||||||
|
|||||||
@ -16,8 +16,8 @@ namespace {
|
|||||||
|
|
||||||
#ifdef USE_PYTORCH_QNNPACK
|
#ifdef USE_PYTORCH_QNNPACK
|
||||||
|
|
||||||
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
||||||
const static int qnnpack_softmax_output_zero_point = 0;
|
constexpr static int qnnpack_softmax_output_zero_point = 0;
|
||||||
|
|
||||||
bool is_qnnpack_compatible(
|
bool is_qnnpack_compatible(
|
||||||
const Tensor& qx,
|
const Tensor& qx,
|
||||||
|
|||||||
@ -110,9 +110,9 @@ class ApplyLogSumExp {
|
|||||||
using ElementCompute = ElementCompute_;
|
using ElementCompute = ElementCompute_;
|
||||||
using ElementLSE = ElementLSE_;
|
using ElementLSE = ElementLSE_;
|
||||||
|
|
||||||
static int const kElementsPerAccess = ElementsPerAccess;
|
static int constexpr kElementsPerAccess = ElementsPerAccess;
|
||||||
static int const kCount = kElementsPerAccess;
|
static int constexpr kCount = kElementsPerAccess;
|
||||||
static const ScaleType::Kind kScale =
|
static constexpr ScaleType::Kind kScale =
|
||||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||||
|
|
||||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||||
|
|||||||
@ -14,16 +14,16 @@ using namespace at;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
const auto int_min = std::numeric_limits<int>::min();
|
constexpr auto int_min = std::numeric_limits<int>::min();
|
||||||
const auto int_max = std::numeric_limits<int>::max();
|
constexpr auto int_max = std::numeric_limits<int>::max();
|
||||||
const auto long_min = std::numeric_limits<int64_t>::min();
|
constexpr auto long_min = std::numeric_limits<int64_t>::min();
|
||||||
const auto long_max = std::numeric_limits<int64_t>::max();
|
constexpr auto long_max = std::numeric_limits<int64_t>::max();
|
||||||
const auto float_lowest = std::numeric_limits<float>::lowest();
|
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
|
||||||
const auto float_min = std::numeric_limits<float>::min();
|
constexpr auto float_min = std::numeric_limits<float>::min();
|
||||||
const auto float_max = std::numeric_limits<float>::max();
|
constexpr auto float_max = std::numeric_limits<float>::max();
|
||||||
const auto double_lowest = std::numeric_limits<double>::lowest();
|
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
|
||||||
const auto double_min = std::numeric_limits<double>::min();
|
constexpr auto double_min = std::numeric_limits<double>::min();
|
||||||
const auto double_max = std::numeric_limits<double>::max();
|
constexpr auto double_max = std::numeric_limits<double>::max();
|
||||||
|
|
||||||
const std::vector<int> ints {
|
const std::vector<int> ints {
|
||||||
int_min,
|
int_min,
|
||||||
|
|||||||
@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
|
|||||||
|
|
||||||
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||||
// The RNG state comprises the seed, and an offset used for Philox.
|
// The RNG state comprises the seed, and an offset used for Philox.
|
||||||
static const size_t seed_size = sizeof(uint64_t);
|
constexpr size_t seed_size = sizeof(uint64_t);
|
||||||
static const size_t offset_size = sizeof(uint64_t);
|
constexpr size_t offset_size = sizeof(uint64_t);
|
||||||
static const size_t total_size = seed_size + offset_size;
|
constexpr size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
// The internal state is returned as a CPU byte tensor.
|
// The internal state is returned as a CPU byte tensor.
|
||||||
auto state_tensor = at::detail::empty_cpu(
|
auto state_tensor = at::detail::empty_cpu(
|
||||||
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
|||||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||||
at::xpu::assertNotCapturing(
|
at::xpu::assertNotCapturing(
|
||||||
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
||||||
static const size_t seed_size = sizeof(uint64_t);
|
constexpr size_t seed_size = sizeof(uint64_t);
|
||||||
static const size_t offset_size = sizeof(uint64_t);
|
constexpr size_t offset_size = sizeof(uint64_t);
|
||||||
static const size_t total_size = seed_size + offset_size;
|
constexpr size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
at::detail::check_rng_state(new_state);
|
at::detail::check_rng_state(new_state);
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
|
|||||||
op_bench.generate_pt_test(
|
op_bench.generate_pt_test(
|
||||||
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
|
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
|
||||||
)
|
)
|
||||||
op_bench.generate_pt_test(
|
|
||||||
configs.convtranspose_1d_configs_short
|
|
||||||
+ configs.conv_1d_configs_short
|
if not torch.backends.mkldnn.is_acl_available():
|
||||||
+ configs.conv_1d_configs_long,
|
# convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
|
||||||
ConvTranspose1dBenchmark,
|
op_bench.generate_pt_test(
|
||||||
)
|
configs.convtranspose_1d_configs_short
|
||||||
|
+ configs.conv_1d_configs_short
|
||||||
|
+ configs.conv_1d_configs_long,
|
||||||
|
ConvTranspose1dBenchmark,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -164,9 +164,6 @@ class TestIntTuple(TestCase):
|
|||||||
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
|
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
|
||||||
) # 4 -> (1,0,0) -> 1*8 = 8
|
) # 4 -> (1,0,0) -> 1*8 = 8
|
||||||
|
|
||||||
# Test with zero-length shape and strides
|
|
||||||
self.assertEqual(crd2idx(0, (), ()), 0) # 0 -> () -> sum([]) = 0
|
|
||||||
|
|
||||||
def test_idx2crd_basic(self):
|
def test_idx2crd_basic(self):
|
||||||
# Test basic int/int case
|
# Test basic int/int case
|
||||||
self.assertEqual(idx2crd(2, 5, 1), 2)
|
self.assertEqual(idx2crd(2, 5, 1), 2)
|
||||||
|
|||||||
@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
|
|||||||
def test_remap_to_tensor(self):
|
def test_remap_to_tensor(self):
|
||||||
"""Test the remap_to_tensor method for various scenarios."""
|
"""Test the remap_to_tensor method for various scenarios."""
|
||||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||||
result1 = layout1.remap_to_tensor(original_mesh)
|
result1 = layout1.remap_to_tensor(original_mesh)
|
||||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||||
self.assertEqual(result1, expected1)
|
self.assertEqual(result1, expected1)
|
||||||
|
|
||||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||||
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
|
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||||
layout2 = _Layout((2, 2), (2, 1))
|
layout2 = _Layout((2, 2), (2, 1))
|
||||||
result2 = layout2.remap_to_tensor(original_mesh)
|
result2 = layout2.remap_to_tensor(original_mesh)
|
||||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||||
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
|
|||||||
self.assertEqual(result5, expected5)
|
self.assertEqual(result5, expected5)
|
||||||
|
|
||||||
# Test 6: Tensor Cute representation of a 2D mesh
|
# Test 6: Tensor Cute representation of a 2D mesh
|
||||||
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
|
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||||
result6 = layout6.remap_to_tensor(original_mesh)
|
result6 = layout6.remap_to_tensor(original_mesh)
|
||||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||||
|
|||||||
@ -76,7 +76,7 @@ def main() -> None:
|
|||||||
if uv and (is_uv_managed_python or not need_user_flag):
|
if uv and (is_uv_managed_python or not need_user_flag):
|
||||||
pip_args = [uv, "pip", "install"]
|
pip_args = [uv, "pip", "install"]
|
||||||
elif sys.executable:
|
elif sys.executable:
|
||||||
pip_args = [sys.executable, "-mpip", "install"]
|
pip_args = [sys.executable, "-m", "pip", "install"]
|
||||||
else:
|
else:
|
||||||
pip_args = ["pip3", "install"]
|
pip_args = ["pip3", "install"]
|
||||||
|
|
||||||
|
|||||||
@ -707,13 +707,27 @@ class _LocalDeviceMesh:
|
|||||||
lm = local_tensor_mode()
|
lm = local_tensor_mode()
|
||||||
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
assert lm is not None, "Unexpectedly not in LocalTensorMode"
|
||||||
|
|
||||||
|
root_mesh = self._get_root_mesh()
|
||||||
|
submesh_dims = self.mesh_dim_names
|
||||||
|
|
||||||
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
|
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
|
||||||
for r in lm.ranks:
|
old_get_rank = DeviceMesh.get_rank # type: ignore[assignment]
|
||||||
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
|
try:
|
||||||
rank_coords = (rank_tensor == r).nonzero().tolist()
|
for r in lm.ranks:
|
||||||
assert len(rank_coords) == 1
|
DeviceMesh.get_rank = lambda self: r # type: ignore[method-assign]
|
||||||
for d, c in enumerate(rank_coords[0][1:]):
|
submesh = (
|
||||||
coords[d][r] = c
|
root_mesh
|
||||||
|
if submesh_dims is None
|
||||||
|
else root_mesh.__getitem__(submesh_dims)
|
||||||
|
)
|
||||||
|
rank_coords = (submesh.mesh == r).nonzero().tolist()
|
||||||
|
assert len(rank_coords) in (0, 1)
|
||||||
|
if len(rank_coords) == 0:
|
||||||
|
continue
|
||||||
|
for d, c in enumerate(rank_coords[0]):
|
||||||
|
coords[d][r] = c
|
||||||
|
finally:
|
||||||
|
DeviceMesh.get_rank = old_get_rank # type: ignore[method-assign]
|
||||||
|
|
||||||
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
out = [torch.SymInt(LocalIntNode(c)) for c in coords]
|
||||||
|
|
||||||
|
|||||||
@ -301,7 +301,10 @@ class _MeshLayout(Layout):
|
|||||||
ranks = self.all_ranks_from_zero()
|
ranks = self.all_ranks_from_zero()
|
||||||
return len(ranks) == len(set(ranks))
|
return len(ranks) == len(set(ranks))
|
||||||
|
|
||||||
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
|
def remap_to_tensor(
|
||||||
|
self,
|
||||||
|
mesh_tensor: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
|
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
|
||||||
transformation to actual device ranks.
|
transformation to actual device ranks.
|
||||||
@ -313,7 +316,10 @@ class _MeshLayout(Layout):
|
|||||||
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
|
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
|
||||||
sub-tensor for DeviceMesh and its backend creation.
|
sub-tensor for DeviceMesh and its backend creation.
|
||||||
|
|
||||||
The shape of the `rank_map` must be 1D and contiguous.
|
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
|
||||||
|
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
|
||||||
|
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
|
||||||
|
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -330,18 +336,18 @@ class _MeshLayout(Layout):
|
|||||||
Return: [[[10,30],[20,40]]]
|
Return: [[[10,30],[20,40]]]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rank_map: The concrete mesh tensor with actual device ranks
|
mesh_tensor: The concrete mesh tensor with actual device ranks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: A tensor representing the actual device allocation from rank_map
|
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
|
||||||
"""
|
"""
|
||||||
assert rank_map.ndim == 1
|
complement_layout = self.complement(mesh_tensor.numel())
|
||||||
assert rank_map.is_contiguous()
|
|
||||||
assert rank_map.numel() >= self.cosize()
|
|
||||||
|
|
||||||
complement_layout = self.complement(rank_map.numel())
|
return (
|
||||||
|
mesh_tensor.flatten()
|
||||||
return rank_map.as_strided(
|
.as_strided(
|
||||||
flatten(complement_layout.sizes) + flatten(self.sizes),
|
flatten(complement_layout.sizes) + flatten(self.sizes),
|
||||||
flatten(complement_layout.strides) + flatten(self.strides),
|
flatten(complement_layout.strides) + flatten(self.strides),
|
||||||
).reshape(-1, *self.top_level_sizes)
|
)
|
||||||
|
.reshape(-1, *(self[i].numel() for i in range(len(self))))
|
||||||
|
)
|
||||||
|
|||||||
@ -198,9 +198,7 @@ def crd2idx(
|
|||||||
for i in range(len(shape) - 1, 0, -1):
|
for i in range(len(shape) - 1, 0, -1):
|
||||||
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
||||||
crd = crd // product(shape[i])
|
crd = crd // product(shape[i])
|
||||||
if len(shape) > 0:
|
return result + crd2idx(crd, shape[0], stride[0])
|
||||||
result += crd2idx(crd, shape[0], stride[0])
|
|
||||||
return result
|
|
||||||
else: # "int" "int" "int"
|
else: # "int" "int" "int"
|
||||||
assert not is_tuple(shape) and not is_tuple(stride)
|
assert not is_tuple(shape) and not is_tuple(stride)
|
||||||
return crd * stride # all are ints after type checks
|
return crd * stride # all are ints after type checks
|
||||||
|
|||||||
@ -173,7 +173,7 @@ else:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_device_type: str
|
_device_type: str
|
||||||
_rank_map: torch.Tensor
|
_mesh: torch.Tensor
|
||||||
_mesh_dim_names: Optional[tuple[str, ...]]
|
_mesh_dim_names: Optional[tuple[str, ...]]
|
||||||
_layout: _MeshLayout
|
_layout: _MeshLayout
|
||||||
_root_mesh: Optional["DeviceMesh"] = None
|
_root_mesh: Optional["DeviceMesh"] = None
|
||||||
@ -190,49 +190,46 @@ else:
|
|||||||
_init_backend: bool = True,
|
_init_backend: bool = True,
|
||||||
_rank: Optional[int] = None,
|
_rank: Optional[int] = None,
|
||||||
_layout: Optional[_MeshLayout] = None,
|
_layout: Optional[_MeshLayout] = None,
|
||||||
_root_mesh: Optional["DeviceMesh"] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self._device_type = device_type
|
self._device_type = device_type
|
||||||
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
|
||||||
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
|
||||||
mesh_tensor = (
|
self._mesh = (
|
||||||
mesh.detach().to(dtype=torch.int).contiguous()
|
mesh.detach().to(dtype=torch.int).contiguous()
|
||||||
if isinstance(mesh, torch.Tensor)
|
if isinstance(mesh, torch.Tensor)
|
||||||
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
else torch.tensor(mesh, device="cpu", dtype=torch.int)
|
||||||
)
|
)
|
||||||
self._rank_map = (
|
|
||||||
_root_mesh._rank_map
|
|
||||||
if _root_mesh is not None
|
|
||||||
else mesh_tensor.flatten()
|
|
||||||
)
|
|
||||||
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
|
||||||
|
if backend_override is None:
|
||||||
|
backend_override = ((None, None),) * self.mesh.ndim
|
||||||
|
elif len(backend_override) != self.mesh.ndim:
|
||||||
|
raise ValueError(
|
||||||
|
f"backend_override should have the same length as the number of mesh dimensions, "
|
||||||
|
f"but got {len(backend_override)} and {self.mesh.ndim}."
|
||||||
|
)
|
||||||
# Internal bookkeeping for the device mesh.
|
# Internal bookkeeping for the device mesh.
|
||||||
self._layout = (
|
self._layout = (
|
||||||
_layout
|
_layout
|
||||||
if _layout
|
if _layout
|
||||||
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
|
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||||
)
|
)
|
||||||
self._root_mesh = _root_mesh
|
|
||||||
assert self._layout.check_non_overlap(), (
|
assert self._layout.check_non_overlap(), (
|
||||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||||
)
|
)
|
||||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||||
assert self._layout.top_level_sizes == mesh_tensor.size(), (
|
assert self._layout.top_level_sizes == self.mesh.size(), (
|
||||||
"Please use a valid layout when creating a DeviceMesh."
|
"Please use a valid layout when creating a DeviceMesh."
|
||||||
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
|
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if backend_override is None:
|
# private field to pre-generate DeviceMesh's hash
|
||||||
backend_override = ((None, None),) * len(self._layout)
|
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||||
elif len(backend_override) != len(self._layout):
|
self._thread_id = None
|
||||||
raise ValueError(
|
# Initialize instance-specific flatten mapping
|
||||||
f"backend_override should have the same length as the number of mesh dimensions, "
|
self._flatten_mapping = {}
|
||||||
f"but got {len(backend_override)} and {len(self._layout)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Skip process group initialization if xla device or init backend is False
|
# Skip process group initialization if xla device or init backend is False
|
||||||
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
|
||||||
self._thread_id = None
|
|
||||||
if device_type != "xla":
|
if device_type != "xla":
|
||||||
# always try to create default (world) pg, even if it is not initialized
|
# always try to create default (world) pg, even if it is not initialized
|
||||||
# already. The world pg is used for device mesh identity (rank) on each
|
# already. The world pg is used for device mesh identity (rank) on each
|
||||||
@ -255,11 +252,6 @@ else:
|
|||||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# private field to pre-generate DeviceMesh's hash
|
|
||||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
|
||||||
# Initialize instance-specific flatten mapping
|
|
||||||
self._flatten_mapping = {}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device_type(self) -> str:
|
def device_type(self) -> str:
|
||||||
"""Returns the device type of the mesh."""
|
"""Returns the device type of the mesh."""
|
||||||
@ -268,17 +260,7 @@ else:
|
|||||||
@property
|
@property
|
||||||
def mesh(self) -> torch.Tensor:
|
def mesh(self) -> torch.Tensor:
|
||||||
"""Returns the tensor representing the layout of devices."""
|
"""Returns the tensor representing the layout of devices."""
|
||||||
full_mesh = self._layout.remap_to_tensor(self._rank_map)
|
return self._mesh
|
||||||
if full_mesh.size(0) == 1:
|
|
||||||
return full_mesh[0]
|
|
||||||
my_coords = (full_mesh == get_rank()).nonzero()
|
|
||||||
if my_coords.size(0) > 0:
|
|
||||||
return full_mesh[my_coords[0, 0]]
|
|
||||||
raise RuntimeError(
|
|
||||||
"In order to get the mesh Tensor of a DeviceMesh it needs to "
|
|
||||||
"either have all its original dimensions (e.g., no slicing) "
|
|
||||||
"or it needs to contain the local rank"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
|
||||||
@ -293,9 +275,9 @@ else:
|
|||||||
init_process_group()
|
init_process_group()
|
||||||
|
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
if self._layout.numel() > world_size:
|
if self.mesh.numel() > world_size:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
|
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ONLY set the device if the current device is not initialized, if user already
|
# ONLY set the device if the current device is not initialized, if user already
|
||||||
@ -346,8 +328,8 @@ else:
|
|||||||
default_group = _get_default_group()
|
default_group = _get_default_group()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(self._layout) == 1
|
self.mesh.ndim == 1
|
||||||
and self._layout.numel() == get_world_size()
|
and self.mesh.numel() == get_world_size()
|
||||||
and backend_override[0] == (None, None)
|
and backend_override[0] == (None, None)
|
||||||
):
|
):
|
||||||
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
|
||||||
@ -366,11 +348,11 @@ else:
|
|||||||
dim_group_names.append(dim_group.group_name)
|
dim_group_names.append(dim_group.group_name)
|
||||||
else:
|
else:
|
||||||
# create sub pgs base on the mesh argument specified
|
# create sub pgs base on the mesh argument specified
|
||||||
for dim in range(len(self._layout)):
|
for dim in range(self.mesh.ndim):
|
||||||
# swap the current dim to the last dim
|
# swap the current dim to the last dim
|
||||||
# then reshape to flatten out other dims
|
# then reshape to flatten out other dims
|
||||||
pg_ranks_by_dim = (
|
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
|
||||||
self._layout[dim].nest().remap_to_tensor(self._rank_map)
|
-1, self.mesh.size(dim)
|
||||||
)
|
)
|
||||||
backend, pg_options = backend_override[dim]
|
backend, pg_options = backend_override[dim]
|
||||||
# We need to explicitly pass in timeout when specified in option, otherwise
|
# We need to explicitly pass in timeout when specified in option, otherwise
|
||||||
@ -466,14 +448,14 @@ else:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
device_mesh_repr = (
|
device_mesh_repr = (
|
||||||
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
|
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
|
||||||
if self._mesh_dim_names
|
if self._mesh_dim_names
|
||||||
else f"{self._layout.top_level_sizes}"
|
else f"{tuple(self._mesh.shape)}"
|
||||||
)
|
)
|
||||||
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
|
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
|
||||||
# We only print the mesh tensor if the debug mode is turned on.
|
# We only print the mesh tensor if the debug mode is turned on.
|
||||||
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
|
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
|
||||||
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
|
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
|
||||||
return f"{device_mesh_repr})"
|
return f"{device_mesh_repr})"
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
@ -483,7 +465,7 @@ else:
|
|||||||
self._hash = hash(
|
self._hash = hash(
|
||||||
(
|
(
|
||||||
self._flatten_mesh_list,
|
self._flatten_mesh_list,
|
||||||
self._layout,
|
self._mesh.shape,
|
||||||
self._device_type,
|
self._device_type,
|
||||||
self._mesh_dim_names,
|
self._mesh_dim_names,
|
||||||
self._thread_id,
|
self._thread_id,
|
||||||
@ -499,7 +481,7 @@ else:
|
|||||||
return False
|
return False
|
||||||
return (
|
return (
|
||||||
self._flatten_mesh_list == other._flatten_mesh_list
|
self._flatten_mesh_list == other._flatten_mesh_list
|
||||||
and self._layout == other._layout
|
and self._mesh.shape == other._mesh.shape
|
||||||
and self._device_type == other._device_type
|
and self._device_type == other._device_type
|
||||||
and self._mesh_dim_names == other._mesh_dim_names
|
and self._mesh_dim_names == other._mesh_dim_names
|
||||||
and self._thread_id == other._thread_id
|
and self._thread_id == other._thread_id
|
||||||
@ -591,16 +573,16 @@ else:
|
|||||||
if not hasattr(self, "_dim_group_names"):
|
if not hasattr(self, "_dim_group_names"):
|
||||||
raise RuntimeError("DeviceMesh process groups not initialized!")
|
raise RuntimeError("DeviceMesh process groups not initialized!")
|
||||||
|
|
||||||
if len(self._layout) > 1 and mesh_dim is None:
|
if self.mesh.ndim > 1 and mesh_dim is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||||
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
|
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
|
||||||
"please use `get_all_groups()` instead.",
|
"please use `get_all_groups()` instead.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Quick return if the current device_mesh is a 1D mesh.
|
# Quick return if the current device_mesh is a 1D mesh.
|
||||||
if len(self._layout) == 1 and mesh_dim is None:
|
if self.mesh.ndim == 1 and mesh_dim is None:
|
||||||
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
return not_none(_resolve_process_group(self._dim_group_names[0]))
|
||||||
|
|
||||||
root_mesh = self._get_root_mesh()
|
root_mesh = self._get_root_mesh()
|
||||||
@ -626,7 +608,7 @@ else:
|
|||||||
Returns:
|
Returns:
|
||||||
A list of :class:`ProcessGroup` object.
|
A list of :class:`ProcessGroup` object.
|
||||||
"""
|
"""
|
||||||
return [self.get_group(i) for i in range(len(self._layout))]
|
return [self.get_group(i) for i in range(self.mesh.ndim)]
|
||||||
|
|
||||||
def _create_sub_mesh(
|
def _create_sub_mesh(
|
||||||
self,
|
self,
|
||||||
@ -653,7 +635,9 @@ else:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
cur_rank = self.get_rank()
|
cur_rank = self.get_rank()
|
||||||
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
|
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||||
|
root_mesh.mesh,
|
||||||
|
)
|
||||||
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
res_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||||
self._device_type,
|
self._device_type,
|
||||||
pg_ranks_by_dim,
|
pg_ranks_by_dim,
|
||||||
@ -708,7 +692,9 @@ else:
|
|||||||
cur_rank = root_mesh.get_rank()
|
cur_rank = root_mesh.get_rank()
|
||||||
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
|
||||||
# new_group api to avoid potential hang.
|
# new_group api to avoid potential hang.
|
||||||
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
|
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
|
||||||
|
root_mesh.mesh,
|
||||||
|
)
|
||||||
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||||
root_mesh._device_type,
|
root_mesh._device_type,
|
||||||
pg_ranks_by_dim.flatten(
|
pg_ranks_by_dim.flatten(
|
||||||
@ -847,7 +833,9 @@ else:
|
|||||||
"""
|
"""
|
||||||
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
|
||||||
layout = self._layout[mesh_dim]
|
layout = self._layout[mesh_dim]
|
||||||
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
|
pg_ranks_by_dim = layout.remap_to_tensor(
|
||||||
|
self.mesh,
|
||||||
|
)
|
||||||
cur_rank = self.get_rank()
|
cur_rank = self.get_rank()
|
||||||
res_submeshes = []
|
res_submeshes = []
|
||||||
for mesh_1d in pg_ranks_by_dim:
|
for mesh_1d in pg_ranks_by_dim:
|
||||||
@ -908,7 +896,6 @@ else:
|
|||||||
backend_override=backend_override,
|
backend_override=backend_override,
|
||||||
_init_backend=_init_backend,
|
_init_backend=_init_backend,
|
||||||
_layout=_layout,
|
_layout=_layout,
|
||||||
_root_mesh=_root_mesh,
|
|
||||||
)
|
)
|
||||||
if cur_rank in mesh_nd:
|
if cur_rank in mesh_nd:
|
||||||
res_mesh = mesh
|
res_mesh = mesh
|
||||||
@ -917,6 +904,8 @@ else:
|
|||||||
f"Current rank {cur_rank} not found in any mesh, "
|
f"Current rank {cur_rank} not found in any mesh, "
|
||||||
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
|
||||||
)
|
)
|
||||||
|
if _root_mesh is not None:
|
||||||
|
res_mesh._root_mesh = _root_mesh
|
||||||
return res_mesh
|
return res_mesh
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1015,17 +1004,15 @@ else:
|
|||||||
return device_mesh
|
return device_mesh
|
||||||
|
|
||||||
def size(self, mesh_dim: Optional[int] = None) -> int:
|
def size(self, mesh_dim: Optional[int] = None) -> int:
|
||||||
if mesh_dim is not None:
|
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
|
||||||
return self._layout[mesh_dim].numel()
|
|
||||||
return self._layout.numel()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ndim(self) -> int:
|
def ndim(self) -> int:
|
||||||
return len(self._layout)
|
return self.mesh.ndim
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
return self._layout.top_level_sizes
|
return tuple(self.mesh.shape)
|
||||||
|
|
||||||
def get_rank(self) -> int:
|
def get_rank(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -1064,7 +1051,7 @@ else:
|
|||||||
"""
|
"""
|
||||||
if self.ndim > 1 and mesh_dim is None:
|
if self.ndim > 1 and mesh_dim is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Found the DeviceMesh have {len(self._layout)} dimensions",
|
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
|
||||||
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
|
||||||
)
|
)
|
||||||
elif mesh_dim is None:
|
elif mesh_dim is None:
|
||||||
@ -1128,7 +1115,9 @@ else:
|
|||||||
root_mesh = self._get_root_mesh()
|
root_mesh = self._get_root_mesh()
|
||||||
cur_rank = self.get_rank()
|
cur_rank = self.get_rank()
|
||||||
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
|
||||||
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
|
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
|
||||||
|
root_mesh.mesh,
|
||||||
|
)
|
||||||
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
|
||||||
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
|
||||||
res_mesh = DeviceMesh._create_mesh_from_ranks(
|
res_mesh = DeviceMesh._create_mesh_from_ranks(
|
||||||
@ -1152,7 +1141,7 @@ else:
|
|||||||
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
|
||||||
)
|
)
|
||||||
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
|
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
|
||||||
root_mesh._rank_map
|
root_mesh.mesh,
|
||||||
)
|
)
|
||||||
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
|
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
|
||||||
self.device_type,
|
self.device_type,
|
||||||
|
|||||||
@ -640,9 +640,9 @@ def get_pip_packages(run_lambda, patterns=None):
|
|||||||
|
|
||||||
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
|
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
|
||||||
# People generally have pip as `pip` or `pip3`
|
# People generally have pip as `pip` or `pip3`
|
||||||
# But here it is invoked as `python -mpip`
|
# But here it is invoked as `python -m pip`
|
||||||
out = run_and_read_all(
|
out = run_and_read_all(
|
||||||
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"]
|
run_lambda, [sys.executable, "-m", "pip", "list", "--format=freeze"]
|
||||||
)
|
)
|
||||||
if out is None:
|
if out is None:
|
||||||
return pip_version, out
|
return pip_version, out
|
||||||
|
|||||||
Reference in New Issue
Block a user