Compare commits

..

6 Commits

Author SHA1 Message Date
241b702918 Fix remaining -m<module> patterns in Python code and scripts for consistency
Co-authored-by: malfet <2453524+malfet@users.noreply.github.com>
2025-10-17 14:56:15 +00:00
83df2e0610 Replace python3 -mpip and python -mpip with python3 -m pip and python -m pip for better readability
Co-authored-by: malfet <2453524+malfet@users.noreply.github.com>
2025-10-17 14:52:55 +00:00
77fe8234bb Initial plan 2025-10-17 14:45:52 +00:00
6ece527fc5 [CI] Add aarch64 operator benchmark (#165585)
Running on Graviton4
Skip ConvTranspose1d benchmarks if PyTorch is compiled with ACL, due to https://github.com/pytorch/pytorch/issues/165654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165585
Approved by: https://github.com/huydhn
2025-10-17 14:42:14 +00:00
ce29d0d796 [ATen] Vectorize 8 elements on 16 bit data types for sum/mean (#165055)
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790
2025-10-17 13:39:36 +00:00
7231118db3 Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 13:24:46 +00:00
51 changed files with 1584 additions and 236 deletions

View File

@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
# cmake-3.18.4 from pip
RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \
python3 -m pip install cmake==3.18.4 && \
ln -s /usr/local/bin/cmake /usr/bin/cmake3
RUN rm -rf /usr/local/cuda-*

View File

@ -25,7 +25,7 @@ function install_torchbench() {
python install.py --continue_on_fail
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
python -m pip freeze
popd
chown -R jenkins torchbench

View File

@ -8,8 +8,8 @@ MKLROOT=/opt/intel
mkdir -p ${MKLROOT}
pushd /tmp
python3 -mpip install wheel
python3 -mpip download -d . mkl-static==${MKL_VERSION}
python3 -m pip install wheel
python3 -m pip download -d . mkl-static==${MKL_VERSION}
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}

View File

@ -11,5 +11,5 @@ ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
python -m venv /var/lib/jenkins/ci_env
source /var/lib/jenkins/ci_env/bin/activate
python -mpip install --upgrade pip
python -mpip install -r /opt/requirements-ci.txt
python -m pip install --upgrade pip
python -m pip install -r /opt/requirements-ci.txt

View File

@ -14,7 +14,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op
# cmake-3.18.4 from pip
RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \
python3 -m pip install cmake==3.18.4 && \
ln -s /usr/local/bin/cmake /usr/bin/cmake3
FROM base as openssl
@ -135,7 +135,7 @@ RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
# cmake-3.18.4 from pip; force in case cmake3 already exists
RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \
python3 -m pip install cmake==3.18.4 && \
ln -sf /usr/local/bin/cmake /usr/bin/cmake3
FROM cpu_final as cuda_final
@ -157,7 +157,7 @@ ENV ROCM_PATH /opt/rocm
# cmake-3.28.4 from pip to get enable_language(HIP)
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
RUN python3 -m pip install --upgrade pip && \
python3 -mpip install cmake==3.28.4
python3 -m pip install cmake==3.28.4
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
@ -174,7 +174,7 @@ FROM cpu_final as xpu_final
ENV XPU_DRIVER_TYPE ROLLING
# cmake-3.28.4 from pip
RUN python3 -m pip install --upgrade pip && \
python3 -mpip install cmake==3.28.4
python3 -m pip install cmake==3.28.4
ADD ./common/install_xpu.sh install_xpu.sh
ENV XPU_VERSION 2025.2
RUN bash ./install_xpu.sh && rm install_xpu.sh

View File

@ -113,7 +113,7 @@ RUN dnf install -y \
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
# cmake-3.28.0 from pip for onnxruntime
RUN python3 -mpip install cmake==3.28.0
RUN python3 -m pip install cmake==3.28.0
ADD ./common/patch_libstdc.sh patch_libstdc.sh
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh

View File

@ -288,7 +288,7 @@ else
# or building non-XLA tests.
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
# Install numpy-2.0.2 for builds which are backward compatible with 1.X
python -mpip install numpy==2.0.2
python -m pip install numpy==2.0.2
WERROR=1 python setup.py clean

View File

@ -67,13 +67,13 @@ function pip_install_whl() {
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
python3 -m pip install --no-index --no-deps "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path"
python3 -m pip install --no-index --no-deps "$path"
done
fi
}

View File

@ -182,7 +182,7 @@ checkout_install_torchbench() {
pip uninstall -y torchao
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
python -m pip freeze
}
torchbench_setup_macos() {
@ -211,7 +211,7 @@ torchbench_setup_macos() {
}
pip_benchmark_deps() {
python -mpip install --no-input requests cython scikit-learn six
python -m pip install --no-input requests cython scikit-learn six
}

View File

@ -1434,7 +1434,7 @@ EOF
# shellcheck source=./common-build.sh
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
python -mpip install base_dist/*.whl
python -m pip install base_dist/*.whl
echo "::endgroup::"
pushd test/forward_backward_compatibility

View File

@ -173,7 +173,7 @@ esac
PINNED_PACKAGES=(
"numpy${NUMPY_PINNED_VERSION}"
)
python -mvenv ~/${desired_python}-build
python -m venv ~/${desired_python}-build
source ~/${desired_python}-build/bin/activate
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
retry brew install libomp

View File

@ -24,7 +24,7 @@ change_wheel_version() {
local t_version=$4
# Extract the wheel
${PYTHON_EXECUTABLE} -mwheel unpack $wheel
${PYTHON_EXECUTABLE} -m wheel unpack $wheel
mv "${package}-${f_version}" "${package}-${t_version}"
# Change the version from f_version to t_version in the dist-info dir
@ -47,7 +47,7 @@ change_wheel_version() {
popd
# Repack the wheel
${PYTHON_EXECUTABLE} -mwheel pack "${package}-${t_version}"
${PYTHON_EXECUTABLE} -m wheel pack "${package}-${t_version}"
# Clean up
rm -rf "${package}-${t_version}"
@ -85,7 +85,7 @@ repackage_wheel() {
}
# Require to re-package the wheel
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1
${PYTHON_EXECUTABLE} -m pip install wheel==0.45.1
pushd externals/vllm/wheels
for package in xformers flashinfer-python vllm; do

View File

@ -211,7 +211,7 @@ jobs:
$tool --version
done
python3 -mpip install --no-index --no-deps dist/*.whl
python3 -m pip install --no-index --no-deps dist/*.whl
set +e
pushd "${RUNNER_TEMP}"
@ -222,7 +222,7 @@ jobs:
popd
if [ "${RC}" -ne 0 ]; then
python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
python3 -m pip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
fi
set -e

View File

@ -204,7 +204,7 @@ jobs:
run: |
pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
# shellcheck disable=SC2046,SC2102
python3 -mpip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
python3 -m pip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
popd
.ci/pytorch/win-test.sh

View File

@ -126,13 +126,13 @@ jobs:
"${MANYLINUX_IMAGE}"
)
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip install \
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install \
--pre torch torchvision torchaudio \
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
# I wonder if there is a command to both download and install the wheels
# in one go
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip download \
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip download \
--pre torch torchvision torchaudio \
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"

View File

@ -106,7 +106,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -216,7 +216,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -326,7 +326,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -436,7 +436,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -546,7 +546,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -656,7 +656,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -766,7 +766,7 @@ jobs:
SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086
python -mvenv test_venv
python -m venv test_venv
source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v

View File

@ -52,3 +52,27 @@ jobs:
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
secrets: inherit
aarch64-opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: aarch64-opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
]}
secrets: inherit
aarch64-opbenchmark-test:
name: aarch64-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: aarch64-opbenchmark-build
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -39,7 +39,7 @@ RUN chmod +x ~/miniconda.sh && \
bash ~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
/opt/conda/bin/python -mpip install -r requirements.txt && \
/opt/conda/bin/python -m pip install -r requirements.txt && \
/opt/conda/bin/conda clean -ya
FROM dev-base as submodule-update

View File

@ -229,10 +229,10 @@ private:
}
static const uint32_t kPhilox10A = 0x9E3779B9;
static const uint32_t kPhilox10B = 0xBB67AE85;
static const uint32_t kPhiloxSA = 0xD2511F53;
static const uint32_t kPhiloxSB = 0xCD9E8D57;
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
};
typedef philox_engine Philox4_32;

View File

@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
*/
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>();
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state.
*/
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size;
detail::check_rng_state(new_state);

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS()
template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax;
}
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx,
[[maybe_unused]] float beta,
int64_t incy) {
auto intmax = std::numeric_limits<int>::max();
auto constexpr intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
}

View File

@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <ATen/native/Math.h>
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = {
constexpr static scalar_t kTailValues[] = {
0.0810614667953272,
0.0413406959554092,
0.0276779256849983,
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273,
0.00833056343336287
};
if (k <= 9) {
if (k < std::size(kTailValues)) {
return kTailValues[static_cast<size_t>(k)];
}
scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] =
static constexpr scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility>
#include <vector>
static const int MIOPEN_DIM_MAX = 5;
static constexpr int MIOPEN_DIM_MAX = 5;
namespace at::meta {

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement
static const int interp_size = 1;
static constexpr int interp_size = 1;
static inline void init_indices_weights(
at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase {
static const int interp_size = 2;
static constexpr int interp_size = 2;
// Compute indices and weights for each interpolated dimension
// indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase {
static const int interp_size = 4;
static constexpr int interp_size = 4;
// Compute indices and weights for each interpolated dimension
// indices_weights = {

View File

@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
}
static const int BLOCK_THREADS = 256;
static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM)

View File

@ -36,9 +36,9 @@ namespace at::native {
namespace {
#if defined(USE_ROCM)
static const int BLOCKDIMY = 16;
static constexpr int BLOCKDIMY = 16;
#else
static const int BLOCKDIMY = 32;
static constexpr int BLOCKDIMY = 32;
#endif
template

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859
};
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1.,
66.,
1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac;
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835;
static const accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
constexpr accscalar_t EXP1 = 2.718281828459045;
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const int MAXITER = 2000;
constexpr int MAXITER = 2000;
int i;
accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1;
accscalar_t sum = 0;
accscalar_t term, logx;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] =
constexpr accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn;
int maxpow = 0;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i;
accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
static const int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
constexpr int MAXITER = 2000;
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.;
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) {
// out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5;
constexpr accscalar_t SMALL = 20.0;
constexpr accscalar_t LARGE = 200.0;
constexpr accscalar_t SMALLRATIO = 0.3;
constexpr accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy
if ((x < 0) || (a < 0)) {

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify(
template <typename T>
T digamma(T x) {
static const double PI_f64 = 3.14159265358979323846;
static constexpr double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
static constexpr double PI_f64 = 3.14159265358979323846;
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
constexpr accscalar_t A[] = {
8.33333333333333333333E-2,
-2.10927960927960927961E-2,
7.57575757575757575758E-3,

View File

@ -1097,11 +1097,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
// threads with different threadIdx.x are independent and will produce results for different outputs.
// In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) {
#ifdef USE_ROCM
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
#else
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
#endif
// Case 1: "vectorize along input"
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
// we should avoid vectorization.

View File

@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) {
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
using factor_t = typename c10::scalar_value_type<acc_t>::type;
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
} else {
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
}
}
static void mean_kernel_cuda(TensorIterator& iter) {

View File

@ -13,24 +13,19 @@ namespace at::native {
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct sum_functor {
void operator()(TensorIterator& iter) {
#ifdef USE_ROCM
// Half and BFloat16 can be packed in groups of up to 8 elements and
// can use *_DWORDX4 instructions to achieve that.
const bool is_16_bits =
( (std::is_same<at::Half, scalar_t>::value) ||
(std::is_same<at::BFloat16, scalar_t>::value) );
if (is_16_bits) {
const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
};
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
return;
iter, func_wrapper<out_t>(sum_combine)
);
} else {
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>(sum_combine)
);
}
#endif
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
}
};

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0;
}
static const int size = 2;
static constexpr int size = 2;
};
// taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0;
}
static const int size = 4;
static constexpr int size = 4;
};
template <typename accscalar_t>

View File

@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot
return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20;
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation
template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt;
static const std::string_view binary_post_op = "none";
constexpr std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp,

View File

@ -16,8 +16,8 @@ namespace {
#ifdef USE_PYTORCH_QNNPACK
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
const static int qnnpack_softmax_output_zero_point = 0;
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
constexpr static int qnnpack_softmax_output_zero_point = 0;
bool is_qnnpack_compatible(
const Tensor& qx,

View File

@ -110,9 +110,9 @@ class ApplyLogSumExp {
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
static int constexpr kElementsPerAccess = ElementsPerAccess;
static int constexpr kCount = kElementsPerAccess;
static constexpr ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;

View File

@ -14,16 +14,16 @@ using namespace at;
namespace {
const auto int_min = std::numeric_limits<int>::min();
const auto int_max = std::numeric_limits<int>::max();
const auto long_min = std::numeric_limits<int64_t>::min();
const auto long_max = std::numeric_limits<int64_t>::max();
const auto float_lowest = std::numeric_limits<float>::lowest();
const auto float_min = std::numeric_limits<float>::min();
const auto float_max = std::numeric_limits<float>::max();
const auto double_lowest = std::numeric_limits<double>::lowest();
const auto double_min = std::numeric_limits<double>::min();
const auto double_max = std::numeric_limits<double>::max();
constexpr auto int_min = std::numeric_limits<int>::min();
constexpr auto int_max = std::numeric_limits<int>::max();
constexpr auto long_min = std::numeric_limits<int64_t>::min();
constexpr auto long_max = std::numeric_limits<int64_t>::max();
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
constexpr auto float_min = std::numeric_limits<float>::min();
constexpr auto float_max = std::numeric_limits<float>::max();
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
constexpr auto double_min = std::numeric_limits<double>::min();
constexpr auto double_max = std::numeric_limits<double>::max();
const std::vector<int> ints {
int_min,

View File

@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
// The internal state is returned as a CPU byte tensor.
auto state_tensor = at::detail::empty_cpu(
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::xpu::assertNotCapturing(
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size;
constexpr size_t seed_size = sizeof(uint64_t);
constexpr size_t offset_size = sizeof(uint64_t);
constexpr size_t total_size = seed_size + offset_size;
at::detail::check_rng_state(new_state);

View File

@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test(
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
)
op_bench.generate_pt_test(
configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short
+ configs.conv_1d_configs_long,
ConvTranspose1dBenchmark,
)
if not torch.backends.mkldnn.is_acl_available():
# convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
op_bench.generate_pt_test(
configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short
+ configs.conv_1d_configs_long,
ConvTranspose1dBenchmark,
)
"""

View File

@ -164,9 +164,6 @@ class TestIntTuple(TestCase):
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
) # 4 -> (1,0,0) -> 1*8 = 8
# Test with zero-length shape and strides
self.assertEqual(crd2idx(0, (), ()), 0) # 0 -> () -> sum([]) = 0
def test_idx2crd_basic(self):
# Test basic int/int case
self.assertEqual(idx2crd(2, 5, 1), 2)

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -76,7 +76,7 @@ def main() -> None:
if uv and (is_uv_managed_python or not need_user_flag):
pip_args = [uv, "pip", "install"]
elif sys.executable:
pip_args = [sys.executable, "-mpip", "install"]
pip_args = [sys.executable, "-m", "pip", "install"]
else:
pip_args = ["pip3", "install"]

View File

@ -707,13 +707,27 @@ class _LocalDeviceMesh:
lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode"
root_mesh = self._get_root_mesh()
submesh_dims = self.mesh_dim_names
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
for r in lm.ranks:
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
rank_coords = (rank_tensor == r).nonzero().tolist()
assert len(rank_coords) == 1
for d, c in enumerate(rank_coords[0][1:]):
coords[d][r] = c
old_get_rank = DeviceMesh.get_rank # type: ignore[assignment]
try:
for r in lm.ranks:
DeviceMesh.get_rank = lambda self: r # type: ignore[method-assign]
submesh = (
root_mesh
if submesh_dims is None
else root_mesh.__getitem__(submesh_dims)
)
rank_coords = (submesh.mesh == r).nonzero().tolist()
assert len(rank_coords) in (0, 1)
if len(rank_coords) == 0:
continue
for d, c in enumerate(rank_coords[0]):
coords[d][r] = c
finally:
DeviceMesh.get_rank = old_get_rank # type: ignore[method-assign]
out = [torch.SymInt(LocalIntNode(c)) for c in coords]

View File

@ -301,7 +301,10 @@ class _MeshLayout(Layout):
ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks))
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor:
def remap_to_tensor(
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
"""
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks.
@ -313,7 +316,10 @@ class _MeshLayout(Layout):
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation.
The shape of the `rank_map` must be 1D and contiguous.
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
Examples:
@ -330,18 +336,18 @@ class _MeshLayout(Layout):
Return: [[[10,30],[20,40]]]
Args:
rank_map: The concrete mesh tensor with actual device ranks
mesh_tensor: The concrete mesh tensor with actual device ranks
Returns:
torch.Tensor: A tensor representing the actual device allocation from rank_map
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
"""
assert rank_map.ndim == 1
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
complement_layout = self.complement(mesh_tensor.numel())
complement_layout = self.complement(rank_map.numel())
return rank_map.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
).reshape(-1, *self.top_level_sizes)
return (
mesh_tensor.flatten()
.as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides),
)
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)

View File

@ -198,9 +198,7 @@ def crd2idx(
for i in range(len(shape) - 1, 0, -1):
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
crd = crd // product(shape[i])
if len(shape) > 0:
result += crd2idx(crd, shape[0], stride[0])
return result
return result + crd2idx(crd, shape[0], stride[0])
else: # "int" "int" "int"
assert not is_tuple(shape) and not is_tuple(stride)
return crd * stride # all are ints after type checks

View File

@ -173,7 +173,7 @@ else:
"""
_device_type: str
_rank_map: torch.Tensor
_mesh: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None
@ -190,49 +190,46 @@ else:
_init_backend: bool = True,
_rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None:
self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = (
self._mesh = (
mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int)
)
self._rank_map = (
_root_mesh._rank_map
if _root_mesh is not None
else mesh_tensor.flatten()
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
elif len(backend_override) != self.mesh.ndim:
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {self.mesh.ndim}."
)
# Internal bookkeeping for the device mesh.
self._layout = (
_layout
if _layout
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride())
else _MeshLayout(self.mesh.size(), self.mesh.stride())
)
self._root_mesh = _root_mesh
assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == mesh_tensor.size(), (
assert self._layout.top_level_sizes == self.mesh.size(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
if backend_override is None:
backend_override = ((None, None),) * len(self._layout)
elif len(backend_override) != len(self._layout):
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {len(self._layout)}."
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._thread_id = None
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
# Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
@ -255,11 +252,6 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property
def device_type(self) -> str:
"""Returns the device type of the mesh."""
@ -268,17 +260,7 @@ else:
@property
def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices."""
full_mesh = self._layout.remap_to_tensor(self._rank_map)
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
return self._mesh
@property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
@ -293,9 +275,9 @@ else:
init_process_group()
world_size = get_world_size()
if self._layout.numel() > world_size:
if self.mesh.numel() > world_size:
raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!"
f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
)
# ONLY set the device if the current device is not initialized, if user already
@ -346,8 +328,8 @@ else:
default_group = _get_default_group()
if (
len(self._layout) == 1
and self._layout.numel() == get_world_size()
self.mesh.ndim == 1
and self.mesh.numel() == get_world_size()
and backend_override[0] == (None, None)
):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -366,11 +348,11 @@ else:
dim_group_names.append(dim_group.group_name)
else:
# create sub pgs base on the mesh argument specified
for dim in range(len(self._layout)):
for dim in range(self.mesh.ndim):
# swap the current dim to the last dim
# then reshape to flatten out other dims
pg_ranks_by_dim = (
self._layout[dim].nest().remap_to_tensor(self._rank_map)
pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-1, self.mesh.size(dim)
)
backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise
@ -466,14 +448,14 @@ else:
def __repr__(self) -> str:
device_mesh_repr = (
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})"
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
if self._mesh_dim_names
else f"{self._layout.top_level_sizes}"
else f"{tuple(self._mesh.shape)}"
)
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}"
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
# We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self.mesh.tolist()}"
device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
return f"{device_mesh_repr})"
def __hash__(self):
@ -483,7 +465,7 @@ else:
self._hash = hash(
(
self._flatten_mesh_list,
self._layout,
self._mesh.shape,
self._device_type,
self._mesh_dim_names,
self._thread_id,
@ -499,7 +481,7 @@ else:
return False
return (
self._flatten_mesh_list == other._flatten_mesh_list
and self._layout == other._layout
and self._mesh.shape == other._mesh.shape
and self._device_type == other._device_type
and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id
@ -591,16 +573,16 @@ else:
if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!")
if len(self._layout) > 1 and mesh_dim is None:
if self.mesh.ndim > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {len(self._layout)} dimensions",
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
"If you want to get the list of all the ProcessGroups in the DeviceMesh,"
"please use `get_all_groups()` instead.",
)
# Quick return if the current device_mesh is a 1D mesh.
if len(self._layout) == 1 and mesh_dim is None:
if self.mesh.ndim == 1 and mesh_dim is None:
return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = self._get_root_mesh()
@ -626,7 +608,7 @@ else:
Returns:
A list of :class:`ProcessGroup` object.
"""
return [self.get_group(i) for i in range(len(self._layout))]
return [self.get_group(i) for i in range(self.mesh.ndim)]
def _create_sub_mesh(
self,
@ -653,7 +635,9 @@ else:
]
)
cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map)
pg_ranks_by_dim = layout.remap_to_tensor(
root_mesh.mesh,
)
res_submesh = DeviceMesh._create_mesh_from_ranks(
self._device_type,
pg_ranks_by_dim,
@ -708,7 +692,9 @@ else:
cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map)
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
root_mesh.mesh,
)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh._device_type,
pg_ranks_by_dim.flatten(
@ -847,7 +833,9 @@ else:
"""
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map)
pg_ranks_by_dim = layout.remap_to_tensor(
self.mesh,
)
cur_rank = self.get_rank()
res_submeshes = []
for mesh_1d in pg_ranks_by_dim:
@ -908,7 +896,6 @@ else:
backend_override=backend_override,
_init_backend=_init_backend,
_layout=_layout,
_root_mesh=_root_mesh,
)
if cur_rank in mesh_nd:
res_mesh = mesh
@ -917,6 +904,8 @@ else:
f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world"
)
if _root_mesh is not None:
res_mesh._root_mesh = _root_mesh
return res_mesh
@staticmethod
@ -1015,17 +1004,15 @@ else:
return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is not None:
return self._layout[mesh_dim].numel()
return self._layout.numel()
return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
@property
def ndim(self) -> int:
return len(self._layout)
return self.mesh.ndim
@property
def shape(self) -> tuple[int, ...]:
return self._layout.top_level_sizes
return tuple(self.mesh.shape)
def get_rank(self) -> int:
"""
@ -1064,7 +1051,7 @@ else:
"""
if self.ndim > 1 and mesh_dim is None:
raise RuntimeError(
f"Found the DeviceMesh have {len(self._layout)} dimensions",
f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
)
elif mesh_dim is None:
@ -1128,7 +1115,9 @@ else:
root_mesh = self._get_root_mesh()
cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
root_mesh.mesh,
)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks(
@ -1152,7 +1141,7 @@ else:
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
)
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh._rank_map
root_mesh.mesh,
)
unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
self.device_type,

View File

@ -640,9 +640,9 @@ def get_pip_packages(run_lambda, patterns=None):
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
# People generally have pip as `pip` or `pip3`
# But here it is invoked as `python -mpip`
# But here it is invoked as `python -m pip`
out = run_and_read_all(
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"]
run_lambda, [sys.executable, "-m", "pip", "list", "--format=freeze"]
)
if out is None:
return pip_version, out