Compare commits

..

6 Commits

Author SHA1 Message Date
241b702918 Fix remaining -m<module> patterns in Python code and scripts for consistency
Co-authored-by: malfet <2453524+malfet@users.noreply.github.com>
2025-10-17 14:56:15 +00:00
83df2e0610 Replace python3 -mpip and python -mpip with python3 -m pip and python -m pip for better readability
Co-authored-by: malfet <2453524+malfet@users.noreply.github.com>
2025-10-17 14:52:55 +00:00
77fe8234bb Initial plan 2025-10-17 14:45:52 +00:00
6ece527fc5 [CI] Add aarch64 operator benchmark (#165585)
Running on Graviton4
Skip ConvTranspose1d benchmarks if PyTorch is compiled with ACL, due to https://github.com/pytorch/pytorch/issues/165654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165585
Approved by: https://github.com/huydhn
2025-10-17 14:42:14 +00:00
ce29d0d796 [ATen] Vectorize 8 elements on 16 bit data types for sum/mean (#165055)
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790
2025-10-17 13:39:36 +00:00
7231118db3 Turn some const variables into constexpr in C++ code (#165401)
This PR checks the C++ code and turns some const variables into constexpr.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165401
Approved by: https://github.com/Skylion007
2025-10-17 13:24:46 +00:00
51 changed files with 1584 additions and 236 deletions

View File

@ -20,7 +20,7 @@ ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
# cmake-3.18.4 from pip # cmake-3.18.4 from pip
RUN yum install -y python3-pip && \ RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \ python3 -m pip install cmake==3.18.4 && \
ln -s /usr/local/bin/cmake /usr/bin/cmake3 ln -s /usr/local/bin/cmake /usr/bin/cmake3
RUN rm -rf /usr/local/cuda-* RUN rm -rf /usr/local/cuda-*

View File

@ -25,7 +25,7 @@ function install_torchbench() {
python install.py --continue_on_fail python install.py --continue_on_fail
echo "Print all dependencies after TorchBench is installed" echo "Print all dependencies after TorchBench is installed"
python -mpip freeze python -m pip freeze
popd popd
chown -R jenkins torchbench chown -R jenkins torchbench

View File

@ -8,8 +8,8 @@ MKLROOT=/opt/intel
mkdir -p ${MKLROOT} mkdir -p ${MKLROOT}
pushd /tmp pushd /tmp
python3 -mpip install wheel python3 -m pip install wheel
python3 -mpip download -d . mkl-static==${MKL_VERSION} python3 -m pip download -d . mkl-static==${MKL_VERSION}
python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl python3 -m wheel unpack mkl_static-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl python3 -m wheel unpack mkl_include-${MKL_VERSION}-py2.py3-none-manylinux1_x86_64.whl
mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT} mv mkl_static-${MKL_VERSION}/mkl_static-${MKL_VERSION}.data/data/lib ${MKLROOT}

View File

@ -11,5 +11,5 @@ ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python
python -m venv /var/lib/jenkins/ci_env python -m venv /var/lib/jenkins/ci_env
source /var/lib/jenkins/ci_env/bin/activate source /var/lib/jenkins/ci_env/bin/activate
python -mpip install --upgrade pip python -m pip install --upgrade pip
python -mpip install -r /opt/requirements-ci.txt python -m pip install -r /opt/requirements-ci.txt

View File

@ -14,7 +14,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op
# cmake-3.18.4 from pip # cmake-3.18.4 from pip
RUN yum install -y python3-pip && \ RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \ python3 -m pip install cmake==3.18.4 && \
ln -s /usr/local/bin/cmake /usr/bin/cmake3 ln -s /usr/local/bin/cmake /usr/bin/cmake3
FROM base as openssl FROM base as openssl
@ -135,7 +135,7 @@ RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
# cmake-3.18.4 from pip; force in case cmake3 already exists # cmake-3.18.4 from pip; force in case cmake3 already exists
RUN yum install -y python3-pip && \ RUN yum install -y python3-pip && \
python3 -mpip install cmake==3.18.4 && \ python3 -m pip install cmake==3.18.4 && \
ln -sf /usr/local/bin/cmake /usr/bin/cmake3 ln -sf /usr/local/bin/cmake /usr/bin/cmake3
FROM cpu_final as cuda_final FROM cpu_final as cuda_final
@ -157,7 +157,7 @@ ENV ROCM_PATH /opt/rocm
# cmake-3.28.4 from pip to get enable_language(HIP) # cmake-3.28.4 from pip to get enable_language(HIP)
# and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker # and avoid 3.21.0 cmake+ninja issues with ninja inserting "-Wl,--no-as-needed" in LINK_FLAGS for static linker
RUN python3 -m pip install --upgrade pip && \ RUN python3 -m pip install --upgrade pip && \
python3 -mpip install cmake==3.28.4 python3 -m pip install cmake==3.28.4
# replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path # replace the libdrm in /opt/amdgpu with custom amdgpu.ids lookup path
ADD ./common/install_rocm_drm.sh install_rocm_drm.sh ADD ./common/install_rocm_drm.sh install_rocm_drm.sh
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
@ -174,7 +174,7 @@ FROM cpu_final as xpu_final
ENV XPU_DRIVER_TYPE ROLLING ENV XPU_DRIVER_TYPE ROLLING
# cmake-3.28.4 from pip # cmake-3.28.4 from pip
RUN python3 -m pip install --upgrade pip && \ RUN python3 -m pip install --upgrade pip && \
python3 -mpip install cmake==3.28.4 python3 -m pip install cmake==3.28.4
ADD ./common/install_xpu.sh install_xpu.sh ADD ./common/install_xpu.sh install_xpu.sh
ENV XPU_VERSION 2025.2 ENV XPU_VERSION 2025.2
RUN bash ./install_xpu.sh && rm install_xpu.sh RUN bash ./install_xpu.sh && rm install_xpu.sh

View File

@ -113,7 +113,7 @@ RUN dnf install -y \
RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
# cmake-3.28.0 from pip for onnxruntime # cmake-3.28.0 from pip for onnxruntime
RUN python3 -mpip install cmake==3.28.0 RUN python3 -m pip install cmake==3.28.0
ADD ./common/patch_libstdc.sh patch_libstdc.sh ADD ./common/patch_libstdc.sh patch_libstdc.sh
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh

View File

@ -288,7 +288,7 @@ else
# or building non-XLA tests. # or building non-XLA tests.
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
# Install numpy-2.0.2 for builds which are backward compatible with 1.X # Install numpy-2.0.2 for builds which are backward compatible with 1.X
python -mpip install numpy==2.0.2 python -m pip install numpy==2.0.2
WERROR=1 python setup.py clean WERROR=1 python setup.py clean

View File

@ -67,13 +67,13 @@ function pip_install_whl() {
# Loop through each path and install individually # Loop through each path and install individually
for path in "${paths[@]}"; do for path in "${paths[@]}"; do
echo "Installing $path" echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path" python3 -m pip install --no-index --no-deps "$path"
done done
else else
# Loop through each argument and install individually # Loop through each argument and install individually
for path in "${args[@]}"; do for path in "${args[@]}"; do
echo "Installing $path" echo "Installing $path"
python3 -mpip install --no-index --no-deps "$path" python3 -m pip install --no-index --no-deps "$path"
done done
fi fi
} }

View File

@ -182,7 +182,7 @@ checkout_install_torchbench() {
pip uninstall -y torchao pip uninstall -y torchao
echo "Print all dependencies after TorchBench is installed" echo "Print all dependencies after TorchBench is installed"
python -mpip freeze python -m pip freeze
} }
torchbench_setup_macos() { torchbench_setup_macos() {
@ -211,7 +211,7 @@ torchbench_setup_macos() {
} }
pip_benchmark_deps() { pip_benchmark_deps() {
python -mpip install --no-input requests cython scikit-learn six python -m pip install --no-input requests cython scikit-learn six
} }

View File

@ -1434,7 +1434,7 @@ EOF
# shellcheck source=./common-build.sh # shellcheck source=./common-build.sh
source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh"
python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist" python -m build --wheel --no-isolation -C--build-option=--bdist-dir="base_bdist_tmp" --outdir "base_dist"
python -mpip install base_dist/*.whl python -m pip install base_dist/*.whl
echo "::endgroup::" echo "::endgroup::"
pushd test/forward_backward_compatibility pushd test/forward_backward_compatibility

View File

@ -173,7 +173,7 @@ esac
PINNED_PACKAGES=( PINNED_PACKAGES=(
"numpy${NUMPY_PINNED_VERSION}" "numpy${NUMPY_PINNED_VERSION}"
) )
python -mvenv ~/${desired_python}-build python -m venv ~/${desired_python}-build
source ~/${desired_python}-build/bin/activate source ~/${desired_python}-build/bin/activate
retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt" retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.txt"
retry brew install libomp retry brew install libomp

View File

@ -24,7 +24,7 @@ change_wheel_version() {
local t_version=$4 local t_version=$4
# Extract the wheel # Extract the wheel
${PYTHON_EXECUTABLE} -mwheel unpack $wheel ${PYTHON_EXECUTABLE} -m wheel unpack $wheel
mv "${package}-${f_version}" "${package}-${t_version}" mv "${package}-${f_version}" "${package}-${t_version}"
# Change the version from f_version to t_version in the dist-info dir # Change the version from f_version to t_version in the dist-info dir
@ -47,7 +47,7 @@ change_wheel_version() {
popd popd
# Repack the wheel # Repack the wheel
${PYTHON_EXECUTABLE} -mwheel pack "${package}-${t_version}" ${PYTHON_EXECUTABLE} -m wheel pack "${package}-${t_version}"
# Clean up # Clean up
rm -rf "${package}-${t_version}" rm -rf "${package}-${t_version}"
@ -85,7 +85,7 @@ repackage_wheel() {
} }
# Require to re-package the wheel # Require to re-package the wheel
${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1 ${PYTHON_EXECUTABLE} -m pip install wheel==0.45.1
pushd externals/vllm/wheels pushd externals/vllm/wheels
for package in xformers flashinfer-python vllm; do for package in xformers flashinfer-python vllm; do

View File

@ -211,7 +211,7 @@ jobs:
$tool --version $tool --version
done done
python3 -mpip install --no-index --no-deps dist/*.whl python3 -m pip install --no-index --no-deps dist/*.whl
set +e set +e
pushd "${RUNNER_TEMP}" pushd "${RUNNER_TEMP}"
@ -222,7 +222,7 @@ jobs:
popd popd
if [ "${RC}" -ne 0 ]; then if [ "${RC}" -ne 0 ]; then
python3 -mpip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}" python3 -m pip install --ignore-installed -r "${PIP_REQUIREMENTS_FILE}"
fi fi
set -e set -e

View File

@ -204,7 +204,7 @@ jobs:
run: | run: |
pushd "${PYTORCH_FINAL_PACKAGE_DIR}" pushd "${PYTORCH_FINAL_PACKAGE_DIR}"
# shellcheck disable=SC2046,SC2102 # shellcheck disable=SC2046,SC2102
python3 -mpip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0 python3 -m pip install $(echo *.whl)[opt-einsum,optree] optree==0.13.0
popd popd
.ci/pytorch/win-test.sh .ci/pytorch/win-test.sh

View File

@ -126,13 +126,13 @@ jobs:
"${MANYLINUX_IMAGE}" "${MANYLINUX_IMAGE}"
) )
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip install \ docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install \
--pre torch torchvision torchaudio \ --pre torch torchvision torchaudio \
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}" --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"
# I wonder if there is a command to both download and install the wheels # I wonder if there is a command to both download and install the wheels
# in one go # in one go
docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -mpip download \ docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip download \
--pre torch torchvision torchaudio \ --pre torch torchvision torchaudio \
--index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}" --index-url "https://download.pytorch.org/whl/nightly/${BUILD_DEVICE}"

View File

@ -106,7 +106,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -216,7 +216,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -326,7 +326,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -436,7 +436,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -546,7 +546,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -656,7 +656,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
@ -766,7 +766,7 @@ jobs:
SMOKE_TEST_PARAMS="" SMOKE_TEST_PARAMS=""
# shellcheck disable=SC2086 # shellcheck disable=SC2086
python -mvenv test_venv python -m venv test_venv
source test_venv/bin/activate source test_venv/bin/activate
pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v

View File

@ -52,3 +52,27 @@ jobs:
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }} docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }} test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
secrets: inherit secrets: inherit
aarch64-opbenchmark-build:
if: github.repository_owner == 'pytorch'
name: aarch64-opbenchmark-build
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
]}
secrets: inherit
aarch64-opbenchmark-test:
name: aarch64-opbenchmark-test
uses: ./.github/workflows/_linux-test.yml
needs: aarch64-opbenchmark-build
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
secrets: inherit

View File

@ -39,7 +39,7 @@ RUN chmod +x ~/miniconda.sh && \
bash ~/miniconda.sh -b -p /opt/conda && \ bash ~/miniconda.sh -b -p /opt/conda && \
rm ~/miniconda.sh && \ rm ~/miniconda.sh && \
/opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \ /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \
/opt/conda/bin/python -mpip install -r requirements.txt && \ /opt/conda/bin/python -m pip install -r requirements.txt && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
FROM dev-base as submodule-update FROM dev-base as submodule-update

View File

@ -229,10 +229,10 @@ private:
} }
static const uint32_t kPhilox10A = 0x9E3779B9; static constexpr uint32_t kPhilox10A = 0x9E3779B9;
static const uint32_t kPhilox10B = 0xBB67AE85; static constexpr uint32_t kPhilox10B = 0xBB67AE85;
static const uint32_t kPhiloxSA = 0xD2511F53; static constexpr uint32_t kPhiloxSA = 0xD2511F53;
static const uint32_t kPhiloxSB = 0xCD9E8D57; static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
}; };
typedef philox_engine Philox4_32; typedef philox_engine Philox4_32;

View File

@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
*/ */
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const { c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox. // The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t); constexpr size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t); constexpr size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size; constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>(); auto rng_state = state_tensor.data_ptr<uint8_t>();
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
* and size of the internal state. * and size of the internal state.
*/ */
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) { void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
static const size_t seed_size = sizeof(uint64_t); constexpr size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(int64_t); constexpr size_t offset_size = sizeof(int64_t);
static const size_t total_size = seed_size + offset_size; constexpr size_t total_size = seed_size + offset_size;
detail::check_rng_state(new_state); detail::check_rng_state(new_state);

View File

@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
namespace at::native { namespace at::native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717; static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946; static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(elu_stub); DEFINE_DISPATCH(elu_stub);
DEFINE_DISPATCH(elu_backward_stub); DEFINE_DISPATCH(elu_backward_stub);

View File

@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
#if AT_BUILD_WITH_BLAS() #if AT_BUILD_WITH_BLAS()
template <> template <>
bool scal_use_fast_path<double>(int64_t n, int64_t incx) { bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
auto intmax = std::numeric_limits<int>::max(); auto constexpr intmax = std::numeric_limits<int>::max();
return n <= intmax && incx <= intmax; return n <= intmax && incx <= intmax;
} }
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
int64_t incx, int64_t incx,
[[maybe_unused]] float beta, [[maybe_unused]] float beta,
int64_t incy) { int64_t incy) {
auto intmax = std::numeric_limits<int>::max(); auto constexpr intmax = std::numeric_limits<int>::max();
return (m <= intmax) && (n <= intmax) && (lda <= intmax) && return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
} }

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include <array>
#include <ATen/native/Math.h> #include <ATen/native/Math.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h> #include <c10/util/MathConstants.h>
@ -127,7 +128,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
template<typename scalar_t> template<typename scalar_t>
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) { C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
const static scalar_t kTailValues[] = { constexpr static scalar_t kTailValues[] = {
0.0810614667953272, 0.0810614667953272,
0.0413406959554092, 0.0413406959554092,
0.0276779256849983, 0.0276779256849983,
@ -139,7 +140,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
0.00925546218271273, 0.00925546218271273,
0.00833056343336287 0.00833056343336287
}; };
if (k <= 9) { if (k < std::size(kTailValues)) {
return kTailValues[static_cast<size_t>(k)]; return kTailValues[static_cast<size_t>(k)];
} }
scalar_t kp1sq = (k + 1) * (k + 1); scalar_t kp1sq = (k + 1) * (k + 1);

View File

@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
template <typename scalar_t> template <typename scalar_t>
static scalar_t lanczos_sum_expg_scaled(scalar_t x) { static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation // lanczos approximation
static const scalar_t lanczos_sum_expg_scaled_num[13] = { static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222, 0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993, 0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208, 19.51992788247617482847860966235652136208,
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554, 103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859 56906521.91347156388090791033559122686859
}; };
static const scalar_t lanczos_sum_expg_scaled_denom[13] = { static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
1., 1.,
66., 66.,
1925., 1925.,
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
template <typename scalar_t> template <typename scalar_t>
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
static const scalar_t d[25][25] = static constexpr scalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,

View File

@ -62,7 +62,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
static const int MIOPEN_DIM_MAX = 5; static constexpr int MIOPEN_DIM_MAX = 5;
namespace at::meta { namespace at::meta {

View File

@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
// We keep this structure for BC and consider as deprecated. // We keep this structure for BC and consider as deprecated.
// See HelperInterpNearestExact as replacement // See HelperInterpNearestExact as replacement
static const int interp_size = 1; static constexpr int interp_size = 1;
static inline void init_indices_weights( static inline void init_indices_weights(
at::ScalarType output_type, at::ScalarType output_type,
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
struct HelperInterpLinear : public HelperInterpBase { struct HelperInterpLinear : public HelperInterpBase {
static const int interp_size = 2; static constexpr int interp_size = 2;
// Compute indices and weights for each interpolated dimension // Compute indices and weights for each interpolated dimension
// indices_weights = { // indices_weights = {
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
struct HelperInterpCubic : public HelperInterpBase { struct HelperInterpCubic : public HelperInterpBase {
static const int interp_size = 4; static constexpr int interp_size = 4;
// Compute indices and weights for each interpolated dimension // Compute indices and weights for each interpolated dimension
// indices_weights = { // indices_weights = {

View File

@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
} }
static const int BLOCK_THREADS = 256; static constexpr int BLOCK_THREADS = 256;
template <typename scalar_t, typename accscalar_t> template <typename scalar_t, typename accscalar_t>
#if defined (USE_ROCM) #if defined (USE_ROCM)

View File

@ -36,9 +36,9 @@ namespace at::native {
namespace { namespace {
#if defined(USE_ROCM) #if defined(USE_ROCM)
static const int BLOCKDIMY = 16; static constexpr int BLOCKDIMY = 16;
#else #else
static const int BLOCKDIMY = 32; static constexpr int BLOCKDIMY = 32;
#endif #endif
template template

View File

@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
// lanczos approximation // lanczos approximation
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t lanczos_sum_expg_scaled_num[13] = { constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
0.006061842346248906525783753964555936883222, 0.006061842346248906525783753964555936883222,
0.5098416655656676188125178644804694509993, 0.5098416655656676188125178644804694509993,
19.51992788247617482847860966235652136208, 19.51992788247617482847860966235652136208,
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
103794043.1163445451906271053616070238554, 103794043.1163445451906271053616070238554,
56906521.91347156388090791033559122686859 56906521.91347156388090791033559122686859
}; };
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = { constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
1., 1.,
66., 66.,
1925., 1925.,
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t ax, fac, res, num, numfac; accscalar_t ax, fac, res, num, numfac;
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ? constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
7.09782712893383996843E2 : 88.72283905206835; 7.09782712893383996843E2 : 88.72283905206835;
static const accscalar_t EXP1 = 2.718281828459045; constexpr accscalar_t EXP1 = 2.718281828459045;
static const accscalar_t lanczos_g = 6.024680040776729583740234375; constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
if (::fabs(a - x) > 0.4 * ::fabs(a)) { if (::fabs(a - x) > 0.4 * ::fabs(a)) {
ax = a * ::log(x) - x - ::lgamma(a); ax = a * ::log(x) - x - ::lgamma(a);
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
// Compute igam using DLMF 8.11.4. [igam1] // Compute igam using DLMF 8.11.4. [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8; 1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const int MAXITER = 2000; constexpr int MAXITER = 2000;
int i; int i;
accscalar_t ans, ax, c, r; accscalar_t ans, ax, c, r;
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
accscalar_t fac = 1; accscalar_t fac = 1;
accscalar_t sum = 0; accscalar_t sum = 0;
accscalar_t term, logx; accscalar_t term, logx;
static const int MAXITER = 2000; constexpr int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8; 1.11022302462515654042E-16 : 5.9604644775390625E-8;
for (n = 1; n < MAXITER; n++) { for (n = 1; n < MAXITER; n++) {
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const accscalar_t d[25][25] = constexpr accscalar_t d[25][25] =
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15}, {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15}, {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15}, {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
int k, n, sgn; int k, n, sgn;
int maxpow = 0; int maxpow = 0;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8; 1.11022302462515654042E-16 : 5.9604644775390625E-8;
accscalar_t lambda = x / a; accscalar_t lambda = x / a;
accscalar_t sigma = (x - a) / a; accscalar_t sigma = (x - a) / a;
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
int i; int i;
accscalar_t ans, ax, c, yc, r, t, y, z; accscalar_t ans, ax, c, yc, r, t, y, z;
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
static const int MAXITER = 2000; constexpr int MAXITER = 2000;
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ? constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
1.11022302462515654042E-16 : 5.9604644775390625E-8; 1.11022302462515654042E-16 : 5.9604644775390625E-8;
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ? constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
4.503599627370496e15 : 16777216.; 4.503599627370496e15 : 16777216.;
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ? constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
2.22044604925031308085e-16 : 5.9604644775390625E-8; 2.22044604925031308085e-16 : 5.9604644775390625E-8;
ax = _igam_helper_fac(a, x); ax = _igam_helper_fac(a, x);
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a; accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0; constexpr accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0; constexpr accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3; constexpr accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5; constexpr accscalar_t LARGERATIO = 4.5;
if ((x < 0) || (a < 0)) { if ((x < 0) || (a < 0)) {
// out of defined-region of the function // out of defined-region of the function
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
accscalar_t absxma_a; accscalar_t absxma_a;
static const accscalar_t SMALL = 20.0; constexpr accscalar_t SMALL = 20.0;
static const accscalar_t LARGE = 200.0; constexpr accscalar_t LARGE = 200.0;
static const accscalar_t SMALLRATIO = 0.3; constexpr accscalar_t SMALLRATIO = 0.3;
static const accscalar_t LARGERATIO = 4.5; constexpr accscalar_t LARGERATIO = 4.5;
// boundary values following SciPy // boundary values following SciPy
if ((x < 0) || (a < 0)) { if ((x < 0) || (a < 0)) {

View File

@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
const auto digamma_string = jiterator_stringify( const auto digamma_string = jiterator_stringify(
template <typename T> template <typename T>
T digamma(T x) { T digamma(T x) {
static const double PI_f64 = 3.14159265358979323846; static constexpr double PI_f64 = 3.14159265358979323846;
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard // Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
if (x == 0) { if (x == 0) {
@ -3072,9 +3072,9 @@ template <typename scalar_t>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>; using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
static const double PI_f64 = 3.14159265358979323846; static constexpr double PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764; constexpr accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = { constexpr accscalar_t A[] = {
8.33333333333333333333E-2, 8.33333333333333333333E-2,
-2.10927960927960927961E-2, -2.10927960927960927961E-2,
7.57575757575757575758E-3, 7.57575757575757575758E-3,

View File

@ -1097,11 +1097,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
// threads with different threadIdx.x are independent and will produce results for different outputs. // threads with different threadIdx.x are independent and will produce results for different outputs.
// In such case, values in each loaded vector always correspond to different outputs. // In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) { if (fastest_moving_stride == sizeof(scalar_t)) {
#ifdef USE_ROCM
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) { if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
#else
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
#endif
// Case 1: "vectorize along input" // Case 1: "vectorize along input"
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case, // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
// we should avoid vectorization. // we should avoid vectorization.

View File

@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) { void mean_kernel_impl(TensorIterator& iter) {
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T> // returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
using factor_t = typename c10::scalar_value_type<acc_t>::type; using factor_t = typename c10::scalar_value_type<acc_t>::type;
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel(); factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); if constexpr (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
} else {
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
}
} }
static void mean_kernel_cuda(TensorIterator& iter) { static void mean_kernel_cuda(TensorIterator& iter) {

View File

@ -13,24 +13,19 @@ namespace at::native {
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t> template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct sum_functor { struct sum_functor {
void operator()(TensorIterator& iter) { void operator()(TensorIterator& iter) {
#ifdef USE_ROCM const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
// Half and BFloat16 can be packed in groups of up to 8 elements and return a + b;
// can use *_DWORDX4 instructions to achieve that. };
const bool is_16_bits = constexpr bool is_16_bits = sizeof(scalar_t) == 2;
( (std::is_same<at::Half, scalar_t>::value) || if constexpr (is_16_bits) {
(std::is_same<at::BFloat16, scalar_t>::value) );
if (is_16_bits) {
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>( gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { iter, func_wrapper<out_t>(sum_combine)
return a + b; );
})); } else {
return; gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>(sum_combine)
);
} }
#endif
gpu_reduce_kernel<scalar_t, out_t>(
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
} }
}; };

View File

@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
return 0; return 0;
} }
static const int size = 2; static constexpr int size = 2;
}; };
// taken from // taken from
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
return 0; return 0;
} }
static const int size = 4; static constexpr int size = 4;
}; };
template <typename accscalar_t> template <typename accscalar_t>

View File

@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k) // else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n) // else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel // only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16; constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
if (mat1.dim() == 1 && mat2.dim() == 1) { if (mat1.dim() == 1 && mat2.dim() == 1) {
// aten::dot // aten::dot
return mat1.size(0) > mkldnn_gemm_min_size; return mat1.size(0) > mkldnn_gemm_min_size;

View File

@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__) #if defined(__ARM_NEON__) || defined(__aarch64__)
const static int PARALLEL_THRESHOLD = 1 << 20; constexpr static int PARALLEL_THRESHOLD = 1 << 20;
// Generic template defaults to naive quantize implementation // Generic template defaults to naive quantize implementation
template <typename T> template <typename T>

View File

@ -1388,7 +1388,7 @@ namespace at::native {
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1, TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
"onednn int8 linear: act scale/zp size should be 1/<=1"); "onednn int8 linear: act scale/zp size should be 1/<=1");
static std::optional<at::Tensor> other = std::nullopt; static std::optional<at::Tensor> other = std::nullopt;
static const std::string_view binary_post_op = "none"; constexpr std::string_view binary_post_op = "none";
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0; int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
return linear_int8_with_onednn_weight( return linear_int8_with_onednn_weight(
act, act_scale.item().toDouble(), act_zp, act, act_scale.item().toDouble(), act_zp,

View File

@ -16,8 +16,8 @@ namespace {
#ifdef USE_PYTORCH_QNNPACK #ifdef USE_PYTORCH_QNNPACK
const static float qnnpack_softmax_output_scale = 0x1.0p-8f; constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
const static int qnnpack_softmax_output_zero_point = 0; constexpr static int qnnpack_softmax_output_zero_point = 0;
bool is_qnnpack_compatible( bool is_qnnpack_compatible(
const Tensor& qx, const Tensor& qx,

View File

@ -110,9 +110,9 @@ class ApplyLogSumExp {
using ElementCompute = ElementCompute_; using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_; using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess; static int constexpr kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess; static int constexpr kCount = kElementsPerAccess;
static const ScaleType::Kind kScale = static constexpr ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling; cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>; using FragmentOutput = Array<ElementOutput, kCount>;

View File

@ -14,16 +14,16 @@ using namespace at;
namespace { namespace {
const auto int_min = std::numeric_limits<int>::min(); constexpr auto int_min = std::numeric_limits<int>::min();
const auto int_max = std::numeric_limits<int>::max(); constexpr auto int_max = std::numeric_limits<int>::max();
const auto long_min = std::numeric_limits<int64_t>::min(); constexpr auto long_min = std::numeric_limits<int64_t>::min();
const auto long_max = std::numeric_limits<int64_t>::max(); constexpr auto long_max = std::numeric_limits<int64_t>::max();
const auto float_lowest = std::numeric_limits<float>::lowest(); constexpr auto float_lowest = std::numeric_limits<float>::lowest();
const auto float_min = std::numeric_limits<float>::min(); constexpr auto float_min = std::numeric_limits<float>::min();
const auto float_max = std::numeric_limits<float>::max(); constexpr auto float_max = std::numeric_limits<float>::max();
const auto double_lowest = std::numeric_limits<double>::lowest(); constexpr auto double_lowest = std::numeric_limits<double>::lowest();
const auto double_min = std::numeric_limits<double>::min(); constexpr auto double_min = std::numeric_limits<double>::min();
const auto double_max = std::numeric_limits<double>::max(); constexpr auto double_max = std::numeric_limits<double>::max();
const std::vector<int> ints { const std::vector<int> ints {
int_min, int_min,

View File

@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const { c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
// The RNG state comprises the seed, and an offset used for Philox. // The RNG state comprises the seed, and an offset used for Philox.
static const size_t seed_size = sizeof(uint64_t); constexpr size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t); constexpr size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size; constexpr size_t total_size = seed_size + offset_size;
// The internal state is returned as a CPU byte tensor. // The internal state is returned as a CPU byte tensor.
auto state_tensor = at::detail::empty_cpu( auto state_tensor = at::detail::empty_cpu(
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
at::xpu::assertNotCapturing( at::xpu::assertNotCapturing(
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing."); "Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
static const size_t seed_size = sizeof(uint64_t); constexpr size_t seed_size = sizeof(uint64_t);
static const size_t offset_size = sizeof(uint64_t); constexpr size_t offset_size = sizeof(uint64_t);
static const size_t total_size = seed_size + offset_size; constexpr size_t total_size = seed_size + offset_size;
at::detail::check_rng_state(new_state); at::detail::check_rng_state(new_state);

View File

@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
op_bench.generate_pt_test( op_bench.generate_pt_test(
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
) )
op_bench.generate_pt_test(
configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short if not torch.backends.mkldnn.is_acl_available():
+ configs.conv_1d_configs_long, # convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
ConvTranspose1dBenchmark, op_bench.generate_pt_test(
) configs.convtranspose_1d_configs_short
+ configs.conv_1d_configs_short
+ configs.conv_1d_configs_long,
ConvTranspose1dBenchmark,
)
""" """

View File

@ -164,9 +164,6 @@ class TestIntTuple(TestCase):
crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8 crd2idx(4, ((2, 2, 2), (2, 2, 2)), ((1, 16, 4), (8, 2, 32))), 8
) # 4 -> (1,0,0) -> 1*8 = 8 ) # 4 -> (1,0,0) -> 1*8 = 8
# Test with zero-length shape and strides
self.assertEqual(crd2idx(0, (), ()), 0) # 0 -> () -> sum([]) = 0
def test_idx2crd_basic(self): def test_idx2crd_basic(self):
# Test basic int/int case # Test basic int/int case
self.assertEqual(idx2crd(2, 5, 1), 2) self.assertEqual(idx2crd(2, 5, 1), 2)

View File

@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
def test_remap_to_tensor(self): def test_remap_to_tensor(self):
"""Test the remap_to_tensor method for various scenarios.""" """Test the remap_to_tensor method for various scenarios."""
# Test 1: Consecutive ranks, full world - should return logical groups directly # Test 1: Consecutive ranks, full world - should return logical groups directly
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int) original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2 layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
result1 = layout1.remap_to_tensor(original_mesh) result1 = layout1.remap_to_tensor(original_mesh)
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
self.assertEqual(result1, expected1) self.assertEqual(result1, expected1)
# Test 2: Non-consecutive ranks - should map to actual ranks # Test 2: Non-consecutive ranks - should map to actual ranks
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int) original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
layout2 = _Layout((2, 2), (2, 1)) layout2 = _Layout((2, 2), (2, 1))
result2 = layout2.remap_to_tensor(original_mesh) result2 = layout2.remap_to_tensor(original_mesh)
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int) expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
self.assertEqual(result5, expected5) self.assertEqual(result5, expected5)
# Test 6: Tensor Cute representation of a 2D mesh # Test 6: Tensor Cute representation of a 2D mesh
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int) original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
layout6 = _Layout((2, 2), (1, 2)) # column-major style layout6 = _Layout((2, 2), (1, 2)) # column-major style
result6 = layout6.remap_to_tensor(original_mesh) result6 = layout6.remap_to_tensor(original_mesh)
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int) expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)

View File

@ -76,7 +76,7 @@ def main() -> None:
if uv and (is_uv_managed_python or not need_user_flag): if uv and (is_uv_managed_python or not need_user_flag):
pip_args = [uv, "pip", "install"] pip_args = [uv, "pip", "install"]
elif sys.executable: elif sys.executable:
pip_args = [sys.executable, "-mpip", "install"] pip_args = [sys.executable, "-m", "pip", "install"]
else: else:
pip_args = ["pip3", "install"] pip_args = ["pip3", "install"]

View File

@ -707,13 +707,27 @@ class _LocalDeviceMesh:
lm = local_tensor_mode() lm = local_tensor_mode()
assert lm is not None, "Unexpectedly not in LocalTensorMode" assert lm is not None, "Unexpectedly not in LocalTensorMode"
root_mesh = self._get_root_mesh()
submesh_dims = self.mesh_dim_names
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
for r in lm.ranks: old_get_rank = DeviceMesh.get_rank # type: ignore[assignment]
rank_tensor = self._layout.remap_to_tensor(self._rank_map) try:
rank_coords = (rank_tensor == r).nonzero().tolist() for r in lm.ranks:
assert len(rank_coords) == 1 DeviceMesh.get_rank = lambda self: r # type: ignore[method-assign]
for d, c in enumerate(rank_coords[0][1:]): submesh = (
coords[d][r] = c root_mesh
if submesh_dims is None
else root_mesh.__getitem__(submesh_dims)
)
rank_coords = (submesh.mesh == r).nonzero().tolist()
assert len(rank_coords) in (0, 1)
if len(rank_coords) == 0:
continue
for d, c in enumerate(rank_coords[0]):
coords[d][r] = c
finally:
DeviceMesh.get_rank = old_get_rank # type: ignore[method-assign]
out = [torch.SymInt(LocalIntNode(c)) for c in coords] out = [torch.SymInt(LocalIntNode(c)) for c in coords]

View File

@ -301,7 +301,10 @@ class _MeshLayout(Layout):
ranks = self.all_ranks_from_zero() ranks = self.all_ranks_from_zero()
return len(ranks) == len(set(ranks)) return len(ranks) == len(set(ranks))
def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor: def remap_to_tensor(
self,
mesh_tensor: torch.Tensor,
) -> torch.Tensor:
""" """
Leverage layout as an index for mesh tensor that re-maps the indexes after layout Leverage layout as an index for mesh tensor that re-maps the indexes after layout
transformation to actual device ranks. transformation to actual device ranks.
@ -313,7 +316,10 @@ class _MeshLayout(Layout):
can be treated as a view or subset of mesh tensor, we do need to use the actual view or can be treated as a view or subset of mesh tensor, we do need to use the actual view or
sub-tensor for DeviceMesh and its backend creation. sub-tensor for DeviceMesh and its backend creation.
The shape of the `rank_map` must be 1D and contiguous. The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
Examples: Examples:
@ -330,18 +336,18 @@ class _MeshLayout(Layout):
Return: [[[10,30],[20,40]]] Return: [[[10,30],[20,40]]]
Args: Args:
rank_map: The concrete mesh tensor with actual device ranks mesh_tensor: The concrete mesh tensor with actual device ranks
Returns: Returns:
torch.Tensor: A tensor representing the actual device allocation from rank_map torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
""" """
assert rank_map.ndim == 1 complement_layout = self.complement(mesh_tensor.numel())
assert rank_map.is_contiguous()
assert rank_map.numel() >= self.cosize()
complement_layout = self.complement(rank_map.numel()) return (
mesh_tensor.flatten()
return rank_map.as_strided( .as_strided(
flatten(complement_layout.sizes) + flatten(self.sizes), flatten(complement_layout.sizes) + flatten(self.sizes),
flatten(complement_layout.strides) + flatten(self.strides), flatten(complement_layout.strides) + flatten(self.strides),
).reshape(-1, *self.top_level_sizes) )
.reshape(-1, *(self[i].numel() for i in range(len(self))))
)

View File

@ -198,9 +198,7 @@ def crd2idx(
for i in range(len(shape) - 1, 0, -1): for i in range(len(shape) - 1, 0, -1):
result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
crd = crd // product(shape[i]) crd = crd // product(shape[i])
if len(shape) > 0: return result + crd2idx(crd, shape[0], stride[0])
result += crd2idx(crd, shape[0], stride[0])
return result
else: # "int" "int" "int" else: # "int" "int" "int"
assert not is_tuple(shape) and not is_tuple(stride) assert not is_tuple(shape) and not is_tuple(stride)
return crd * stride # all are ints after type checks return crd * stride # all are ints after type checks

View File

@ -173,7 +173,7 @@ else:
""" """
_device_type: str _device_type: str
_rank_map: torch.Tensor _mesh: torch.Tensor
_mesh_dim_names: Optional[tuple[str, ...]] _mesh_dim_names: Optional[tuple[str, ...]]
_layout: _MeshLayout _layout: _MeshLayout
_root_mesh: Optional["DeviceMesh"] = None _root_mesh: Optional["DeviceMesh"] = None
@ -190,49 +190,46 @@ else:
_init_backend: bool = True, _init_backend: bool = True,
_rank: Optional[int] = None, _rank: Optional[int] = None,
_layout: Optional[_MeshLayout] = None, _layout: Optional[_MeshLayout] = None,
_root_mesh: Optional["DeviceMesh"] = None,
) -> None: ) -> None:
self._device_type = device_type self._device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu":
raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}")
mesh_tensor = ( self._mesh = (
mesh.detach().to(dtype=torch.int).contiguous() mesh.detach().to(dtype=torch.int).contiguous()
if isinstance(mesh, torch.Tensor) if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, device="cpu", dtype=torch.int) else torch.tensor(mesh, device="cpu", dtype=torch.int)
) )
self._rank_map = (
_root_mesh._rank_map
if _root_mesh is not None
else mesh_tensor.flatten()
)
self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None
if backend_override is None:
backend_override = ((None, None),) * self.mesh.ndim
elif len(backend_override) != self.mesh.ndim:
raise ValueError(
f"backend_override should have the same length as the number of mesh dimensions, "
f"but got {len(backend_override)} and {self.mesh.ndim}."
)
# Internal bookkeeping for the device mesh. # Internal bookkeeping for the device mesh.
self._layout = ( self._layout = (
_layout _layout
if _layout if _layout
else _MeshLayout(mesh_tensor.size(), mesh_tensor.stride()) else _MeshLayout(self.mesh.size(), self.mesh.stride())
) )
self._root_mesh = _root_mesh
assert self._layout.check_non_overlap(), ( assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh." "Please use a non-overlapping layout when creating a DeviceMesh."
) )
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here. # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
assert self._layout.top_level_sizes == mesh_tensor.size(), ( assert self._layout.top_level_sizes == self.mesh.size(), (
"Please use a valid layout when creating a DeviceMesh." "Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {mesh_tensor.size()}." f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
) )
if backend_override is None: # private field to pre-generate DeviceMesh's hash
backend_override = ((None, None),) * len(self._layout) self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
elif len(backend_override) != len(self._layout): self._thread_id = None
raise ValueError( # Initialize instance-specific flatten mapping
f"backend_override should have the same length as the number of mesh dimensions, " self._flatten_mapping = {}
f"but got {len(backend_override)} and {len(self._layout)}."
)
# Skip process group initialization if xla device or init backend is False # Skip process group initialization if xla device or init backend is False
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend. # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
self._thread_id = None
if device_type != "xla": if device_type != "xla":
# always try to create default (world) pg, even if it is not initialized # always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each # already. The world pg is used for device mesh identity (rank) on each
@ -255,11 +252,6 @@ else:
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
) )
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
# Initialize instance-specific flatten mapping
self._flatten_mapping = {}
@property @property
def device_type(self) -> str: def device_type(self) -> str:
"""Returns the device type of the mesh.""" """Returns the device type of the mesh."""
@ -268,17 +260,7 @@ else:
@property @property
def mesh(self) -> torch.Tensor: def mesh(self) -> torch.Tensor:
"""Returns the tensor representing the layout of devices.""" """Returns the tensor representing the layout of devices."""
full_mesh = self._layout.remap_to_tensor(self._rank_map) return self._mesh
if full_mesh.size(0) == 1:
return full_mesh[0]
my_coords = (full_mesh == get_rank()).nonzero()
if my_coords.size(0) > 0:
return full_mesh[my_coords[0, 0]]
raise RuntimeError(
"In order to get the mesh Tensor of a DeviceMesh it needs to "
"either have all its original dimensions (e.g., no slicing) "
"or it needs to contain the local rank"
)
@property @property
def mesh_dim_names(self) -> Optional[tuple[str, ...]]: def mesh_dim_names(self) -> Optional[tuple[str, ...]]:
@ -293,9 +275,9 @@ else:
init_process_group() init_process_group()
world_size = get_world_size() world_size = get_world_size()
if self._layout.numel() > world_size: if self.mesh.numel() > world_size:
raise RuntimeError( raise RuntimeError(
f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!" f"Mesh should not be bigger than default world size {world_size}, but found {self.mesh.numel()} ranks!"
) )
# ONLY set the device if the current device is not initialized, if user already # ONLY set the device if the current device is not initialized, if user already
@ -346,8 +328,8 @@ else:
default_group = _get_default_group() default_group = _get_default_group()
if ( if (
len(self._layout) == 1 self.mesh.ndim == 1
and self._layout.numel() == get_world_size() and self.mesh.numel() == get_world_size()
and backend_override[0] == (None, None) and backend_override[0] == (None, None)
): ):
# Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`.
@ -366,11 +348,11 @@ else:
dim_group_names.append(dim_group.group_name) dim_group_names.append(dim_group.group_name)
else: else:
# create sub pgs base on the mesh argument specified # create sub pgs base on the mesh argument specified
for dim in range(len(self._layout)): for dim in range(self.mesh.ndim):
# swap the current dim to the last dim # swap the current dim to the last dim
# then reshape to flatten out other dims # then reshape to flatten out other dims
pg_ranks_by_dim = ( pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
self._layout[dim].nest().remap_to_tensor(self._rank_map) -1, self.mesh.size(dim)
) )
backend, pg_options = backend_override[dim] backend, pg_options = backend_override[dim]
# We need to explicitly pass in timeout when specified in option, otherwise # We need to explicitly pass in timeout when specified in option, otherwise
@ -466,14 +448,14 @@ else:
def __repr__(self) -> str: def __repr__(self) -> str:
device_mesh_repr = ( device_mesh_repr = (
f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})" f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._mesh.shape))})"
if self._mesh_dim_names if self._mesh_dim_names
else f"{self._layout.top_level_sizes}" else f"{tuple(self._mesh.shape)}"
) )
device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}" device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._mesh.stride()}"
# We only print the mesh tensor if the debug mode is turned on. # We only print the mesh tensor if the debug mode is turned on.
if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL":
device_mesh_repr += f", Mesh: {self.mesh.tolist()}" device_mesh_repr += f", Mesh: {self._mesh.tolist()}"
return f"{device_mesh_repr})" return f"{device_mesh_repr})"
def __hash__(self): def __hash__(self):
@ -483,7 +465,7 @@ else:
self._hash = hash( self._hash = hash(
( (
self._flatten_mesh_list, self._flatten_mesh_list,
self._layout, self._mesh.shape,
self._device_type, self._device_type,
self._mesh_dim_names, self._mesh_dim_names,
self._thread_id, self._thread_id,
@ -499,7 +481,7 @@ else:
return False return False
return ( return (
self._flatten_mesh_list == other._flatten_mesh_list self._flatten_mesh_list == other._flatten_mesh_list
and self._layout == other._layout and self._mesh.shape == other._mesh.shape
and self._device_type == other._device_type and self._device_type == other._device_type
and self._mesh_dim_names == other._mesh_dim_names and self._mesh_dim_names == other._mesh_dim_names
and self._thread_id == other._thread_id and self._thread_id == other._thread_id
@ -591,16 +573,16 @@ else:
if not hasattr(self, "_dim_group_names"): if not hasattr(self, "_dim_group_names"):
raise RuntimeError("DeviceMesh process groups not initialized!") raise RuntimeError("DeviceMesh process groups not initialized!")
if len(self._layout) > 1 and mesh_dim is None: if self.mesh.ndim > 1 and mesh_dim is None:
raise RuntimeError( raise RuntimeError(
f"Found the DeviceMesh have {len(self._layout)} dimensions", f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
"If you want to get the list of all the ProcessGroups in the DeviceMesh," "If you want to get the list of all the ProcessGroups in the DeviceMesh,"
"please use `get_all_groups()` instead.", "please use `get_all_groups()` instead.",
) )
# Quick return if the current device_mesh is a 1D mesh. # Quick return if the current device_mesh is a 1D mesh.
if len(self._layout) == 1 and mesh_dim is None: if self.mesh.ndim == 1 and mesh_dim is None:
return not_none(_resolve_process_group(self._dim_group_names[0])) return not_none(_resolve_process_group(self._dim_group_names[0]))
root_mesh = self._get_root_mesh() root_mesh = self._get_root_mesh()
@ -626,7 +608,7 @@ else:
Returns: Returns:
A list of :class:`ProcessGroup` object. A list of :class:`ProcessGroup` object.
""" """
return [self.get_group(i) for i in range(len(self._layout))] return [self.get_group(i) for i in range(self.mesh.ndim)]
def _create_sub_mesh( def _create_sub_mesh(
self, self,
@ -653,7 +635,9 @@ else:
] ]
) )
cur_rank = self.get_rank() cur_rank = self.get_rank()
pg_ranks_by_dim = layout.remap_to_tensor(root_mesh._rank_map) pg_ranks_by_dim = layout.remap_to_tensor(
root_mesh.mesh,
)
res_submesh = DeviceMesh._create_mesh_from_ranks( res_submesh = DeviceMesh._create_mesh_from_ranks(
self._device_type, self._device_type,
pg_ranks_by_dim, pg_ranks_by_dim,
@ -708,7 +692,9 @@ else:
cur_rank = root_mesh.get_rank() cur_rank = root_mesh.get_rank()
# Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the # Due to the limitation of ProcessGroup api, we need to start from root mesh so that all ranks call the
# new_group api to avoid potential hang. # new_group api to avoid potential hang.
pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(root_mesh._rank_map) pg_ranks_by_dim = flattened_mesh_layout.remap_to_tensor(
root_mesh.mesh,
)
res_flattened_mesh = DeviceMesh._create_mesh_from_ranks( res_flattened_mesh = DeviceMesh._create_mesh_from_ranks(
root_mesh._device_type, root_mesh._device_type,
pg_ranks_by_dim.flatten( pg_ranks_by_dim.flatten(
@ -847,7 +833,9 @@ else:
""" """
mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name) mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name)
layout = self._layout[mesh_dim] layout = self._layout[mesh_dim]
pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map) pg_ranks_by_dim = layout.remap_to_tensor(
self.mesh,
)
cur_rank = self.get_rank() cur_rank = self.get_rank()
res_submeshes = [] res_submeshes = []
for mesh_1d in pg_ranks_by_dim: for mesh_1d in pg_ranks_by_dim:
@ -908,7 +896,6 @@ else:
backend_override=backend_override, backend_override=backend_override,
_init_backend=_init_backend, _init_backend=_init_backend,
_layout=_layout, _layout=_layout,
_root_mesh=_root_mesh,
) )
if cur_rank in mesh_nd: if cur_rank in mesh_nd:
res_mesh = mesh res_mesh = mesh
@ -917,6 +904,8 @@ else:
f"Current rank {cur_rank} not found in any mesh, " f"Current rank {cur_rank} not found in any mesh, "
f"input {pg_ranks_by_dim} does not contain all ranks in the world" f"input {pg_ranks_by_dim} does not contain all ranks in the world"
) )
if _root_mesh is not None:
res_mesh._root_mesh = _root_mesh
return res_mesh return res_mesh
@staticmethod @staticmethod
@ -1015,17 +1004,15 @@ else:
return device_mesh return device_mesh
def size(self, mesh_dim: Optional[int] = None) -> int: def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is not None: return self.mesh.numel() if mesh_dim is None else self.mesh.size(mesh_dim)
return self._layout[mesh_dim].numel()
return self._layout.numel()
@property @property
def ndim(self) -> int: def ndim(self) -> int:
return len(self._layout) return self.mesh.ndim
@property @property
def shape(self) -> tuple[int, ...]: def shape(self) -> tuple[int, ...]:
return self._layout.top_level_sizes return tuple(self.mesh.shape)
def get_rank(self) -> int: def get_rank(self) -> int:
""" """
@ -1064,7 +1051,7 @@ else:
""" """
if self.ndim > 1 and mesh_dim is None: if self.ndim > 1 and mesh_dim is None:
raise RuntimeError( raise RuntimeError(
f"Found the DeviceMesh have {len(self._layout)} dimensions", f"Found the DeviceMesh have {self.mesh.ndim} dimensions",
"Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.",
) )
elif mesh_dim is None: elif mesh_dim is None:
@ -1128,7 +1115,9 @@ else:
root_mesh = self._get_root_mesh() root_mesh = self._get_root_mesh()
cur_rank = self.get_rank() cur_rank = self.get_rank()
unflattened_layout = self._layout.unflatten(dim, mesh_sizes) unflattened_layout = self._layout.unflatten(dim, mesh_sizes)
pg_ranks_by_dim = unflattened_layout.remap_to_tensor(root_mesh._rank_map) pg_ranks_by_dim = unflattened_layout.remap_to_tensor(
root_mesh.mesh,
)
unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names)) unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names))
unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names) unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names)
res_mesh = DeviceMesh._create_mesh_from_ranks( res_mesh = DeviceMesh._create_mesh_from_ranks(
@ -1152,7 +1141,7 @@ else:
tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index] tuple(unflattened_layout.strides[dim : dim + unflatten_length]), # type: ignore[index]
) )
unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor( unflatten_pg_ranks_by_dim = unflatten_layout.remap_to_tensor(
root_mesh._rank_map root_mesh.mesh,
) )
unflatten_submesh = DeviceMesh._create_mesh_from_ranks( unflatten_submesh = DeviceMesh._create_mesh_from_ranks(
self.device_type, self.device_type,

View File

@ -640,9 +640,9 @@ def get_pip_packages(run_lambda, patterns=None):
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1" os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
# People generally have pip as `pip` or `pip3` # People generally have pip as `pip` or `pip3`
# But here it is invoked as `python -mpip` # But here it is invoked as `python -m pip`
out = run_and_read_all( out = run_and_read_all(
run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"] run_lambda, [sys.executable, "-m", "pip", "list", "--format=freeze"]
) )
if out is None: if out is None:
return pip_version, out return pip_version, out