mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 20:55:34 +08:00
Compare commits
166 Commits
upload-tes
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| a38b1ee364 | |||
| a51208c656 | |||
| ed4aa449b6 | |||
| 9eebda944d | |||
| 09d8953fb4 | |||
| 8b2365094d | |||
| 7b423c2d21 | |||
| 46b3f913b3 | |||
| f7b7f40a6f | |||
| 91337ae3ff | |||
| eea951758f | |||
| 3feea296a5 | |||
| c3c3653418 | |||
| f72772b184 | |||
| 981dd71893 | |||
| d31599f40b | |||
| 85fab6c9b0 | |||
| c08ce30d18 | |||
| e1a1aeaf5b | |||
| 943227f57b | |||
| 3a2d75a086 | |||
| 69af74972b | |||
| 7432676187 | |||
| fd5edda1ed | |||
| 872d1daec2 | |||
| 6cd57e6fc2 | |||
| d29efba8fa | |||
| a344069f2a | |||
| af829c0dad | |||
| 3869aa115b | |||
| 47eb34b7ac | |||
| 08200280ce | |||
| ad7a57262c | |||
| 711a775878 | |||
| e9a688f02e | |||
| e69aaaf45a | |||
| fd8f368d31 | |||
| 13d2cc7bd2 | |||
| c6c913d18e | |||
| ef3f953966 | |||
| ea44f12bce | |||
| a74fe75c45 | |||
| 6d30666bc1 | |||
| 8e8cbb85ee | |||
| fbd70fb84e | |||
| 6c5db82584 | |||
| 6052a01b71 | |||
| 14b153bcf2 | |||
| 641de23c96 | |||
| 89165c0a2b | |||
| dcc2ba4ca4 | |||
| ad5c7c20e0 | |||
| c86540f120 | |||
| c17aa0f113 | |||
| 4ff068c33a | |||
| 0c7a4a6b48 | |||
| f93ee16fb6 | |||
| 9c2c3dbc15 | |||
| d4dcd0354c | |||
| aba2fa3259 | |||
| d2d13bf62d | |||
| 7a6ff88196 | |||
| 59563dfe56 | |||
| 5c639466f7 | |||
| 0b4dd08e04 | |||
| edd8d356b6 | |||
| 658c5f879c | |||
| 59a6c83dfe | |||
| 431dfe8692 | |||
| c00696144d | |||
| 9ffc480c5a | |||
| 14956eaef4 | |||
| 066c5c57a9 | |||
| 08ef852a4b | |||
| 56fc99915b | |||
| 5863ba1b2e | |||
| a743f9eeb5 | |||
| 53b03f1a2b | |||
| cd5d810c3a | |||
| 01e6e35c7f | |||
| bcd159bcdd | |||
| 64ae31c5d3 | |||
| 45da6e1fe1 | |||
| 39160dba0c | |||
| f2fbc81c50 | |||
| 4271ffe918 | |||
| 7eefcfb1db | |||
| 4b12c0344d | |||
| 661b639663 | |||
| 0cd809f60c | |||
| a96728d188 | |||
| c1e91bd4c3 | |||
| d7e2d0ad30 | |||
| 81038fd326 | |||
| e020fb3431 | |||
| e8052f2f99 | |||
| a64c7d7404 | |||
| cdca63db8c | |||
| ed45c5f38d | |||
| 7f0e932136 | |||
| 2673f8b007 | |||
| 4e1bd16738 | |||
| 871d0cd196 | |||
| 2bba37309b | |||
| b4e4ee81d3 | |||
| 3283eaa5ba | |||
| 397d9fe2ae | |||
| d77c24caac | |||
| cef98ae5cb | |||
| 52ea135f77 | |||
| a5f3035aaf | |||
| 1d3f5e19da | |||
| 496277a8ff | |||
| 53f75cd5ba | |||
| 527b1109a8 | |||
| 3144713325 | |||
| eefa16342c | |||
| d02f68f484 | |||
| 68eb55c4b2 | |||
| 8d4b8ab430 | |||
| afd50bdd29 | |||
| 56dfd4c74b | |||
| 24db5c4451 | |||
| cc8bfd1206 | |||
| c45b156605 | |||
| 8fff7e36b4 | |||
| 82fa2aa269 | |||
| 09e0285608 | |||
| d980d8dc79 | |||
| c7d00de115 | |||
| d3cf90ada5 | |||
| 0e1a88904f | |||
| 3232caa078 | |||
| a6c6acea9d | |||
| 55be1cc739 | |||
| 344cebda52 | |||
| ba72c6b981 | |||
| 888efcc453 | |||
| 24aa9a2ef7 | |||
| f70faf2b9a | |||
| 167e64ba1a | |||
| 875b18d53c | |||
| eec3749c44 | |||
| 40133fe966 | |||
| f288433d3e | |||
| 864633fca0 | |||
| c21868b435 | |||
| a0a8eca01a | |||
| 0958f307d9 | |||
| 7551507c41 | |||
| f92834d477 | |||
| e1fc01bef8 | |||
| 22a745737a | |||
| ee708ea96c | |||
| 64819e3701 | |||
| 79ff2c66c8 | |||
| 665a411351 | |||
| 5c89bdb461 | |||
| 7b64ad906c | |||
| d944279def | |||
| 5048e4701d | |||
| 616314cfd5 | |||
| 2b7e4c3ef2 | |||
| 6c98657239 | |||
| 86b2d82e84 | |||
| d6b27c4cef |
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
|
||||
RUN yum -y update
|
||||
RUN yum -y install epel-release
|
||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||
RUN yum -y install glibc-langpack-en
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
||||
# Install CUDA
|
||||
FROM base as cuda
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
RUN rm -rf /usr/local/cuda-*
|
||||
ADD ./common/install_cuda.sh install_cuda.sh
|
||||
COPY ./common/install_nccl.sh install_nccl.sh
|
||||
@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
||||
# Preserve CUDA_VERSION for the builds
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
# Make things in our path by default
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
|
||||
FROM cuda as cuda12.6
|
||||
RUN bash ./install_cuda.sh 12.6
|
||||
@ -68,8 +70,22 @@ FROM cuda as cuda13.0
|
||||
RUN bash ./install_cuda.sh 13.0
|
||||
ENV DESIRED_CUDA=13.0
|
||||
|
||||
FROM ${ROCM_IMAGE} as rocm
|
||||
FROM ${ROCM_IMAGE} as rocm_base
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
# Install devtoolset on ROCm base image
|
||||
RUN yum -y update && \
|
||||
yum -y install epel-release && \
|
||||
yum -y install glibc-langpack-en && \
|
||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
FROM rocm_base as rocm
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ADD ./common/install_mkl.sh install_mkl.sh
|
||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||
@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
||||
|
||||
# Final step
|
||||
FROM ${BASE_TARGET} as final
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||
COPY --from=conda /opt/conda /opt/conda
|
||||
|
||||
@ -63,7 +63,7 @@ docker build \
|
||||
--target final \
|
||||
--progress plain \
|
||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
||||
${EXTRA_BUILD_ARGS} \
|
||||
-t ${tmp_tag} \
|
||||
$@ \
|
||||
|
||||
@ -261,9 +261,9 @@ case "$tag" in
|
||||
PYTHON_VERSION=3.10
|
||||
CUDA_VERSION=12.8.1
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11)
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc13)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
GCC_VERSION=13
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
@ -271,9 +271,19 @@ case "$tag" in
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
CLANG_VERSION=21
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
# snadampal: skipping llvm src build install because the current version
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=13
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
|
||||
@ -1 +1 @@
|
||||
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
||||
|
||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
||||
# work around ubuntu apt-get conflicts
|
||||
sudo apt-get -y -f install
|
||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||
if [[ $CLANG_VERSION == 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
|
||||
# Need the official toolchain repo to get alternate packages
|
||||
add-apt-repository ppa:ubuntu-toolchain-r/test
|
||||
apt-get update
|
||||
apt-get install -y g++-$GCC_VERSION
|
||||
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
||||
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
||||
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
||||
|
||||
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
|
||||
|
||||
# Cleanup package manager
|
||||
apt-get autoclean && apt-get clean
|
||||
|
||||
@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
CC=gcc
|
||||
NUM_THREADS=128
|
||||
USE_OPENMP=1
|
||||
NO_SHARED=0
|
||||
|
||||
@ -1 +1 @@
|
||||
3.5.0
|
||||
3.5.1
|
||||
|
||||
@ -6,8 +6,8 @@ set -eou pipefail
|
||||
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
|
||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||
# https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||
|
||||
# Folders for the build
|
||||
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
|
||||
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
|
||||
|
||||
# Fetch magma sources and verify checksum
|
||||
pushd ${PACKAGE_DIR}
|
||||
git clone https://github.com/icl-utk-edu/magma
|
||||
git clone https://github.com/jeffdaily/magma
|
||||
pushd magma
|
||||
git checkout ${MAGMA_VERSION}
|
||||
popd
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
name: 🚀 Release highlight for proposed Feature
|
||||
name: 🚀 New Feature for Release
|
||||
description: Submit a Release highlight for proposed Feature
|
||||
labels: ["release-feature-request"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Release highlight for proposed Feature
|
||||
label: New Feature for Release
|
||||
description: >
|
||||
Example: “A torch.special module, analogous to SciPy's special module.”
|
||||
- type: input
|
||||
|
||||
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--download \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--repo $REPO \
|
||||
--bucket $BUCKET \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--repo "$REPO" \
|
||||
--bucket "$BUCKET" \
|
||||
|
||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--upload \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--sha $SHA \
|
||||
--test_config $TEST_CONFIG \
|
||||
--shard $SHARD \
|
||||
--repo $REPO \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--sha "$SHA" \
|
||||
--test_config "$TEST_CONFIG" \
|
||||
--shard "$SHARD" \
|
||||
--repo "$REPO" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
|
||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
||||
|
||||
125
.github/copilot-instructions.md
vendored
Normal file
125
.github/copilot-instructions.md
vendored
Normal file
@ -0,0 +1,125 @@
|
||||
# PyTorch Copilot Instructions
|
||||
|
||||
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
|
||||
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
|
||||
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
|
||||
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
|
||||
- **torch/** - Python bindings and public API
|
||||
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
|
||||
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
|
||||
- `torch/csrc/jit/` - TorchScript JIT compiler
|
||||
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
|
||||
- **tools/** - Build scripts, autograd derivatives, code generation
|
||||
|
||||
### The Code Generation Workflow
|
||||
|
||||
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
|
||||
1. Declares operator signatures, variants (function/method), and dispatch behavior
|
||||
2. Gets processed by `torchgen/` to generate C++/Python bindings
|
||||
3. Produces headers in `build/aten/src/ATen/` during compilation
|
||||
|
||||
Example entry structure:
|
||||
```yaml
|
||||
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: my_op_cpu
|
||||
CUDA: my_op_cuda
|
||||
```
|
||||
|
||||
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
|
||||
|
||||
## Development Workflows
|
||||
|
||||
### Building from Source
|
||||
|
||||
**Never run `setup.py` directly** - use pip with editable install:
|
||||
```bash
|
||||
python -m pip install --no-build-isolation -v -e .
|
||||
```
|
||||
|
||||
Speed up builds:
|
||||
- `DEBUG=1` - Debug symbols with `-g -O0`
|
||||
- `USE_CUDA=0` - Skip CUDA compilation
|
||||
- `BUILD_TEST=0` - Skip C++ test binaries
|
||||
- Install `ninja` (`pip install ninja`) for faster builds
|
||||
- Use `ccache` for incremental compilation caching
|
||||
|
||||
Rebuild specific targets: `(cd build && ninja <target>)`
|
||||
|
||||
### Testing
|
||||
|
||||
**Critical**: DO NOT run entire test suites. Run specific tests only:
|
||||
```bash
|
||||
python test/test_torch.py TestTorch.test_specific_case
|
||||
```
|
||||
|
||||
**Test structure**: All tests use `torch.testing._internal.common_utils`:
|
||||
```python
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestFeature(TestCase):
|
||||
def test_something(self):
|
||||
# Use self.assertEqual for tensor comparisons
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
```
|
||||
|
||||
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
|
||||
|
||||
### Linting
|
||||
|
||||
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
### Memory and Storage
|
||||
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
|
||||
- CUDA device info lives in storage objects
|
||||
|
||||
### Python-C++ Integration (`torch/csrc/`)
|
||||
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
|
||||
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
|
||||
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
|
||||
|
||||
### Dispatch System
|
||||
- PyTorch uses operator dispatch to route calls to backend-specific kernels
|
||||
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
|
||||
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
|
||||
|
||||
## Git Workflow (AI Agent Specific)
|
||||
|
||||
When preparing PRs from this environment:
|
||||
```bash
|
||||
git stash -u
|
||||
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
|
||||
git stash pop
|
||||
# Resolve conflicts if necessary
|
||||
```
|
||||
|
||||
## Common Gotchas
|
||||
|
||||
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
|
||||
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
|
||||
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
|
||||
4. **No internet access** - DO NOT attempt to install dependencies during development
|
||||
|
||||
## Key Files Reference
|
||||
|
||||
- `AGENTS.md` - Instructions specific to AI coding agents
|
||||
- `CONTRIBUTING.md` - Comprehensive human contributor guide
|
||||
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
|
||||
- `aten/src/ATen/native/README.md` - Operator implementation guide
|
||||
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
|
||||
|
||||
## Performance Debugging
|
||||
|
||||
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.
|
||||
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.2",
|
||||
"13.0": "13.0.0",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -97,8 +97,8 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
||||
if [[ $ngpu -lt 4 ]]; then
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
|
||||
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
16
.github/workflows/_xpu-test.yml
vendored
16
.github/workflows/_xpu-test.yml
vendored
@ -344,5 +344,21 @@ jobs:
|
||||
if-no-files-found: ignore
|
||||
path: ./**/core.[1-9]*
|
||||
|
||||
- name: Authenticate with AWS
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
||||
# The max duration enforced by the server side
|
||||
role-duration-seconds: 18000
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Upload the benchmark results
|
||||
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
|
||||
with:
|
||||
benchmark-results-dir: test/test-reports
|
||||
dry-run: false
|
||||
schema-version: v3
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Teardown XPU
|
||||
uses: ./.github/actions/teardown-xpu
|
||||
|
||||
6
.github/workflows/docker-builds.yml
vendored
6
.github/workflows/docker-builds.yml
vendored
@ -77,9 +77,11 @@ jobs:
|
||||
pytorch-linux-noble-riscv64-py3.12-gcc14
|
||||
]
|
||||
include:
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
timeout-minutes: 600
|
||||
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
2
.github/workflows/linux-aarch64.yml
vendored
2
.github/workflows/linux-aarch64.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
with:
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
|
||||
2
.github/workflows/operator_benchmark.yml
vendored
2
.github/workflows/operator_benchmark.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
||||
with:
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||
|
||||
3
.github/workflows/trunk.yml
vendored
3
.github/workflows/trunk.yml
vendored
@ -204,6 +204,7 @@ jobs:
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -221,7 +222,7 @@ jobs:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
secrets: inherit
|
||||
|
||||
inductor-build:
|
||||
|
||||
@ -211,7 +211,6 @@ exclude_patterns = [
|
||||
'**/*pb.h',
|
||||
'**/*inl.h',
|
||||
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
||||
'aten/src/ATen/cpu/Utils.cpp',
|
||||
'aten/src/ATen/cpu/vml.h',
|
||||
'aten/src/ATen/CPUFixedAllocator.h',
|
||||
'aten/src/ATen/Parallel*.h',
|
||||
@ -230,8 +229,6 @@ exclude_patterns = [
|
||||
'c10/util/win32-headers.h',
|
||||
'c10/test/**/*.h',
|
||||
'third_party/**/*',
|
||||
'torch/csrc/api/include/torch/nn/modules/common.h',
|
||||
'torch/csrc/api/include/torch/linalg.h',
|
||||
'torch/csrc/autograd/generated/**',
|
||||
'torch/csrc/distributed/**/*.cu',
|
||||
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
||||
@ -243,7 +240,6 @@ exclude_patterns = [
|
||||
'torch/csrc/utils/generated_serialization_types.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||
'aten/src/ATen/ExpandBase.h',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
||||
@ -234,7 +234,17 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
|
||||
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
||||
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
||||
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
||||
|
||||
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
|
||||
# If USE_CUDA is already defined in cache, it means user explicitly set it
|
||||
if(DEFINED CACHE{USE_CUDA})
|
||||
set(_USE_CUDA_EXPLICITLY_SET TRUE)
|
||||
else()
|
||||
set(_USE_CUDA_EXPLICITLY_SET FALSE)
|
||||
endif()
|
||||
|
||||
option(USE_CUDA "Use CUDA" ON)
|
||||
|
||||
option(USE_XPU "Use XPU" ON)
|
||||
cmake_dependent_option(
|
||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||
|
||||
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
||||
- [Python Unit Testing](#python-unit-testing)
|
||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||
- [Local linting](#local-linting)
|
||||
- [Running `mypy`](#running-mypy)
|
||||
- [Running `pyrefly`](#running-pyrefly)
|
||||
- [C++ Unit Testing](#c-unit-testing)
|
||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||
- [Merging your Change](#merging-your-change)
|
||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
||||
**Prerequisites**:
|
||||
The following packages should be installed with `pip`:
|
||||
- `expecttest` and `hypothesis` - required to run tests
|
||||
- `mypy` - recommended for linting
|
||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
||||
- `pytest` - recommended to run tests more selectively
|
||||
Running
|
||||
```
|
||||
@ -350,15 +350,32 @@ make lint
|
||||
|
||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||
|
||||
#### Running `mypy`
|
||||
#### Running `pyrefly`
|
||||
|
||||
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
||||
|
||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
||||
|
||||
**Getting Started with Pyrefly:**
|
||||
|
||||
To run type checking on the PyTorch codebase:
|
||||
```bash
|
||||
pyrefly check
|
||||
```
|
||||
|
||||
For more detailed error information with summaries:
|
||||
```bash
|
||||
pyrefly check --summarize-errors
|
||||
```
|
||||
|
||||
**Learn More:**
|
||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
||||
|
||||
See [Guide for adding type annotations to
|
||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||
for more information on how to set up `mypy` and tackle type annotation
|
||||
tasks.
|
||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
||||
|
||||
### C++ Unit Testing
|
||||
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -41,16 +39,6 @@ namespace {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
|
||||
auto vals = svreinterpret_u16_bf16(values);
|
||||
vals = sveor_u16_x(ptrue, vals, mask);
|
||||
return svreinterpret_bf16_u16(vals);
|
||||
};
|
||||
}
|
||||
Vectorized<BFloat16> round() const;
|
||||
Vectorized<BFloat16> tan() const;
|
||||
Vectorized<BFloat16> tanh() const;
|
||||
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
||||
return convert_float_bfloat16(v1, v2); \
|
||||
}
|
||||
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
||||
const Vectorized<BFloat16>& other) const {
|
||||
|
||||
@ -191,22 +191,37 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
|
||||
}
|
||||
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
|
||||
template <typename to_type>
|
||||
inline void convertFromBf16Impl(
|
||||
const c10::BFloat16* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
|
||||
float tmpF;
|
||||
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
|
||||
dst[i] = static_cast<to_type>(tmpF);
|
||||
}
|
||||
}
|
||||
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
|
||||
template <> \
|
||||
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
|
||||
return convertFromBf16Impl<to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int32_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int64_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float)
|
||||
CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
@ -247,8 +262,6 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
uint32_t mask =
|
||||
mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
|
||||
@ -55,6 +55,14 @@ struct numeric_limits<int8_t> {
|
||||
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint16_t> {
|
||||
static inline __host__ __device__ uint16_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
|
||||
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
||||
@ -63,6 +71,14 @@ struct numeric_limits<int16_t> {
|
||||
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint32_t> {
|
||||
static inline __host__ __device__ uint32_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
|
||||
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
||||
@ -71,6 +87,21 @@ struct numeric_limits<int32_t> {
|
||||
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<uint64_t> {
|
||||
#ifdef _MSC_VER
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
|
||||
#else
|
||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
||||
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
|
||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
||||
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<int64_t> {
|
||||
#ifdef _MSC_VER
|
||||
|
||||
@ -24,7 +24,13 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor send_to_meta(const Tensor& self, const Device& device) {
|
||||
Tensor out_meta;
|
||||
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
|
||||
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
} else {
|
||||
out_meta = self.to(device);
|
||||
}
|
||||
return out_meta;
|
||||
}
|
||||
|
||||
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
|
||||
auto out_device = correct_out_device(self, other);
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
||||
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
|
||||
}
|
||||
|
||||
@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
||||
|
||||
if (self._is_zerotensor()) {
|
||||
if (other._is_zerotensor()) {
|
||||
@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto meta_out = at::_ops::add_Tensor::redispatch(
|
||||
meta_dks, self.to(device_), other.to(device_), alpha);
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
|
||||
|
||||
auto get_out_like = [&] (const Tensor& tensor)
|
||||
{
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
@ -1710,11 +1711,37 @@ Tensor narrow_symint(
|
||||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
|
||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
||||
|
||||
if (cond1 || cond2) {
|
||||
if (cond1) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
length,
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
}
|
||||
|
||||
// Unbacked start handling!
|
||||
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
@ -1722,7 +1749,28 @@ Tensor narrow_symint(
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
@ -1736,8 +1784,8 @@ Tensor narrow_tensor_symint(
|
||||
start.dim() == 0 &&
|
||||
isIntegralType(start.scalar_type(), /*includeBool=*/false),
|
||||
"start must be an 0-dim integral Tensor.");
|
||||
int64_t st = start.item<int64_t>();
|
||||
return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
|
||||
c10::SymInt st = start.item().toSymInt();
|
||||
return at::narrow_symint(self, dim, std::move(st), std::move(length));
|
||||
}
|
||||
|
||||
std::
|
||||
|
||||
@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
|
||||
, empty(size <= 0) {}
|
||||
|
||||
inline Vec unnormalize(const Vec &in) const {
|
||||
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
|
||||
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
|
||||
}
|
||||
|
||||
inline Vec clip_coordinates(const Vec &in) const {
|
||||
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
|
||||
|
||||
// constant used in cubic convolution
|
||||
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
|
||||
const Vec A = Vec(-0.75);
|
||||
const Vec A = Vec(static_cast<scalar_t>(-0.75));
|
||||
|
||||
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
|
||||
: inp_H(input.size(2))
|
||||
|
||||
@ -92,7 +92,8 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
|
||||
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (dtype == kBFloat16) {
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
|
||||
auto norm_val = norm.to<float>();
|
||||
float beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<float>(norm_val);
|
||||
@ -101,9 +102,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
const auto zero_vec = Vectorized<float>(0);
|
||||
const auto pos_1_vec = Vectorized<float>(1);
|
||||
cpu_kernel_vec(iter,
|
||||
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
|
||||
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
|
||||
const auto x = float(input) - float(target);
|
||||
if (x <= -beta){
|
||||
if (x <= -beta) {
|
||||
return -norm_val * float(grad_output);
|
||||
}else if (x >= beta){
|
||||
return norm_val * float(grad_output);
|
||||
@ -112,14 +113,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
}
|
||||
},
|
||||
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
|
||||
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
|
||||
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
|
||||
// using two blendv calls to simulate the 3 cases
|
||||
// 1 if x >= beta
|
||||
// -1 if x <= -beta
|
||||
// x / beta if |x| < beta
|
||||
auto [input0, input1] = convert_bfloat16_float(input);
|
||||
auto [target0, target1] = convert_bfloat16_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
|
||||
auto [input0, input1] = convert_to_float(input);
|
||||
auto [target0, target1] = convert_to_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
|
||||
auto x = input0 - target0;
|
||||
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
|
||||
neg_1_vec, pos_1_vec, x > zero_vec);
|
||||
@ -135,9 +136,10 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
output = Vectorized<float>::blendv(
|
||||
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
|
||||
input1 = norm_val_vec * output * grad_output1;
|
||||
return convert_float_bfloat16(input0, input1);
|
||||
return convert_from_float<scalar_t>(input0, input1);
|
||||
}
|
||||
);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
|
||||
@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename func_t, typename vec_func_t>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
|
||||
template <typename func_t, typename vec_func_t, typename ident_t = double>
|
||||
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
|
||||
using traits = binary_function_traits<func_t>;
|
||||
static_assert(
|
||||
all_same<
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); });
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) {
|
||||
reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
|
||||
[=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
|
||||
[=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); });
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,7 +200,7 @@ void aminmax_allreduce_kernel(
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
|
||||
AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
|
||||
reduce_all_impl_vec_two_outputs<scalar_t>(
|
||||
@ -214,7 +215,7 @@ void aminmax_allreduce_kernel(
|
||||
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
|
||||
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
|
||||
);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
@ -338,43 +339,24 @@ void or_kernel_impl(TensorIterator& iter) {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
struct MinValuesOps: public at::native::MinOps<scalar_t> {
|
||||
using arg_t = typename MinOps<scalar_t>::arg_t;
|
||||
static scalar_t project(arg_t arg) {
|
||||
return arg.first;
|
||||
}
|
||||
};
|
||||
|
||||
void min_values_kernel_impl(TensorIterator& iter) {
|
||||
if (iter.dtype() == kLong) {
|
||||
// This case is special because of Vectorized<int64_t> does not
|
||||
// handle upper_bound<int64_t>().
|
||||
// See: https://github.com/pytorch/pytorch/issues/43254
|
||||
using scalar_t = int64_t;
|
||||
binary_kernel_reduce(
|
||||
iter,
|
||||
MinValuesOps<scalar_t>{},
|
||||
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
|
||||
return;
|
||||
}
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
|
||||
static_cast<double>(upper_bound<scalar_t>()));
|
||||
});
|
||||
upper_bound<scalar_t>());
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void max_values_kernel_impl(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] {
|
||||
AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] {
|
||||
binary_kernel_reduce_vec(
|
||||
iter,
|
||||
[](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); },
|
||||
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return maximum(a, b); },
|
||||
lower_bound<scalar_t>());
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void argmax_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/TensorIterator.h>
|
||||
@ -106,7 +107,7 @@ void min_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] {
|
||||
AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -128,7 +129,7 @@ void min_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
}
|
||||
|
||||
void max_kernel_impl(
|
||||
@ -139,7 +140,7 @@ void max_kernel_impl(
|
||||
bool keepdim) {
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, dim);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] {
|
||||
AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] {
|
||||
compare_base_kernel<scalar_t>(result, indice, self, dim, keepdim, [&] (
|
||||
scalar_t* result_data, int64_t* indice_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -161,7 +162,7 @@ void max_kernel_impl(
|
||||
*indice_data = index;
|
||||
}
|
||||
);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool);
|
||||
}
|
||||
|
||||
void aminmax_kernel(
|
||||
@ -186,7 +187,7 @@ void aminmax_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
|
||||
AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] {
|
||||
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
|
||||
scalar_t* min_result_data, scalar_t* max_result_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
@ -209,7 +210,7 @@ void aminmax_kernel(
|
||||
*max_result_data = max_number;
|
||||
}
|
||||
);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half);
|
||||
}
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter) {
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -666,12 +669,19 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -74,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
||||
|
||||
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
|
||||
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
|
||||
|
||||
if (is_gather_like && num_indices==1) {
|
||||
const size_t element_size = iter.element_size(0);
|
||||
constexpr size_t alignment = 16;
|
||||
@ -84,16 +82,9 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
||||
auto ind_dim_size = index_size[0];
|
||||
auto inp_stride_bytes = index_stride[0];
|
||||
auto out_stride_bytes = iter.strides(0)[1];
|
||||
// avoid grid overflow in the fast kernel
|
||||
const int64_t vec_chunks = ceil_div(slice_size, alignment);
|
||||
const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd);
|
||||
const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
// if it's an eligible grid we use the fast path, otherwise default to slower path
|
||||
if (blocks_per_slice_upper <= max_grid_y) {
|
||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||
return;
|
||||
}
|
||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -13,11 +13,12 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
|
||||
if (allow_neg_indices) {
|
||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||
}
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
|
||||
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
|
||||
if (off >= slice_size) return;
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
|
||||
// off is guaranteed to be within int32 limits
|
||||
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
|
||||
auto num_threads = at::round_up(
|
||||
at::ceil_div(slice_size_in_bytes, Alignment),
|
||||
static_cast<int64_t>(C10_WARP_SIZE));
|
||||
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
||||
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
|
||||
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
|
||||
auto block = std::min(max_num_threads, num_threads);
|
||||
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -28,22 +29,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void aminmax_allreduce_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] {
|
||||
_min_max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void aminmax_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinMaxOps<scalar_t, scalar_t, int32_t>{},
|
||||
thrust::pair<scalar_t, scalar_t>(
|
||||
at::numeric_limits<scalar_t>::upper_bound(),
|
||||
at::numeric_limits<scalar_t>::lower_bound()));
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/ReduceAllOps.h>
|
||||
@ -33,27 +34,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void max_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "max_values_cuda", AT_WRAP([&]() {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void max_launch_kernel(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() {
|
||||
AT_DISPATCH_V2(
|
||||
iter.input_dtype(), "max_cuda", AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MaxOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(
|
||||
at::numeric_limits<scalar_t>::lower_bound(), 0));
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void max_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] {
|
||||
max_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda)
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <ATen/NumericUtils.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
|
||||
@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) {
|
||||
}
|
||||
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void min_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(at::numeric_limits<scalar_t>::upper_bound(), 0));
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
void min_all_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
});
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda)
|
||||
|
||||
@ -59,6 +59,24 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -806,7 +797,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
mat_a.sizes()[1],
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
19
aten/src/ATen/native/hip/ck_group_gemm.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
462
aten/src/ATen/native/hip/ck_group_gemm.hip
Normal file
@ -0,0 +1,462 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -157,10 +157,10 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return true;
|
||||
|
||||
dnnl_dims_t blocks = {0};
|
||||
int perm[DNNL_MAX_NDIMS] = {0};
|
||||
std::array<int, DNNL_MAX_NDIMS> perm = {0};
|
||||
for (int d = 0; d < md_ndims; ++d) {
|
||||
// no strides check needed for empty tensor
|
||||
if (md_padded_dims[d] == nullptr)
|
||||
if ((*md_padded_dims)[d] == 0)
|
||||
return true;
|
||||
|
||||
// no strides verification for runtime dims
|
||||
@ -178,14 +178,15 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
|
||||
// A custom comparator to yield linear order on perm
|
||||
auto idx_sorter = [&](const int a, const int b) -> bool {
|
||||
if (strides[a] == strides[b] && md_padded_dims[a] == md_padded_dims[b])
|
||||
if (strides[a] == strides[b] &&
|
||||
(*md_padded_dims)[a] == (*md_padded_dims)[b])
|
||||
return a < b;
|
||||
else if (strides[a] == strides[b])
|
||||
return md_padded_dims[a] < md_padded_dims[b];
|
||||
return (*md_padded_dims)[a] < (*md_padded_dims)[b];
|
||||
else
|
||||
return strides[a] < strides[b];
|
||||
};
|
||||
std::sort(perm, perm + md_ndims, idx_sorter);
|
||||
std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter);
|
||||
|
||||
auto min_stride = block_size;
|
||||
for (int idx = 0; idx < md_ndims; ++idx) {
|
||||
@ -199,9 +200,10 @@ bool onednn_strides_check(const Tensor& src) {
|
||||
return false;
|
||||
|
||||
// update min_stride for next iteration
|
||||
const auto padded_dim = *md_padded_dims[d];
|
||||
const auto padded_dim = (*md_padded_dims)[d];
|
||||
min_stride = block_size * strides[d] * (padded_dim / blocks[d]);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -212,17 +212,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({}));
|
||||
TORCH_CHECK(loss.is_mps());
|
||||
|
||||
Tensor loss_squeezed = loss.squeeze();
|
||||
Tensor input_squeezed = input.squeeze();
|
||||
Tensor target_squeezed = target.squeeze();
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key =
|
||||
op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight});
|
||||
std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight});
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed);
|
||||
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed);
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
|
||||
newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
|
||||
MPSGraphTensor* bceLossUnweighted = nil;
|
||||
// if grad_output is defined, then it's a backward pass
|
||||
@ -252,12 +247,12 @@ static Tensor& bce_loss_out_impl(const Tensor& input,
|
||||
newCachedGraph->gradInputTensor = bceLoss;
|
||||
}
|
||||
} else {
|
||||
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size());
|
||||
newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size());
|
||||
}
|
||||
});
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed);
|
||||
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed);
|
||||
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed);
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
|
||||
Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target);
|
||||
Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
|
||||
@ -370,7 +365,7 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
|
||||
onValue:-1.0f
|
||||
offValue:0.0f
|
||||
name:nil];
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, inputTensor.dataType);
|
||||
oneHotTensor = castMPSTensor(mpsGraph, oneHotTensor, [inputTensor dataType]);
|
||||
if (isWeightsArrayValid) {
|
||||
oneHotTensor = [mpsGraph multiplicationWithPrimaryTensor:oneHotTensor
|
||||
secondaryTensor:weightTensor
|
||||
@ -705,6 +700,7 @@ static void smooth_l1_loss_template(const Tensor& input,
|
||||
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.");
|
||||
TORCH_CHECK(input.is_mps());
|
||||
TORCH_CHECK(target.is_mps());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64");
|
||||
if ((input.numel() == 0) || (target.numel() == 0)) {
|
||||
reduction == Reduction::Mean ? output.fill_(std::numeric_limits<float>::quiet_NaN()) : output.zero_();
|
||||
return;
|
||||
@ -771,7 +767,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target);
|
||||
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
|
||||
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta dataType:[inputTensor dataType]];
|
||||
// xn - yn
|
||||
MPSGraphTensor* diffTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor
|
||||
secondaryTensor:targetTensor
|
||||
@ -797,7 +793,8 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
|
||||
name:@"lossTensor"];
|
||||
MPSGraphTensor* outputTensor = lossTensor;
|
||||
if (reduction == Reduction::Mean) {
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel() dataType:MPSDataTypeFloat32];
|
||||
MPSGraphTensor* numelTensor = [mpsGraph constantWithScalar:(double)input.numel()
|
||||
dataType:[lossTensor dataType]];
|
||||
outputTensor = [mpsGraph divisionWithPrimaryTensor:lossTensor secondaryTensor:numelTensor name:nil];
|
||||
}
|
||||
MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
|
||||
load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test")
|
||||
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||
load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX")
|
||||
load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX")
|
||||
|
||||
# Shared by internal and OSS BUCK
|
||||
def define_qnnpack(third_party, labels = []):
|
||||
@ -21,7 +21,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O2",
|
||||
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
|
||||
@ -82,7 +82,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O3",
|
||||
"-ffast-math",
|
||||
@ -129,7 +129,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O3",
|
||||
"-ffast-math",
|
||||
@ -184,7 +184,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O3",
|
||||
"-ffast-math",
|
||||
@ -236,7 +236,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
],
|
||||
),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
|
||||
],
|
||||
@ -291,7 +291,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "qnnpack/*.h"),
|
||||
("include", "*.h"),
|
||||
]),
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O2",
|
||||
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
|
||||
@ -398,7 +398,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O3",
|
||||
"-ffast-math",
|
||||
@ -465,7 +465,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "requantization/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
|
||||
"-Wno-unused-command-line-argument",
|
||||
@ -525,7 +525,7 @@ def define_qnnpack(third_party, labels = []):
|
||||
("src", "qnnpack/*.h"),
|
||||
]),
|
||||
header_namespace = "",
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = [
|
||||
"-O3",
|
||||
"-ffast-math",
|
||||
|
||||
63
aten/src/ATen/xpu/PeerToPeerAccess.cpp
Normal file
63
aten/src/ATen/xpu/PeerToPeerAccess.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
|
||||
namespace at::xpu {
|
||||
|
||||
// p2pAccessEnabled_ is a flattened 2D matrix of size [num_devices x
|
||||
// num_devices].
|
||||
// Each element represents whether device[i] can access device[j]:
|
||||
// 1 -> access allowed
|
||||
// 0 -> access not allowed
|
||||
// -1 -> unknown (not yet queried)
|
||||
static std::vector<int8_t> p2pAccessEnabled_;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Initializes the peer-to-peer (P2P) access capability cache.
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices) {
|
||||
// By default, each device can always access itself (diagonal entries = 1).
|
||||
// For simplicity, all entries are initialized to -1 except the diagonal.
|
||||
static bool once [[maybe_unused]] = [num_devices]() {
|
||||
p2pAccessEnabled_.clear();
|
||||
p2pAccessEnabled_.resize(num_devices * num_devices, -1);
|
||||
|
||||
for (const auto i : c10::irange(num_devices)) {
|
||||
p2pAccessEnabled_[i * num_devices + i] = 1;
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
bool get_p2p_access(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
at::globalContext().lazyInitDevice(c10::DeviceType::XPU);
|
||||
|
||||
check_device_index(dev);
|
||||
check_device_index(dev_to_access);
|
||||
|
||||
auto& cache =
|
||||
p2pAccessEnabled_[dev * c10::xpu::device_count() + dev_to_access];
|
||||
|
||||
if (cache != -1) {
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
// Query the hardware to determine if P2P access is supported
|
||||
cache = static_cast<int8_t>(
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_can_access_peer(
|
||||
c10::xpu::get_raw_device(dev_to_access),
|
||||
sycl::ext::oneapi::peer_access::access_supported));
|
||||
|
||||
if (cache) {
|
||||
XPUCachingAllocator::enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
return static_cast<bool>(cache);
|
||||
}
|
||||
|
||||
} // namespace at::xpu
|
||||
15
aten/src/ATen/xpu/PeerToPeerAccess.h
Normal file
15
aten/src/ATen/xpu/PeerToPeerAccess.h
Normal file
@ -0,0 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at::xpu {
|
||||
namespace detail {
|
||||
void init_p2p_access_cache(c10::DeviceIndex num_devices);
|
||||
} // namespace detail
|
||||
|
||||
TORCH_XPU_API bool get_p2p_access(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
} // namespace at::xpu
|
||||
@ -1,3 +1,4 @@
|
||||
#include <ATen/xpu/PeerToPeerAccess.h>
|
||||
#include <ATen/xpu/PinnedMemoryAllocator.h>
|
||||
#include <ATen/xpu/XPUContext.h>
|
||||
#include <ATen/xpu/XPUDevice.h>
|
||||
@ -12,6 +13,7 @@ void XPUHooks::init() const {
|
||||
C10_LOG_API_USAGE_ONCE("aten.init.xpu");
|
||||
const auto device_count = c10::xpu::device_count_ensure_non_zero();
|
||||
c10::xpu::XPUCachingAllocator::init(device_count);
|
||||
at::xpu::detail::init_p2p_access_cache(device_count);
|
||||
}
|
||||
|
||||
bool XPUHooks::hasXPU() const {
|
||||
|
||||
@ -53,10 +53,8 @@ class AddmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
return torch.addmm(input_one, mat1, mat2)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
addmm_long_configs + addmm_long_configs, AddmmBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark)
|
||||
|
||||
"""Mircobenchmark for addbmm operator."""
|
||||
|
||||
@ -107,9 +105,7 @@ addbmm_short_configs = op_bench.cross_product_configs(
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark)
|
||||
op_bench.generate_pt_gradient_test(
|
||||
addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark
|
||||
)
|
||||
op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
||||
@ -8,7 +8,7 @@ load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
|
||||
load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags")
|
||||
load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
|
||||
load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
|
||||
load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX")
|
||||
load("//tools/build_defs:platform_defs.bzl", "IOS", "MACOSX")
|
||||
load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
|
||||
load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build")
|
||||
load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build")
|
||||
@ -1090,7 +1090,7 @@ def define_buck_targets(
|
||||
srcs = [
|
||||
"caffe2/core/common.cc",
|
||||
],
|
||||
apple_sdks = (IOS, MACOSX, APPLETVOS),
|
||||
apple_sdks = (IOS, MACOSX),
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
labels = labels,
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
|
||||
@ -929,6 +929,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
"torch/csrc/dynamo/utils.cpp",
|
||||
"torch/csrc/dynamo/init.cpp",
|
||||
"torch/csrc/dynamo/stackref_bridge.c",
|
||||
"torch/csrc/functorch/init.cpp",
|
||||
"torch/csrc/fx/node.cpp",
|
||||
"torch/csrc/mps/Module.cpp",
|
||||
@ -1024,6 +1025,7 @@ libtorch_python_core_sources = [
|
||||
libtorch_python_distributed_core_sources = [
|
||||
"torch/csrc/distributed/c10d/init.cpp",
|
||||
"torch/csrc/distributed/c10d/python_comm_hook.cpp",
|
||||
"torch/csrc/distributed/c10d/python_callback_work.cpp",
|
||||
]
|
||||
|
||||
libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [
|
||||
|
||||
@ -59,6 +59,9 @@ constexpr DispatchKeySet nested_dispatch_keyset =
|
||||
{DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
|
||||
DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
|
||||
|
||||
constexpr DispatchKeySet functorch_batched_dispatch_keyset =
|
||||
DispatchKeySet(DispatchKey::FuncTorchBatched);
|
||||
|
||||
DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||
TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
|
||||
switch (t) {
|
||||
@ -77,6 +80,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
|
||||
return backend_dispatch_keyset;
|
||||
case DispatchKey::CompositeExplicitAutogradNonFunctional:
|
||||
return non_functional_backend_dispatch_keyset;
|
||||
case DispatchKey::FuncTorchBatchedDecomposition:
|
||||
return functorch_batched_dispatch_keyset;
|
||||
default:
|
||||
return DispatchKeySet(t);
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
#include <c10/core/SymBool.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymNodeImpl.h>
|
||||
|
||||
namespace c10 {
|
||||
@ -111,4 +112,17 @@ bool SymBool::has_hint() const {
|
||||
return toSymNodeImpl()->has_hint();
|
||||
}
|
||||
|
||||
SymInt SymBool::toSymInt() const {
|
||||
// If concrete bool, return concrete SymInt
|
||||
if (auto ma = maybe_as_bool()) {
|
||||
return SymInt(*ma ? 1 : 0);
|
||||
}
|
||||
|
||||
// Symbolic case: use sym_ite to convert bool to int (0 or 1)
|
||||
auto node = toSymNodeImpl();
|
||||
auto one_node = node->wrap_int(1);
|
||||
auto zero_node = node->wrap_int(0);
|
||||
return SymInt(node->sym_ite(one_node, zero_node));
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
|
||||
namespace c10 {
|
||||
|
||||
class SymInt;
|
||||
|
||||
class C10_API SymBool {
|
||||
public:
|
||||
/*implicit*/ SymBool(bool b) : data_(b) {}
|
||||
@ -80,6 +82,10 @@ class C10_API SymBool {
|
||||
return toSymNodeImplUnowned()->constant_bool();
|
||||
}
|
||||
|
||||
// Convert SymBool to SymInt (0 or 1)
|
||||
// This is the C++ equivalent of Python's cast_symbool_to_symint_guardless
|
||||
SymInt toSymInt() const;
|
||||
|
||||
bool is_heap_allocated() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
@ -40,200 +41,99 @@ namespace c10 {
|
||||
///
|
||||
/// This is intended to be trivially copyable, so it should be passed by
|
||||
/// value.
|
||||
///
|
||||
/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
|
||||
/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
|
||||
/// the underlying constexpr calls, we rely on apparent-type dispatch for
|
||||
/// inheritance. This should be fine because their memory format is the same,
|
||||
/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
|
||||
/// However, you should prefer to use ArrayRef when possible, because its use
|
||||
/// of TORCH_CHECK will lead to better user-facing error messages.
|
||||
template <typename T>
|
||||
class ArrayRef final {
|
||||
class ArrayRef final : public HeaderOnlyArrayRef<T> {
|
||||
public:
|
||||
using iterator = const T*;
|
||||
using const_iterator = const T*;
|
||||
using size_type = size_t;
|
||||
using value_type = T;
|
||||
|
||||
using reverse_iterator = std::reverse_iterator<iterator>;
|
||||
|
||||
private:
|
||||
/// The start of the array, in an external buffer.
|
||||
const T* Data;
|
||||
|
||||
/// The number of elements.
|
||||
size_type Length;
|
||||
|
||||
void debugCheckNullptrInvariant() {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
Data != nullptr || Length == 0,
|
||||
"created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal");
|
||||
}
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @name Constructors, all inherited from HeaderOnlyArrayRef except for
|
||||
/// SmallVector. As inherited constructors won't work with class template
|
||||
/// argument deduction (CTAD) until C++23, we add deduction guides after
|
||||
/// the class definition to enable CTAD.
|
||||
/// @{
|
||||
|
||||
/// Construct an empty ArrayRef.
|
||||
/* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}
|
||||
|
||||
/// Construct an ArrayRef from a single element.
|
||||
// TODO Make this explicit
|
||||
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
|
||||
|
||||
/// Construct an ArrayRef from a pointer and length.
|
||||
constexpr ArrayRef(const T* data, size_t length)
|
||||
: Data(data), Length(length) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a range.
|
||||
constexpr ArrayRef(const T* begin, const T* end)
|
||||
: Data(begin), Length(end - begin) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
|
||||
|
||||
/// Construct an ArrayRef from a SmallVector. This is templated in order to
|
||||
/// avoid instantiating SmallVectorTemplateCommon<T> whenever we
|
||||
/// copy-construct an ArrayRef.
|
||||
/// NOTE: this is the only constructor that is not inherited from
|
||||
/// HeaderOnlyArrayRef.
|
||||
template <typename U>
|
||||
/* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
template <
|
||||
typename Container,
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ ArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::vector.
|
||||
// The enable_if stuff here makes sure that this isn't used for
|
||||
// std::vector<bool>, because ArrayRef can't work on a std::vector<bool>
|
||||
// bitfield.
|
||||
template <typename A>
|
||||
/* implicit */ ArrayRef(const std::vector<T, A>& Vec)
|
||||
: Data(Vec.data()), Length(Vec.size()) {
|
||||
static_assert(
|
||||
!std::is_same_v<T, bool>,
|
||||
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
|
||||
}
|
||||
|
||||
/// Construct an ArrayRef from a std::array
|
||||
template <size_t N>
|
||||
/* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)
|
||||
: Data(Arr.data()), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a C array.
|
||||
template <size_t N>
|
||||
// NOLINTNEXTLINE(*c-arrays*)
|
||||
/* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
|
||||
|
||||
/// Construct an ArrayRef from a std::initializer_list.
|
||||
/* implicit */ constexpr ArrayRef(const std::initializer_list<T>& Vec)
|
||||
: Data(
|
||||
std::begin(Vec) == std::end(Vec) ? static_cast<T*>(nullptr)
|
||||
: std::begin(Vec)),
|
||||
Length(Vec.size()) {}
|
||||
: HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
|
||||
|
||||
/// @}
|
||||
/// @name Simple Operations
|
||||
/// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
|
||||
/// @{
|
||||
|
||||
constexpr iterator begin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr iterator end() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
// These are actually the same as iterator, since ArrayRef only
|
||||
// gives you const iterators.
|
||||
constexpr const_iterator cbegin() const {
|
||||
return Data;
|
||||
}
|
||||
constexpr const_iterator cend() const {
|
||||
return Data + Length;
|
||||
}
|
||||
|
||||
constexpr reverse_iterator rbegin() const {
|
||||
return reverse_iterator(end());
|
||||
}
|
||||
constexpr reverse_iterator rend() const {
|
||||
return reverse_iterator(begin());
|
||||
}
|
||||
|
||||
/// Check if all elements in the array satisfy the given expression
|
||||
constexpr bool allMatch(const std::function<bool(const T&)>& pred) const {
|
||||
return std::all_of(cbegin(), cend(), pred);
|
||||
}
|
||||
|
||||
/// empty - Check if the array is empty.
|
||||
constexpr bool empty() const {
|
||||
return Length == 0;
|
||||
}
|
||||
|
||||
constexpr const T* data() const {
|
||||
return Data;
|
||||
}
|
||||
|
||||
/// size - Get the array size.
|
||||
constexpr size_t size() const {
|
||||
return Length;
|
||||
}
|
||||
|
||||
/// front - Get the first element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& front() const {
|
||||
TORCH_CHECK(
|
||||
!empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return Data[0];
|
||||
!this->empty(), "ArrayRef: attempted to access front() of empty list");
|
||||
return this->Data[0];
|
||||
}
|
||||
|
||||
/// back - Get the last element.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& back() const {
|
||||
TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return Data[Length - 1];
|
||||
}
|
||||
|
||||
/// equals - Check for element-wise equality.
|
||||
constexpr bool equals(ArrayRef RHS) const {
|
||||
return Length == RHS.Length && std::equal(begin(), end(), RHS.begin());
|
||||
TORCH_CHECK(
|
||||
!this->empty(), "ArrayRef: attempted to access back() of empty list");
|
||||
return this->Data[this->Length - 1];
|
||||
}
|
||||
|
||||
/// slice(n, m) - Take M elements of the array starting at element N
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N, size_t M) const {
|
||||
TORCH_CHECK(
|
||||
N + M <= size(),
|
||||
N + M <= this->size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; M = ",
|
||||
M,
|
||||
"; size = ",
|
||||
size());
|
||||
return ArrayRef<T>(data() + N, M);
|
||||
this->size());
|
||||
return ArrayRef<T>(this->data() + N, M);
|
||||
}
|
||||
|
||||
/// slice(n) - Chop off the first N elements of the array.
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr ArrayRef<T> slice(size_t N) const {
|
||||
TORCH_CHECK(
|
||||
N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size());
|
||||
return slice(N, size() - N);
|
||||
N <= this->size(),
|
||||
"ArrayRef: invalid slice, N = ",
|
||||
N,
|
||||
"; size = ",
|
||||
this->size());
|
||||
return slice(N, this->size() - N); // should this slice be this->slice?
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Operator Overloads
|
||||
/// @{
|
||||
constexpr const T& operator[](size_t Index) const {
|
||||
return Data[Index];
|
||||
}
|
||||
|
||||
/// Vector compatibility
|
||||
/// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
|
||||
/// STD_TORCH_CHECK
|
||||
constexpr const T& at(size_t Index) const {
|
||||
TORCH_CHECK(
|
||||
Index < Length,
|
||||
Index < this->Length,
|
||||
"ArrayRef: invalid index Index = ",
|
||||
Index,
|
||||
"; Length = ",
|
||||
Length);
|
||||
return Data[Index];
|
||||
this->Length);
|
||||
return this->Data[Index];
|
||||
}
|
||||
|
||||
/// Disallow accidental assignment from a temporary.
|
||||
@ -253,16 +153,48 @@ class ArrayRef final {
|
||||
std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
|
||||
std::initializer_list<U>) = delete;
|
||||
|
||||
/// @}
|
||||
/// @name Expensive Operations
|
||||
/// @{
|
||||
std::vector<T> vec() const {
|
||||
return std::vector<T>(Data, Data + Length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
/// Deduction guides for ArrayRef to support CTAD with inherited constructors
|
||||
/// These mirror the constructors inherited from HeaderOnlyArrayRef
|
||||
/// @{
|
||||
|
||||
// Single element constructor
|
||||
template <typename T>
|
||||
ArrayRef(const T&) -> ArrayRef<T>;
|
||||
|
||||
// Pointer and length constructor
|
||||
template <typename T>
|
||||
ArrayRef(const T*, size_t) -> ArrayRef<T>;
|
||||
|
||||
// Range constructor (begin, end)
|
||||
template <typename T>
|
||||
ArrayRef(const T*, const T*) -> ArrayRef<T>;
|
||||
|
||||
// Generic container constructor (anything with .data() and .size())
|
||||
template <typename Container>
|
||||
ArrayRef(const Container&) -> ArrayRef<
|
||||
std::remove_pointer_t<decltype(std::declval<Container>().data())>>;
|
||||
|
||||
// std::vector constructor
|
||||
template <typename T, typename A>
|
||||
ArrayRef(const std::vector<T, A>&) -> ArrayRef<T>;
|
||||
|
||||
// std::array constructor
|
||||
template <typename T, size_t N>
|
||||
ArrayRef(const std::array<T, N>&) -> ArrayRef<T>;
|
||||
|
||||
// C array constructor
|
||||
template <typename T, size_t N>
|
||||
ArrayRef(const T (&)[N]) -> ArrayRef<T>;
|
||||
|
||||
// std::initializer_list constructor
|
||||
template <typename T>
|
||||
ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
|
||||
|
||||
/// @}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
|
||||
int i = 0;
|
||||
|
||||
@ -21,13 +21,20 @@ using stream_set = ska::flat_hash_set<xpu::XPUStream>;
|
||||
struct Block;
|
||||
typedef bool (*Comparison)(const Block*, const Block*);
|
||||
bool BlockComparatorSize(const Block* a, const Block* b);
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b);
|
||||
|
||||
struct BlockPool {
|
||||
BlockPool(bool small) : blocks(BlockComparatorSize), is_small(small) {}
|
||||
BlockPool(bool small)
|
||||
: blocks(BlockComparatorSize),
|
||||
unmapped(BlockComparatorAddress),
|
||||
is_small(small) {}
|
||||
std::set<Block*, Comparison> blocks;
|
||||
std::set<Block*, Comparison> unmapped;
|
||||
const bool is_small;
|
||||
};
|
||||
|
||||
struct ExpandableSegment;
|
||||
|
||||
struct Block {
|
||||
DeviceIndex device;
|
||||
sycl::queue* queue{nullptr}; // underlying queue of the allocation stream
|
||||
@ -37,9 +44,11 @@ struct Block {
|
||||
BlockPool* pool{nullptr}; // owning memory pool
|
||||
void* ptr{nullptr}; // memory address
|
||||
bool allocated{false}; // in-use flag
|
||||
bool mapped{true}; // True if this Block is backed by physical pages
|
||||
Block* prev{nullptr}; // prev block if split from a larger allocation
|
||||
Block* next{nullptr}; // next block if split from a larger allocation
|
||||
int event_count{0}; // number of outstanding XPU events
|
||||
ExpandableSegment* expandable_segment{nullptr}; // owning expandable segment
|
||||
|
||||
Block(
|
||||
DeviceIndex device,
|
||||
@ -66,6 +75,20 @@ struct Block {
|
||||
bool is_split() const {
|
||||
return (prev != nullptr) || (next != nullptr);
|
||||
}
|
||||
|
||||
// Inserts this block between two existing blocks with [before, this, after].
|
||||
void splice(Block* before, Block* after) {
|
||||
if (before) {
|
||||
TORCH_INTERNAL_ASSERT(before->next == after);
|
||||
before->next = this;
|
||||
}
|
||||
prev = before;
|
||||
if (after) {
|
||||
TORCH_INTERNAL_ASSERT(after->prev == before);
|
||||
after->prev = this;
|
||||
}
|
||||
next = after;
|
||||
}
|
||||
};
|
||||
|
||||
bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
@ -80,6 +103,221 @@ bool BlockComparatorSize(const Block* a, const Block* b) {
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
bool BlockComparatorAddress(const Block* a, const Block* b) {
|
||||
if (a->queue != b->queue) {
|
||||
return reinterpret_cast<uintptr_t>(a->queue) <
|
||||
reinterpret_cast<uintptr_t>(b->queue);
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(a->ptr) <
|
||||
reinterpret_cast<uintptr_t>(b->ptr);
|
||||
}
|
||||
|
||||
// Represents a contiguous virtual memory segment mapped for allocation.
|
||||
struct SegmentRange {
|
||||
SegmentRange(void* addr, size_t bytes)
|
||||
: ptr(static_cast<char*>(addr)), size(bytes) {}
|
||||
char* ptr; // Starting address of the mapped range.
|
||||
size_t size; // Size in bytes of the mapped range.
|
||||
};
|
||||
|
||||
struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
std::optional<sycl::queue*> queue,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers)
|
||||
: device_(device),
|
||||
queue_(queue),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(segment_size),
|
||||
peers_(std::move(peers)) {
|
||||
const auto device_total =
|
||||
c10::xpu::get_raw_device(device)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
// The extra 1/8 allows flexibility for remapping or moving pages within the
|
||||
// segment when unmapping earlier regions.
|
||||
constexpr float kVirtualMemOversubscriptFactor = 1.125f; // 1 + 1/8
|
||||
max_handles_ = numSegments(device_total * kVirtualMemOversubscriptFactor);
|
||||
ptr_ = sycl::ext::oneapi::experimental::reserve_virtual_mem(
|
||||
segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(ExpandableSegment);
|
||||
ExpandableSegment(ExpandableSegment&&) = delete;
|
||||
ExpandableSegment& operator=(ExpandableSegment&&) = delete;
|
||||
|
||||
// Maps a virtual memory range to physical memory.
|
||||
SegmentRange map(SegmentRange range) {
|
||||
auto begin = segmentLeft(range.ptr);
|
||||
auto end = segmentRight(range.ptr + range.size);
|
||||
TORCH_INTERNAL_ASSERT(ptr() + begin * segment_size_ == range.ptr);
|
||||
if (begin == end) {
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Ensure handles_ vector is large enough to hold all segments.
|
||||
if (end > handles_.size()) {
|
||||
handles_.resize(end, std::nullopt);
|
||||
}
|
||||
|
||||
// Allocate and map physical memory for each segment.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
TORCH_INTERNAL_ASSERT(!handles_.at(i));
|
||||
try {
|
||||
// Allocate physical memory for each segment. Construct the physical_mem
|
||||
// in-place to avoid copies.
|
||||
handles_.at(i).emplace(
|
||||
xpu::get_raw_device(device_),
|
||||
xpu::get_device_context(),
|
||||
segment_size_);
|
||||
// Map the allocated physical memory into the virtual address space.
|
||||
handles_.at(i).value().map(
|
||||
ptr_ + i * segment_size_,
|
||||
segment_size_,
|
||||
sycl::ext::oneapi::experimental::address_access_mode::read_write);
|
||||
} catch (const sycl::exception& e) {
|
||||
// Allocation failure: typically sycl::errc::memory_allocation.
|
||||
// Mapping failure: typically sycl::errc::runtime (e.g., OOM due to
|
||||
// over-subscription).
|
||||
// Note: constructing physical_mem may over-subscribe device memory but
|
||||
// not immediately trigger OOM. The actual OOM can occur during map().
|
||||
// Roll back all segments allocated or mapped in this operation.
|
||||
handles_.at(i) = std::nullopt;
|
||||
for (const auto j : c10::irange(begin, i)) {
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * j),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
handles_.at(j) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
return rangeFromHandles(begin, begin);
|
||||
}
|
||||
}
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Unmap a virtual memory range from physical memory.
|
||||
SegmentRange unmap(SegmentRange range) {
|
||||
auto begin = segmentRight(range.ptr);
|
||||
auto end = segmentLeft(range.ptr + range.size);
|
||||
if (begin >= end) {
|
||||
return SegmentRange{range.ptr, 0};
|
||||
}
|
||||
unmapHandles(begin, end);
|
||||
return rangeFromHandles(begin, end);
|
||||
}
|
||||
|
||||
// Returns the base pointer of the virtual memory segment.
|
||||
char* ptr() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<char*>(ptr_);
|
||||
}
|
||||
|
||||
// Returns the total size of the virtual memory segment.
|
||||
size_t size() const {
|
||||
return max_handles_ * segment_size_;
|
||||
}
|
||||
|
||||
~ExpandableSegment() {
|
||||
forEachAllocatedRange(
|
||||
[&](size_t begin, size_t end) { unmapHandles(begin, end); });
|
||||
sycl::ext::oneapi::experimental::free_virtual_mem(
|
||||
ptr_, segment_size_ * max_handles_, xpu::get_device_context());
|
||||
}
|
||||
|
||||
private:
|
||||
// Unmaps the physical memory handles in the range [begin, end) from the
|
||||
// segment.
|
||||
void unmapHandles(size_t begin, size_t end) {
|
||||
// Currently, we don't support IPC shared memory with expandable segments.
|
||||
TORCH_INTERNAL_ASSERT(queue_);
|
||||
// As explained in Note [Safe to Free Blocks on BlockPool], additional
|
||||
// synchronization is unnecessary here because the memory is already safe to
|
||||
// release.
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
// Note: physical_mem's destructor does NOT automatically unmap any mapped
|
||||
// ranges. Users must explicitly call unmap on all ranges before
|
||||
// destroying the physical_mem object.
|
||||
sycl::ext::oneapi::experimental::unmap(
|
||||
reinterpret_cast<void*>(ptr_ + segment_size_ * i),
|
||||
segment_size_,
|
||||
xpu::get_device_context());
|
||||
// Here physical_mem object is being destructed.
|
||||
handles_.at(i) = std::nullopt;
|
||||
}
|
||||
trimHandles();
|
||||
}
|
||||
|
||||
// Remove trailing unused handles from the end of handles_.
|
||||
void trimHandles() {
|
||||
while (!handles_.empty() && !handles_.back()) {
|
||||
handles_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
// Iterates over all contiguous ranges of allocated segments in `handles_`,
|
||||
// and invokes the provided function `fn(start, end)` for each range.
|
||||
// Each range is defined as a half-open interval [start, end).
|
||||
void forEachAllocatedRange(const std::function<void(size_t, size_t)>& fn) {
|
||||
size_t start = 0;
|
||||
for (const auto i : c10::irange(handles_.size())) {
|
||||
if (handles_.at(i) && (i == 0 || !handles_.at(i - 1))) {
|
||||
start = i;
|
||||
}
|
||||
if (handles_.at(i) && (i + 1 == handles_.size() || !handles_.at(i + 1))) {
|
||||
fn(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the number of full segments required to cover `size` bytes.
|
||||
// Rounds up to ensure partial segments are counted.
|
||||
size_t numSegments(size_t size) const {
|
||||
return (size + segment_size_ - 1) / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment that contains the pointer `p`,
|
||||
// relative to the base pointer `ptr_`. This is the *inclusive* lower bound
|
||||
// of the segment that includes `p`.
|
||||
size_t segmentLeft(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return offset / segment_size_;
|
||||
}
|
||||
|
||||
// Returns the index of the segment just *past* the one containing pointer
|
||||
// `p`, relative to the base pointer `ptr_`. This is the *exclusive* upper
|
||||
// bound, useful for [begin, end) style ranges.
|
||||
// If `p` lies exactly on a segment boundary, this is equal to segmentLeft(p).
|
||||
// Otherwise, it rounds up and returns segmentLeft(p) + 1.
|
||||
size_t segmentRight(char* p) const {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(p >= ptr() && p < ptr() + size());
|
||||
size_t offset = p - ptr();
|
||||
return numSegments(offset);
|
||||
}
|
||||
|
||||
// Constructs a SegmentRange spanning indices [start, end).
|
||||
SegmentRange rangeFromHandles(size_t begin, size_t end) {
|
||||
return SegmentRange(
|
||||
ptr() + segment_size_ * begin, segment_size_ * (end - begin));
|
||||
}
|
||||
|
||||
c10::DeviceIndex device_{-1};
|
||||
std::optional<sycl::queue*> queue_;
|
||||
// Virtual memory address used for reservation.
|
||||
uintptr_t ptr_{0};
|
||||
// Size of each segment in bytes.
|
||||
size_t segment_size_{0};
|
||||
// Maximum number of segments that can be allocated in this segment.
|
||||
size_t max_handles_{0};
|
||||
// Physical memory handles for the segments.
|
||||
std::vector<std::optional<sycl::ext::oneapi::experimental::physical_mem>>
|
||||
handles_{};
|
||||
// Peer devices on which this memory could be accessible, reserved.
|
||||
std::vector<c10::DeviceIndex> peers_{};
|
||||
};
|
||||
|
||||
struct AllocParams {
|
||||
AllocParams(
|
||||
DeviceIndex device,
|
||||
@ -125,10 +363,12 @@ class DeviceCachingAllocator {
|
||||
DeviceIndex device_index;
|
||||
size_t allowed_memory_maximum = 0;
|
||||
bool set_fraction = false;
|
||||
std::vector<ExpandableSegment*> expandable_segments;
|
||||
std::vector<c10::DeviceIndex> devices_with_peer_access; // reserved
|
||||
|
||||
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
|
||||
if (!src || src->allocated || src->event_count > 0 ||
|
||||
!src->stream_uses.empty()) {
|
||||
!src->stream_uses.empty() || dst->mapped != src->mapped) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -147,7 +387,8 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
const size_t subsumed_size = src->size;
|
||||
dst->size += subsumed_size;
|
||||
auto erased = pool.blocks.erase(src);
|
||||
auto erased =
|
||||
src->mapped ? pool.blocks.erase(src) : pool.unmapped.erase(src);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1);
|
||||
delete src;
|
||||
|
||||
@ -230,12 +471,175 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
// Finds the first (lowest-address) block in any segment that has sufficient
|
||||
// contiguous free virtual address space to satisfy `size`. The available
|
||||
// space may span multiple adjacent blocks, which can include both free and
|
||||
// unmapped segments.
|
||||
Block* find_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
Block key(device, queue, 0);
|
||||
|
||||
auto allocatable = [](Block* b) {
|
||||
return b && !b->allocated && b->event_count == 0 &&
|
||||
b->stream_uses.empty();
|
||||
};
|
||||
auto has_available_address_space = [&](Block* b) {
|
||||
size_t bytes = 0;
|
||||
while (bytes < size && allocatable(b)) {
|
||||
bytes += b->size;
|
||||
b = b->next;
|
||||
}
|
||||
return bytes >= size;
|
||||
};
|
||||
for (auto it = pool->unmapped.lower_bound(&key);
|
||||
it != pool->unmapped.end() && (*it)->queue == queue;
|
||||
++it) {
|
||||
Block* c = *it;
|
||||
// The unmapped block might have a free mapped block right before it.
|
||||
// By starting from the previous block, we can use both:
|
||||
// [Free Mapped Block] + [Unmapped Block] = More contiguous space
|
||||
if (allocatable(c->prev)) {
|
||||
c = c->prev;
|
||||
}
|
||||
if (has_available_address_space(c)) {
|
||||
return c;
|
||||
}
|
||||
}
|
||||
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
|
||||
expandable_segments.emplace_back(new ExpandableSegment(
|
||||
device, queue, segment_size, devices_with_peer_access));
|
||||
|
||||
ExpandableSegment* es = expandable_segments.back();
|
||||
Block* candidate = new Block(device, queue, es->size(), pool, es->ptr());
|
||||
candidate->mapped = false;
|
||||
candidate->expandable_segment = es;
|
||||
pool->unmapped.insert(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool map_block(Block* to_map, size_t size) {
|
||||
TORCH_INTERNAL_ASSERT(!to_map->mapped && size <= to_map->size);
|
||||
auto mapped_range =
|
||||
to_map->expandable_segment->map(SegmentRange{to_map->ptr, size});
|
||||
// Failed to map the memory
|
||||
if (mapped_range.size == 0) {
|
||||
return false;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
mapped_range.ptr == to_map->ptr && mapped_range.size >= size);
|
||||
|
||||
BlockPool& pool = *to_map->pool;
|
||||
pool.unmapped.erase(to_map);
|
||||
to_map->mapped = true;
|
||||
|
||||
if (mapped_range.size < to_map->size) {
|
||||
// to_map -> remaining -> to_map->next(?)
|
||||
Block* remaining = new Block(
|
||||
to_map->device,
|
||||
to_map->queue,
|
||||
to_map->size - mapped_range.size,
|
||||
&pool,
|
||||
static_cast<char*>(to_map->ptr) + mapped_range.size);
|
||||
remaining->mapped = false;
|
||||
remaining->expandable_segment = to_map->expandable_segment;
|
||||
remaining->splice(to_map, to_map->next);
|
||||
pool.unmapped.insert(remaining);
|
||||
to_map->size = mapped_range.size;
|
||||
}
|
||||
|
||||
try_merge_blocks(to_map, to_map->prev, pool);
|
||||
try_merge_blocks(to_map, to_map->next, pool);
|
||||
|
||||
pool.blocks.insert(to_map);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*to_map->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(mapped_range.size);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Block* try_allocate_expandable_block(
|
||||
c10::DeviceIndex device,
|
||||
sycl::queue* queue,
|
||||
BlockPool* pool,
|
||||
size_t size) {
|
||||
// Candidate points to the start of a chain of contiguous blocks with
|
||||
// sufficient virtual address space (>= size). The chain may consist of:
|
||||
// Case 1: [Unmapped Block] -> null
|
||||
// Case 2: [Unmapped Block] -> [Free Mapped Block]
|
||||
// Case 3: [Free Mapped Block] -> [Unmapped Block]
|
||||
Block* candidate = find_expandable_block(device, queue, pool, size);
|
||||
|
||||
// Map first block if unmapped (Case 1 & 2), use std::min to avoid
|
||||
// over-mapping.
|
||||
if (!candidate->mapped &&
|
||||
!map_block(candidate, std::min(candidate->size, size))) {
|
||||
return nullptr;
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(candidate->mapped);
|
||||
|
||||
// Map additional blocks until we have enough continuous space (Case 3).
|
||||
// Each map_block() call merges newly mapped blocks with adjacent free
|
||||
// blocks
|
||||
while (candidate->size < size) {
|
||||
auto remaining = size - candidate->size;
|
||||
auto new_candidate = candidate->next;
|
||||
// Map only what we need from the `new_candidate` block.
|
||||
if (!map_block(new_candidate, std::min(remaining, new_candidate->size))) {
|
||||
return nullptr;
|
||||
}
|
||||
candidate = new_candidate;
|
||||
}
|
||||
|
||||
// Remove from the free pool; block will be marked as `allocated` in
|
||||
// alloc_found_block()
|
||||
pool->blocks.erase(candidate);
|
||||
return candidate;
|
||||
}
|
||||
|
||||
bool get_free_block(AllocParams& p) {
|
||||
BlockPool& pool = *p.pool;
|
||||
auto it = pool.blocks.lower_bound(&p.search_key);
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
if ((*it)->expandable_segment) {
|
||||
if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
// When expandable segments are enabled, consider both the current block
|
||||
// and any immediately adjacent unmapped region as a single expandable
|
||||
// area. For "best fit" allocation, we use the total expandable size
|
||||
// instead of just the block's current size, so that blocks which can
|
||||
// grow into a larger contiguous range are preferred.
|
||||
auto expandable_size = [](Block* b) {
|
||||
// b->next may belong to pool.unmapped (reserved but not mapped)
|
||||
return b->size + (b->next && !b->next->mapped ? b->next->size : 0);
|
||||
};
|
||||
auto next = it;
|
||||
next++;
|
||||
// Looks for the best fit block with expandable size.
|
||||
while ((*it)->expandable_segment && next != pool.blocks.end() &&
|
||||
(*next)->queue == p.queue() &&
|
||||
expandable_size(*next) < expandable_size(*it)) {
|
||||
it = next++;
|
||||
}
|
||||
} else {
|
||||
// Expandable segments were previously enabled, but are now disabled
|
||||
// (e.g. to avoid IPC issues). Skip any expandable blocks and only
|
||||
// find from regular non-expandable segments.
|
||||
do {
|
||||
it++;
|
||||
} while (it != pool.blocks.end() && (*it)->expandable_segment &&
|
||||
(*it)->queue == p.queue());
|
||||
if (it == pool.blocks.end() || (*it)->queue != p.queue()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
return true;
|
||||
@ -252,6 +656,10 @@ class DeviceCachingAllocator {
|
||||
size >
|
||||
allowed_memory_maximum) {
|
||||
return false;
|
||||
} else if (AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
p.block =
|
||||
try_allocate_expandable_block(device, p.queue(), p.pool, p.size());
|
||||
return bool(p.block);
|
||||
}
|
||||
void* ptr = sycl::aligned_alloc_device(
|
||||
kDeviceAlignment,
|
||||
@ -265,6 +673,7 @@ class DeviceCachingAllocator {
|
||||
for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -283,6 +692,27 @@ class DeviceCachingAllocator {
|
||||
xpu_events.clear();
|
||||
}
|
||||
|
||||
void release_expandable_segment(Block* block) {
|
||||
// See Note [Safe to Free Blocks on BlockPool], additional synchronization
|
||||
// is unnecessary here because this function is only called by
|
||||
// release_cached_blocks().
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
block->size == block->expandable_segment->size(),
|
||||
"block disagrees with segment");
|
||||
TORCH_INTERNAL_ASSERT(!block->mapped);
|
||||
|
||||
auto it = std::find(
|
||||
expandable_segments.begin(),
|
||||
expandable_segments.end(),
|
||||
block->expandable_segment);
|
||||
TORCH_INTERNAL_ASSERT(it != expandable_segments.end());
|
||||
|
||||
expandable_segments.erase(it);
|
||||
block->pool->unmapped.erase(block);
|
||||
delete block->expandable_segment;
|
||||
delete block;
|
||||
}
|
||||
|
||||
void release_block(Block* block) {
|
||||
/*
|
||||
* Note [Safe to Free Blocks on BlockPool]
|
||||
@ -293,6 +723,7 @@ class DeviceCachingAllocator {
|
||||
* We have to do a device-level synchronization before free these blocks to
|
||||
* guarantee that all kernels can access to the blocks have finished.
|
||||
*/
|
||||
TORCH_INTERNAL_ASSERT(!block->expandable_segment);
|
||||
sycl::free(block->ptr, xpu::get_device_context());
|
||||
auto* pool = block->pool;
|
||||
pool->blocks.erase(block);
|
||||
@ -305,15 +736,80 @@ class DeviceCachingAllocator {
|
||||
delete block;
|
||||
}
|
||||
|
||||
void unmap_block(Block* block) {
|
||||
auto unmapped =
|
||||
block->expandable_segment->unmap(SegmentRange{block->ptr, block->size});
|
||||
if (unmapped.size == 0) {
|
||||
return;
|
||||
}
|
||||
block->pool->blocks.erase(block);
|
||||
|
||||
ptrdiff_t before_size = unmapped.ptr - static_cast<char*>(block->ptr);
|
||||
if (before_size > 0) {
|
||||
// If the actual unmapped region starts after block->ptr due to alignment,
|
||||
// the region before unmapped.ptr is still mapped.
|
||||
// [Prev Block?] -> [Before Block] -> [Unmapped Block]
|
||||
Block* before_free = new Block(
|
||||
block->device, block->queue, before_size, block->pool, block->ptr);
|
||||
before_free->expandable_segment = block->expandable_segment;
|
||||
before_free->splice(block->prev, block);
|
||||
block->pool->blocks.insert(before_free);
|
||||
}
|
||||
|
||||
auto after_size = block->size - (before_size + unmapped.size);
|
||||
if (after_size > 0) {
|
||||
// If the actual unmapped region ends before block->ptr + block->size,
|
||||
// the region after (unmapped.ptr + unmapped.size) is still mapped.
|
||||
// [Unmapped Block] -> [After Block] -> [Next Block?]
|
||||
Block* after_free = new Block(
|
||||
block->device,
|
||||
block->queue,
|
||||
after_size,
|
||||
block->pool,
|
||||
unmapped.ptr + unmapped.size);
|
||||
after_free->expandable_segment = block->expandable_segment;
|
||||
after_free->splice(block, block->next);
|
||||
block->pool->blocks.insert(after_free);
|
||||
}
|
||||
|
||||
// [Before Mapped Block?] -> [Unmapped Block] -> [After Mapped Block?]
|
||||
block->ptr = unmapped.ptr;
|
||||
block->size = unmapped.size;
|
||||
block->mapped = false;
|
||||
|
||||
try_merge_blocks(block, block->prev, *block->pool);
|
||||
try_merge_blocks(block, block->next, *block->pool);
|
||||
block->pool->unmapped.insert(block);
|
||||
|
||||
StatTypes stat_types = get_stat_types_for_pool(*block->pool);
|
||||
for_each_selected_stat_type(stat_types, [&](size_t stat_type) {
|
||||
stats.reserved_bytes[stat_type].decrease(unmapped.size);
|
||||
});
|
||||
}
|
||||
|
||||
void release_blocks(BlockPool& pool) {
|
||||
std::vector<Block*> to_unmap;
|
||||
// Frees all non-split blocks in the given pool.
|
||||
auto it = pool.blocks.begin();
|
||||
while (it != pool.blocks.end()) {
|
||||
Block* block = *it;
|
||||
++it;
|
||||
if (!block->prev && !block->next) {
|
||||
if (block->expandable_segment) {
|
||||
// unmap_block() modifies the free pool, so collect items to free first
|
||||
// to avoid iterator invalidation.
|
||||
to_unmap.push_back(block);
|
||||
} else if (!block->prev && !block->next) {
|
||||
release_block(block);
|
||||
}
|
||||
}
|
||||
for (Block* block : to_unmap) {
|
||||
unmap_block(block);
|
||||
// After unmap_block(), expandable segment blocks with no neighbors are
|
||||
// also released.
|
||||
if (!block->prev && !block->next) {
|
||||
release_expandable_segment(block);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool release_cached_blocks() {
|
||||
@ -328,7 +824,8 @@ class DeviceCachingAllocator {
|
||||
|
||||
bool should_split(const Block* block, size_t size) {
|
||||
size_t remaining = block->size - size;
|
||||
if (block->pool->is_small) {
|
||||
if (block->pool->is_small ||
|
||||
AcceleratorAllocatorConfig::use_expandable_segments()) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return remaining > kSmallSize;
|
||||
@ -361,6 +858,7 @@ class DeviceCachingAllocator {
|
||||
remaining = block;
|
||||
|
||||
block = new Block(device, queue, size, pool, block->ptr);
|
||||
block->expandable_segment = remaining->expandable_segment;
|
||||
block->prev = remaining->prev;
|
||||
if (block->prev) {
|
||||
block->prev->next = block;
|
||||
@ -599,6 +1097,15 @@ class XPUAllocator : public DeviceAllocator {
|
||||
return block;
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
public:
|
||||
std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocators;
|
||||
|
||||
@ -711,15 +1218,6 @@ class XPUAllocator : public DeviceAllocator {
|
||||
xpu::getCurrentXPUStream().queue().memcpy(dest, src, count);
|
||||
}
|
||||
|
||||
void assertValidDevice(DeviceIndex device) {
|
||||
const auto device_num = device_allocators.size();
|
||||
TORCH_CHECK(
|
||||
0 <= device && device < static_cast<int64_t>(device_num),
|
||||
"Invalid device argument ",
|
||||
device,
|
||||
": did you call init?");
|
||||
}
|
||||
|
||||
DeviceStats getDeviceStats(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getStats();
|
||||
@ -735,6 +1233,13 @@ class XPUAllocator : public DeviceAllocator {
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
assertValidDevice(dev);
|
||||
assertValidDevice(dev_to_access);
|
||||
c10::xpu::get_raw_device(dev).ext_oneapi_enable_peer_access(
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
@ -793,6 +1298,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) {
|
||||
return allocator.enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
@ -25,6 +25,10 @@ C10_XPU_API void raw_delete(void* ptr);
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
@ -1307,7 +1307,7 @@ endif()
|
||||
|
||||
if(USE_MKLDNN_ACL)
|
||||
find_package(ACL REQUIRED)
|
||||
target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS})
|
||||
target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE})
|
||||
|
||||
@ -26,7 +26,7 @@ find_library(Gloo_CUDA_LIBRARY
|
||||
# if Gloo + HIP is desired, Gloo_HIP_LIBRARY
|
||||
# needs to be linked to desired target
|
||||
find_library(Gloo_HIP_LIBRARY
|
||||
NAMES gloo_hiop
|
||||
NAMES gloo_hip
|
||||
DOC "Gloo's HIP support/code"
|
||||
)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ IF(NOT MKLDNN_FOUND)
|
||||
|
||||
SET(IDEEP_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep")
|
||||
SET(MKLDNN_ROOT "${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn")
|
||||
SET(ONEDNN_AARCH64_TAG "v3.10-rc")
|
||||
|
||||
if(USE_XPU) # Build oneDNN GPU library
|
||||
if(WIN32)
|
||||
@ -96,6 +97,14 @@ IF(NOT MKLDNN_FOUND)
|
||||
FIND_PACKAGE(BLAS)
|
||||
FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl)
|
||||
# Checkout the oneDNN version defined by ONEDNN_AARCH64_TAG for CPU_AARCH64
|
||||
IF(CPU_AARCH64)
|
||||
EXECUTE_PROCESS(
|
||||
COMMAND git${CMAKE_EXECUTABLE_SUFFIX} checkout ${ONEDNN_AARCH64_TAG}
|
||||
WORKING_DIRECTORY ${MKLDNN_ROOT}
|
||||
)
|
||||
FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include)
|
||||
ENDIF(CPU_AARCH64)
|
||||
IF(NOT MKLDNN_INCLUDE_DIR)
|
||||
MESSAGE("MKLDNN_INCLUDE_DIR not found")
|
||||
EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT})
|
||||
|
||||
@ -28,6 +28,15 @@ endif()
|
||||
# Find CUDA.
|
||||
find_package(CUDA)
|
||||
if(NOT CUDA_FOUND)
|
||||
# If user explicitly set USE_CUDA=1, error out instead of falling back
|
||||
if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA)
|
||||
message(FATAL_ERROR
|
||||
"PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. "
|
||||
"Please check your CUDA installation, ensure CUDA toolkit is installed, "
|
||||
"and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. "
|
||||
"If you want to build without CUDA, please set USE_CUDA=0.")
|
||||
endif()
|
||||
|
||||
message(WARNING
|
||||
"PyTorch: CUDA cannot be found. Depending on whether you are building "
|
||||
"PyTorch or a PyTorch dependent library, the next warning / error will "
|
||||
|
||||
@ -45,7 +45,7 @@ supported for complex tensors.
|
||||
## Transition from the old representation
|
||||
|
||||
Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)`
|
||||
can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex`
|
||||
can easily switch to using the complex tensors in their code using {func}`torch.view_as_complex`
|
||||
and {func}`torch.view_as_real`. Note that these functions don’t perform any copy and return a
|
||||
view of the input tensor.
|
||||
|
||||
@ -140,7 +140,7 @@ through the same optimizer on the {func}`torch.view_as_real` equivalent of the c
|
||||
|
||||
`real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical
|
||||
discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers
|
||||
and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
|
||||
and capturable vs default optimizers. For more details, see [numerical accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
|
||||
|
||||
Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their
|
||||
`p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the
|
||||
|
||||
@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective
|
||||
.. autofunction:: new_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: torch.distributed.distributed_c10d.shrink_group
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: get_group_rank
|
||||
```
|
||||
|
||||
@ -1720,6 +1720,16 @@ and can be used to share memory across graphs as shown::
|
||||
g1.replay()
|
||||
g2.replay()
|
||||
|
||||
It's also safe to share a memory pool across separate graphs that do not depend
|
||||
on each other's outputs, provided they never run concurrently.
|
||||
Be aware that replaying one graph can clobber another graph's outputs when
|
||||
they share a pool, unless :meth:`~torch.Tensor.clone` is called on the outputs
|
||||
beforehand.
|
||||
This pattern is frequently used in inference servers that accept variable batch
|
||||
sizes at runtime.
|
||||
vLLM is a notable example; see `here <https://github.com/vllm-project/vllm/blob/938a81692ea318e59ead4750e7e7425bfd6a4896/vllm/platforms/interface.py#L508-L515>`__
|
||||
and `here <https://github.com/vllm-project/vllm/blob/938a81692ea318e59ead4750e7e7425bfd6a4896/vllm/compilation/cuda_graph.py#L86-L89>`__.
|
||||
|
||||
With :func:`torch.cuda.make_graphed_callables`, if you want to graph several
|
||||
callables and you know they'll always run in the same order (and never concurrently)
|
||||
pass them as a tuple in the same order they'll run in the live workload, and
|
||||
|
||||
@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp
|
||||
@ -44,6 +45,10 @@ endif()
|
||||
# Disable unused-variable warnings for variables that are only used to test compilation
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable)
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable)
|
||||
# Add -Wno-dangling-pointer for GCC 13
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13)
|
||||
target_compile_options_if_supported(test_aoti_abi_check -Wno-dangling-pointer)
|
||||
endif()
|
||||
|
||||
foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS})
|
||||
foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES})
|
||||
|
||||
52
test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
Normal file
52
test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
using torch::headeronly::HeaderOnlyArrayRef;
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestEmpty) {
|
||||
HeaderOnlyArrayRef<float> arr;
|
||||
ASSERT_TRUE(arr.empty());
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestSingleton) {
|
||||
float val = 5.0f;
|
||||
HeaderOnlyArrayRef<float> arr(val);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 1);
|
||||
EXPECT_EQ(arr[0], val);
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestAPIs) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec);
|
||||
ASSERT_FALSE(arr.empty());
|
||||
EXPECT_EQ(arr.size(), 7);
|
||||
for (size_t i = 0; i < arr.size(); i++) {
|
||||
EXPECT_EQ(arr[i], i + 1);
|
||||
EXPECT_EQ(arr.at(i), i + 1);
|
||||
}
|
||||
EXPECT_EQ(arr.front(), 1);
|
||||
EXPECT_EQ(arr.back(), 7);
|
||||
ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3)));
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr({1, 2, 3, 4, 5, 6, 7});
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i], res_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TestHeaderOnlyArrayRef, TestFromRange) {
|
||||
std::vector<int> vec = {1, 2, 3, 4, 5, 6, 7};
|
||||
HeaderOnlyArrayRef<int> arr(vec.data() + 3, vec.data() + 7);
|
||||
auto res_vec = arr.vec();
|
||||
for (size_t i = 0; i < res_vec.size(); i++) {
|
||||
EXPECT_EQ(vec[i + 3], res_vec[i]);
|
||||
}
|
||||
}
|
||||
@ -70,6 +70,13 @@ if(NOT MSVC)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12)
|
||||
target_compile_options_if_supported(test_api "-Wno-error=nonnull")
|
||||
endif()
|
||||
|
||||
# Add -Wno-error=array-bounds for GCC 13+
|
||||
# See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113239
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13)
|
||||
target_compile_options_if_supported(test_api "-Wno-error=array-bounds")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
if(INSTALL_TEST)
|
||||
|
||||
@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
aoti_torch_get_strides(param.get(), ¶m_strides);
|
||||
// testing Tensor strides + stride
|
||||
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
|
||||
|
||||
int32_t param_dtype;
|
||||
aoti_torch_get_dtype(param.get(), ¶m_dtype);
|
||||
|
||||
int32_t param_device_type;
|
||||
aoti_torch_get_device_type(param.get(), ¶m_device_type);
|
||||
|
||||
AtenTensorHandle out_ath;
|
||||
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
|
||||
auto out = Tensor(out_ath);
|
||||
auto out = new_empty(param, param.sizes());
|
||||
|
||||
sgd_math(
|
||||
reinterpret_cast<float*>(param.data_ptr()),
|
||||
@ -311,10 +301,9 @@ void boxed_fill_infinity(
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
std::vector<int64_t> padding = {1, 2, 2, 1};
|
||||
std::string mode = "constant";
|
||||
double value = 0.0;
|
||||
return pad(t, padding, mode, value);
|
||||
return pad(t, {1, 2, 2, 1}, mode, value);
|
||||
}
|
||||
|
||||
void boxed_my_pad(
|
||||
@ -342,6 +331,11 @@ void boxed_my_narrow(
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
// Still using a std::vector below even though people can just pass in an
|
||||
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
|
||||
// directly.
|
||||
// This is to test that passing in a std::vector works for BC. (It gets
|
||||
// implicitly converted to HeaderOnlyArrayRef too!)
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
|
||||
return new_empty(t, sizes, dtype);
|
||||
@ -353,9 +347,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
std::vector<int64_t> sizes = {2, 5};
|
||||
auto dtype = std::make_optional(at::ScalarType::Float);
|
||||
return new_zeros(t, sizes, dtype);
|
||||
return new_zeros(t, {2, 5}, dtype);
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
@ -429,8 +422,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs)
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
std::vector<int64_t> v = {0,1};
|
||||
return amax(t, v, false);
|
||||
return amax(t, {0,1}, false);
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
|
||||
@ -256,23 +256,25 @@ class TestSDPA(NNTestCase):
|
||||
)
|
||||
rand_upward_privateuse1 = rand_upward.to("openreg")
|
||||
grad_input_mask = [True, True, True, True]
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
_grad_q, _grad_k, _grad_v, _grad_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward(
|
||||
rand_upward_privateuse1,
|
||||
q_privateuse1,
|
||||
k_privateuse1,
|
||||
v_privateuse1,
|
||||
attn_mask_privateuse1,
|
||||
grad_input_mask,
|
||||
output,
|
||||
logsumexp,
|
||||
cum_seq_q,
|
||||
cum_seq_k,
|
||||
max_q,
|
||||
max_k,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset=philox_offset,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
|
||||
@ -416,15 +416,13 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
param_dtype=MixedPrecisionParam,
|
||||
reduce_dtype=torch.float32,
|
||||
)
|
||||
replicate_config = {"mp_policy": mp_policy}
|
||||
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
**replicate_config,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
|
||||
dp_model = replicate(partial_model, **replicate_config)
|
||||
return dp_model
|
||||
|
||||
# Apply same precision to reference model (without replicate)
|
||||
@ -582,11 +580,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_group = device_mesh["replicate"].get_group()
|
||||
@ -648,10 +646,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
reshard_after_forward=False,
|
||||
mesh=dp_mesh,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh)
|
||||
dp_model = replicate(partial_model, mesh=dp_mesh)
|
||||
return dp_model
|
||||
|
||||
def pipelined_models_parameters(start_layer, model):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,7 +14,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
@ -46,35 +45,20 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
@ -82,27 +66,13 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
@ -110,14 +80,10 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -175,39 +141,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
self._test_reduce_dtype_fp32_reduce()
|
||||
self._test_reduce_dtype_bf16_reduce()
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
def _test_reduce_dtype_fp32_reduce(self):
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -249,14 +190,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
self,
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -321,12 +260,8 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(mlp, mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
@ -108,84 +108,70 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
if reshard_after_forward:
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
else:
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj)
|
||||
replicate(model[0].out_proj)
|
||||
replicate(model)
|
||||
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
if reshard_after_forward is None:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
elif reshard_after_forward:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_dtensor_params(root_params)
|
||||
else:
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj)
|
||||
replicate(model.out_proj)
|
||||
replicate(model)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
||||
# need to iterate over the list multiple times
|
||||
@ -287,14 +273,11 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
[(7, 15), (15, 3)],
|
||||
[(16, 17), (17, 8)],
|
||||
],
|
||||
"use_shard_placement_fn": [False],
|
||||
},
|
||||
self._test_train_parity_single_group,
|
||||
)
|
||||
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
@ -333,7 +316,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
@ -354,7 +336,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True], # save CI time
|
||||
"offload_policy": [
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
@ -371,7 +352,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
|
||||
def _test_train_parity_multi_group(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
offload_policy: OffloadPolicy,
|
||||
test_device_type: str,
|
||||
delay_after_forward: bool,
|
||||
@ -405,13 +385,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fully_shard_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for module in model.modules():
|
||||
@ -527,12 +506,10 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
Tests parity when running a module that participates multiple
|
||||
times in forward.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_multi_forward_module,
|
||||
)
|
||||
|
||||
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
|
||||
self._test_multi_forward_module()
|
||||
|
||||
def _test_multi_forward_module(self):
|
||||
class MultiForwardModule(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
@ -687,7 +664,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"checkpoint_impl": ["composable", "utils", "wrapper"],
|
||||
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
|
||||
"test_device_type": [device_type.type],
|
||||
@ -697,7 +673,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
|
||||
def _test_train_parity_with_activation_checkpointing(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
checkpoint_impl: str,
|
||||
module_grouping: str,
|
||||
test_device_type: str,
|
||||
@ -740,12 +715,11 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
# Apply Replicate
|
||||
device_mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fsdp_kwargs = {
|
||||
"reshard_after_forward": reshard_after_forward,
|
||||
"device_mesh": device_mesh,
|
||||
"mesh": device_mesh,
|
||||
}
|
||||
if module_grouping == "mem_eff":
|
||||
assert model_args.n_layers == 3
|
||||
@ -809,7 +783,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
def test_train_parity_with_shared_params(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
},
|
||||
self._test_train_shared_params,
|
||||
@ -817,7 +790,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
|
||||
def _test_train_shared_params(
|
||||
self,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
@ -830,8 +802,8 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
if isinstance(module, TransformerBlock):
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
@ -868,11 +840,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
replicate_size = self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(replicate_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -928,8 +900,7 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
@ -1040,8 +1011,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
@ -1145,8 +1116,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
(2, 2),
|
||||
mesh_dim_names=("dp_replicate", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@ -1154,7 +1125,6 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
@ -1165,12 +1135,11 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
@ -1197,8 +1166,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
replicate(module, mesh=dp_mesh)
|
||||
replicate(model, mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
@ -1229,11 +1198,9 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
self.assertEqual(p.device_mesh.ndim, 2)
|
||||
self.assertEqual(len(p.placements), 2)
|
||||
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue("replicate" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
elif i % 2 == 1:
|
||||
self.assertTrue("fully_shard" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
]
|
||||
|
||||
global_mesh = self.init_replicate_tp_mesh()
|
||||
replicate_mesh = global_mesh["replicate", "shard"]
|
||||
replicate_mesh = global_mesh["replicate"]
|
||||
|
||||
for layer in layers:
|
||||
replicate(layer, device_mesh=replicate_mesh)
|
||||
replicate(layer, mesh=replicate_mesh)
|
||||
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.device_mesh.shape, (2, 1))
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_replicate_fsdp(self):
|
||||
@ -263,7 +263,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
run_subtests(
|
||||
self,
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 16, 17],
|
||||
},
|
||||
@ -273,7 +272,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
def _test_train_parity_2d_mlp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
):
|
||||
@ -287,13 +285,12 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_mesh=replicate_shard_mesh)
|
||||
replicate(ref_model, mesh=replicate_mesh)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
replicate_shard_mesh,
|
||||
use_activation_checkpointing,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
|
||||
@ -1,16 +1,26 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.testing._internal.common_utils as common
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint._async_process_executor import (
|
||||
_ProcessBasedAsyncCheckpointExecutor,
|
||||
_ProcessGroupInitInfo,
|
||||
)
|
||||
from torch.distributed.checkpoint.api import CheckpointException
|
||||
from torch.distributed.checkpoint.storage import StorageWriter
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
|
||||
from torch.testing._internal.common_distributed import skip_if_win32
|
||||
from torch.testing._internal.common_utils import (
|
||||
retry_on_connect_failures,
|
||||
run_tests,
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
@ -110,47 +120,184 @@ class TestAsyncProcessExecutor(DTensorTestBase):
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
# 1. Simulate a failure in creating PG in background process.
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=-1,
|
||||
):
|
||||
with self.assertRaises(ValueError) as _:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
)
|
||||
fut.result()
|
||||
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
fut.result()
|
||||
self.assertIn(
|
||||
"fail_once policy triggered failure", str(fut.exception())
|
||||
)
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
|
||||
# 2. Attempt save with failing storage writer
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
return_value=get_free_port(),
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="fail_once"),
|
||||
)
|
||||
self.assertIn("fail_once policy triggered failure", str(fut.exception()))
|
||||
# Verify new process was created for this attempt
|
||||
if dist.get_rank() == 0:
|
||||
mock_get_free_port.assert_called_once()
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
# 3. Second save attempt with successful storage writer - process should still be alive
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port",
|
||||
) as mock_get_free_port:
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
# Verify process is still alive
|
||||
mock_get_free_port.assert_not_called()
|
||||
# Verify successful save
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
class TestAsyncProcessExecutorPrefixStore(TestCase):
|
||||
@skip_if_win32()
|
||||
@retry_on_connect_failures
|
||||
def test_checkpoint_save_with_prefix_store_enabled(self) -> None:
|
||||
"""Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled."""
|
||||
|
||||
test_state_dict = {
|
||||
"model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)},
|
||||
"optimizer": {"param_groups": [{"lr": 0.01}]},
|
||||
"epoch": 5,
|
||||
}
|
||||
|
||||
master_addr = "localhost"
|
||||
master_port = str(common.find_free_port())
|
||||
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": master_addr,
|
||||
"MASTER_PORT": master_port,
|
||||
},
|
||||
):
|
||||
with patch(
|
||||
"torch.distributed.checkpoint._async_process_executor.get_free_port"
|
||||
) as mock_get_free_port:
|
||||
dist.init_process_group(
|
||||
backend=dist.Backend.GLOO,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
|
||||
proc_executor = _ProcessBasedAsyncCheckpointExecutor()
|
||||
fut = proc_executor.execute_save(
|
||||
staging_future_or_state_dict=test_state_dict,
|
||||
storage_writer=TestStorageWriter(behavior="success"),
|
||||
)
|
||||
result = fut.result()
|
||||
self.assertIsNotNone(result)
|
||||
mock_get_free_port.assert_not_called()
|
||||
|
||||
|
||||
class TestProcessGroupInitInfo(DTensorTestBase):
|
||||
"""Test suite for _ProcessGroupInitInfo."""
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_default_pg(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo correctly initializes."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
self.assertEqual(pg_init_info.global_rank, dist.get_rank())
|
||||
self.assertEqual(pg_init_info.world_size, dist.get_world_size())
|
||||
self.assertIsNotNone(pg_init_info.tcp_store_master_addr)
|
||||
self.assertGreater(pg_init_info.tcp_store_master_port, 0)
|
||||
self.assertEqual(pg_init_info.use_prefix_store, False)
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_with_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable."""
|
||||
|
||||
# Flag enabled, addr/port correctly defined
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
},
|
||||
):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertTrue(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing port
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Missing addr
|
||||
with patch.dict(
|
||||
os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"}
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
# Invalid port
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DCP_USE_PREFIX_STORE": "1",
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "a",
|
||||
},
|
||||
):
|
||||
with self.assertRaises(CheckpointException):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
|
||||
@with_comms
|
||||
def test_process_group_init_info_without_prefix_store_env_var(self) -> None:
|
||||
"""Test that ProcessGroupInitInfo defaults to not using prefix store."""
|
||||
|
||||
# Env var set to 0
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Missing env var
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("DCP_USE_PREFIX_STORE", None)
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
# Invalid env var
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}):
|
||||
pg_init_info = _ProcessGroupInitInfo()
|
||||
self.assertFalse(pg_init_info.use_prefix_store)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -5,8 +5,16 @@ import contextlib
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||
from torch.distributed.tensor import (
|
||||
DeviceMesh,
|
||||
distribute_tensor,
|
||||
DTensor,
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
@ -42,22 +50,24 @@ class TestDTensorDebugMode(TestCase):
|
||||
x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False)
|
||||
|
||||
with DebugMode(record_torchfunction=True) as debug_mode:
|
||||
with DebugMode(
|
||||
record_torchfunction=True, record_ids=True, record_output=True
|
||||
) as debug_mode:
|
||||
torch.mm(x_dtensor, y_dtensor).sum()
|
||||
|
||||
self.assertExpectedInline(
|
||||
debug_mode.debug_string(),
|
||||
"""\
|
||||
torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
|
||||
aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0))
|
||||
torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0)
|
||||
aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))
|
||||
redistribute_input(1, S(0) -> R)
|
||||
redistribute_input(t: f32[1, 32], trace: S(0)->R)
|
||||
_c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0)
|
||||
_c10d_functional::wait_tensor(t: f32[8, 32])
|
||||
aten::mm(t: f32[1, 8], t: f32[8, 32])
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt: f32[8, 32]| S(0))
|
||||
aten::sum(dt: f32[8, 32]| S(0))
|
||||
aten::sum(t: f32[1, 32])""",
|
||||
redistribute_input(t$2: f32[1, 32], trace: S(0)->R)
|
||||
_c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32]
|
||||
_c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32]
|
||||
aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32]
|
||||
<method 'sum' of 'torch._C.TensorBase' objects>(dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P
|
||||
aten::sum(dt$6: f32[8, 32]| S(0))
|
||||
aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""",
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(debug_mode.operators[0], _OpCall))
|
||||
@ -415,6 +425,40 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])""",
|
||||
)
|
||||
|
||||
with DebugMode(record_stack_trace=True) as debug_mode:
|
||||
out = mod(inp).sum()
|
||||
out.backward()
|
||||
|
||||
sum_op = [
|
||||
op for op in debug_mode.operators if str(op.op) == "aten.sum.dim_IntList"
|
||||
][-1]
|
||||
self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace)
|
||||
|
||||
def test_pretty_print_dtensor_make_fx(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
A = torch.randn(8, 32)
|
||||
B = torch.randn(32, 32)
|
||||
dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_()
|
||||
dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_()
|
||||
|
||||
def f(dA, dB):
|
||||
dy = dA @ dB
|
||||
loss = dy.sum()
|
||||
loss.backward()
|
||||
return dA.grad, dB.grad
|
||||
|
||||
# We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode.
|
||||
# make_fx has some logic to ensure we don't accidentally stash real tensors in the graph
|
||||
# so we won't stash our DTensors properly if they don't hold Fake inner tensors
|
||||
gm = make_fx(f, tracing_mode="fake")(dA, dB)
|
||||
# DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.recompile()
|
||||
# Colored is nice for actual viewing, not using in this test though
|
||||
gm_str = gm.print_readable(colored=False, print_output=False)
|
||||
self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from typing import Any, Callable, ClassVar, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@ -1019,6 +1019,28 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
except ValueError:
|
||||
self.fail("Unexpected ValueError raised with run_check=False")
|
||||
|
||||
@with_comms
|
||||
def test_as_strided_identity(self):
|
||||
# Test calling as_strided with the same size/stride/offset as input tensor
|
||||
# This should be a no-op but currently fails
|
||||
device_mesh = self.build_device_mesh()
|
||||
placements = [Shard(0)]
|
||||
local_tensor = torch.randn(3, 4, device=self.device_type)
|
||||
dtensor = DTensor.from_local(local_tensor, device_mesh, placements)
|
||||
|
||||
# Get the current size, stride, and storage_offset
|
||||
size = dtensor.size()
|
||||
stride = dtensor.stride()
|
||||
storage_offset = dtensor.storage_offset()
|
||||
|
||||
# Call as_strided with the exact same parameters
|
||||
result = dtensor.as_strided(size, stride, storage_offset)
|
||||
|
||||
# The result should be identical to the input
|
||||
self.assertEqual(result.size(), dtensor.size())
|
||||
self.assertEqual(result.stride(), dtensor.stride())
|
||||
self.assertEqual(result.to_local(), dtensor.to_local())
|
||||
|
||||
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
@ -21,6 +22,7 @@ from unittest import mock, SkipTest
|
||||
import torch
|
||||
import torch.distributed as c10d
|
||||
import torch.distributed._functional_collectives as _functional_collectives
|
||||
from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT
|
||||
|
||||
|
||||
if not c10d.is_available() or not c10d.is_nccl_available():
|
||||
@ -47,12 +49,15 @@ from torch._C._distributed_c10d import ErrorType, OpType, WorkResult
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_distributed import (
|
||||
get_required_world_size,
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_shrink,
|
||||
requires_nccl_version,
|
||||
requires_world_size,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
@ -88,6 +93,53 @@ BFLOAT16_AVAILABLE = torch.cuda.is_available() and (
|
||||
)
|
||||
|
||||
|
||||
_start_time = time.time()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ts():
|
||||
return time.time() - _start_time
|
||||
|
||||
|
||||
def configure(level=logging.INFO, force=False):
|
||||
try:
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
force=force,
|
||||
)
|
||||
except TypeError:
|
||||
logging.basicConfig(
|
||||
level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
|
||||
def log_test_info(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_success(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_validation(rank, message):
|
||||
_logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_warning(rank, message):
|
||||
_logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
def log_test_error(rank, message):
|
||||
_logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message)
|
||||
|
||||
|
||||
_log_configure = configure
|
||||
|
||||
|
||||
_log_configure(level=logging.INFO, force=True)
|
||||
|
||||
|
||||
class RendezvousEnvTest(TestCase):
|
||||
@retry_on_connect_failures
|
||||
@requires_nccl()
|
||||
@ -317,7 +369,7 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
return get_required_world_size(self, 2)
|
||||
|
||||
@property
|
||||
def rank_to_GPU(self):
|
||||
@ -1255,6 +1307,628 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_basic(self):
|
||||
"""Test basic shrink_group functionality."""
|
||||
self._perform_shrink_test([1], "Basic shrink test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_validation(self):
|
||||
"""Test input validation in shrink_group."""
|
||||
device, pg = self._setup_shrink_test("validation")
|
||||
|
||||
def _test_invalid_input(ranks, description, expected_exception):
|
||||
"""Helper to test invalid inputs."""
|
||||
try:
|
||||
c10d.shrink_group(ranks)
|
||||
self.fail(f"Expected {expected_exception.__name__} for {description}")
|
||||
except expected_exception:
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
except Exception:
|
||||
if expected_exception is Exception: # Accept any exception
|
||||
log_test_validation(self.rank, f"✓ {description}")
|
||||
else:
|
||||
raise
|
||||
|
||||
# Test cases
|
||||
_test_invalid_input([], "Empty exclusion list", ValueError)
|
||||
if self.world_size > 1:
|
||||
_test_invalid_input([0, 0, 1], "Duplicate ranks", Exception)
|
||||
_test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception)
|
||||
|
||||
log_test_success(self.rank, "All validation tests passed")
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_backend_properties(self):
|
||||
"""Test that backend properties are preserved after shrinking."""
|
||||
|
||||
test_name = "Backend Properties Test"
|
||||
ranks_to_exclude = [0]
|
||||
|
||||
# Reuse _setup_shrink_test for complete setup (device, environment, and process group)
|
||||
device, pg = self._setup_shrink_test("backend_properties")
|
||||
|
||||
# Follow _perform_shrink_test pattern from here
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# Store original backend property values (not references) before shrinking
|
||||
original_timeout = None
|
||||
original_high_priority = None
|
||||
if not is_excluded:
|
||||
original_backend = pg._get_backend(device)
|
||||
original_timeout = original_backend.options._timeout
|
||||
original_high_priority = original_backend.options.is_high_priority_stream
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}",
|
||||
)
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
dist.destroy_process_group() # hang without it
|
||||
return
|
||||
|
||||
# Only non-excluded ranks proceed with shrink (same as _perform_shrink_test)
|
||||
log_test_info(self.rank, "Non-excluded rank calling shrink_group")
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
|
||||
# Reuse _validate_shrunk_group helper (same as _perform_shrink_test)
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
# Add custom backend properties validation
|
||||
new_backend = shrunk_pg._get_backend(device)
|
||||
log_test_info(self.rank, "Validating backend properties are preserved")
|
||||
|
||||
new_timeout = new_backend.options._timeout
|
||||
new_high_priority = new_backend.options.is_high_priority_stream
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Timeout comparison - original: {original_timeout}, new: {new_timeout}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_timeout, new_timeout, f"{test_name}: timeout not preserved"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}",
|
||||
)
|
||||
self.assertEqual(
|
||||
original_high_priority,
|
||||
new_high_priority,
|
||||
f"{test_name}: high_priority_stream not preserved",
|
||||
)
|
||||
|
||||
log_test_validation(
|
||||
self.rank, f"{test_name}: Backend properties preserved successfully"
|
||||
)
|
||||
log_test_success(
|
||||
self.rank, f"{test_name} successful (shrink + backend validation)"
|
||||
)
|
||||
|
||||
# Cleanup (same as _perform_shrink_test)
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_multiple_comms(self):
|
||||
"""Test shrink_group with multiple communicators and subgroup invalidation."""
|
||||
|
||||
device, pg = self._setup_shrink_test("multiple_comms")
|
||||
|
||||
# Create subgroup [0, 1] and test shrinking it
|
||||
subgroup = c10d.new_group([0, 1])
|
||||
if self.rank <= 1:
|
||||
# Shrink subgroup: exclude rank 1
|
||||
if self.rank == 0: # Only rank 0 remains
|
||||
shrunk_subgroup = c10d.shrink_group([1], group=subgroup)
|
||||
self.assertEqual(shrunk_subgroup.size(), 1)
|
||||
# Test communication on shrunk subgroup
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_subgroup)
|
||||
self.assertEqual(tensor.item(), 0) # Only rank 0
|
||||
log_test_success(self.rank, "Subgroup shrinking successful")
|
||||
|
||||
dist.barrier() # Sync before default group test
|
||||
|
||||
# Shrink default group: exclude last rank
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
if self.rank not in ranks_to_exclude:
|
||||
shrunk_default = c10d.shrink_group(ranks_to_exclude)
|
||||
expected_size = self.world_size - 1
|
||||
self.assertEqual(shrunk_default.size(), expected_size)
|
||||
|
||||
# Test collective on shrunk default group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
c10d.all_reduce(tensor, group=shrunk_default)
|
||||
expected_sum = sum(
|
||||
range(self.world_size - 1)
|
||||
) # 0 + 1 + ... + (world_size-2)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_success(self.rank, "Default group shrinking successful")
|
||||
|
||||
# Note: After shrinking default group, the old subgroup is invalid
|
||||
# due to global rank reassignment
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude):
|
||||
"""Helper method to test shrink_group with a specific flag."""
|
||||
if self.world_size < 2:
|
||||
log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})")
|
||||
return
|
||||
ranks_to_exclude = [rank_to_exclude]
|
||||
log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})")
|
||||
if flag_name == "NCCL_SHRINK_ABORT":
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"ABORT flag will terminate ongoing operations before shrinking",
|
||||
)
|
||||
|
||||
self._perform_shrink_test(
|
||||
ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag
|
||||
)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_flags(self):
|
||||
"""Test shrink_group with different shrink flags."""
|
||||
# Test ABORT flags
|
||||
log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag")
|
||||
self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1)
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_nccl_config(self):
|
||||
"""Verify that passing NCCL config via pg_options influences the shrunk group's backend options."""
|
||||
device, pg = self._setup_shrink_test("config")
|
||||
if self.rank == self.world_size - 1:
|
||||
# excluded rank should not call shrink_group
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Prepare pg_options with NCCL config overrides
|
||||
# Capture parent's current backend options to ensure we can prove override vs inherit
|
||||
parent_backend = pg._get_backend(torch.device("cuda"))
|
||||
parent_hp = parent_backend.options.is_high_priority_stream
|
||||
parent_blocking = parent_backend.options.config.blocking
|
||||
|
||||
# Choose overrides that differ from the parent (flip where possible)
|
||||
override_hp = not parent_hp
|
||||
if parent_blocking in (0, 1):
|
||||
override_blocking = 1 - parent_blocking
|
||||
else:
|
||||
# If undefined or unexpected, set to 1 which is a concrete value
|
||||
override_blocking = 1
|
||||
|
||||
opts = c10d.ProcessGroupNCCL.Options()
|
||||
opts.is_high_priority_stream = override_hp
|
||||
opts.config.blocking = override_blocking
|
||||
|
||||
shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts)
|
||||
|
||||
# Validate backend options propagated
|
||||
backend = shrunk_pg._get_backend(torch.device("cuda"))
|
||||
# is_high_priority_stream should exactly match our override and differ from parent
|
||||
self.assertEqual(backend.options.is_high_priority_stream, override_hp)
|
||||
self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp)
|
||||
# config is a struct; check representative field and difference from parent when meaningful
|
||||
self.assertEqual(backend.options.config.blocking, override_blocking)
|
||||
if parent_blocking in (0, 1):
|
||||
self.assertNotEqual(backend.options.config.blocking, parent_blocking)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(2)
|
||||
def test_shrink_group_performance(self):
|
||||
"""Test shrink_group performance and regression detection."""
|
||||
import time
|
||||
|
||||
ranks_to_exclude = self._get_default_ranks_to_exclude()
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
|
||||
if not ranks_to_exclude:
|
||||
log_test_info(self.rank, "Skipping performance test (world_size=1)")
|
||||
return
|
||||
|
||||
log_test_info(self.rank, f"Performance test with {self.world_size} processes")
|
||||
device, pg = self._setup_shrink_test("performance")
|
||||
|
||||
if not is_excluded:
|
||||
log_test_info(self.rank, "Measuring shrink_group performance")
|
||||
start_time = time.time()
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude)
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s")
|
||||
|
||||
# Regression check: should complete within reasonable time
|
||||
self.assertLess(
|
||||
elapsed_time,
|
||||
30.0,
|
||||
f"shrink_group took {elapsed_time:.3f}s, possible regression",
|
||||
)
|
||||
|
||||
# Test collective performance
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
self._validate_shrunk_group(shrunk_pg, expected_size, "performance")
|
||||
|
||||
collective_start = time.time()
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, "performance"
|
||||
)
|
||||
collective_time = time.time() - collective_start
|
||||
|
||||
log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s")
|
||||
log_test_success(self.rank, "Performance test passed")
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded rank - waiting")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(4)
|
||||
def test_shrink_group_multiple_exclusions(self):
|
||||
"""Test shrink_group with multiple ranks excluded at once."""
|
||||
# Scale exclusions with world size
|
||||
ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2
|
||||
|
||||
self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test")
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_multiple_iterations(self):
|
||||
"""Test multiple shrink operations in sequence."""
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}",
|
||||
)
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
_ = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
# Track current effective world size throughout shrinking operations
|
||||
current_world_size = self.world_size
|
||||
log_test_info(self.rank, f"Initial world_size: {current_world_size}")
|
||||
|
||||
# First shrinking: exclude the last rank(s)
|
||||
first_exclusion = [self.world_size - 1]
|
||||
if self.world_size >= 6:
|
||||
first_exclusion.append(
|
||||
self.world_size - 2
|
||||
) # Exclude last two ranks for larger sizes
|
||||
|
||||
log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}")
|
||||
|
||||
if self.rank not in first_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group
|
||||
first_pg = c10d.shrink_group(first_exclusion)
|
||||
self.assertIsNotNone(first_pg)
|
||||
# IMPORTANT: Update world size after first shrinking
|
||||
current_world_size = first_pg.size()
|
||||
expected_first_size = self.world_size - len(first_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After first shrinking: world_size {self.world_size} -> {current_world_size}",
|
||||
)
|
||||
self.assertEqual(first_pg.size(), expected_first_size)
|
||||
|
||||
# Second shrinking: exclude another rank from the remaining group
|
||||
# Choose a rank that's in the middle range
|
||||
if current_world_size >= 3:
|
||||
second_exclusion = [
|
||||
current_world_size - 1
|
||||
] # Exclude the new "last" rank
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}",
|
||||
)
|
||||
|
||||
if self.rank not in second_exclusion:
|
||||
# Only non-excluded ranks should call shrink_group for second iteration
|
||||
second_pg = c10d.shrink_group(second_exclusion, group=first_pg)
|
||||
self.assertIsNotNone(second_pg)
|
||||
# IMPORTANT: Update world size after second shrinking
|
||||
final_world_size = second_pg.size()
|
||||
expected_final_size = current_world_size - len(second_exclusion)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"After second shrinking: world_size {current_world_size} -> {final_world_size}",
|
||||
)
|
||||
self.assertEqual(second_pg.size(), expected_final_size)
|
||||
|
||||
# Test collective on final group
|
||||
tensor = torch.full((1,), self.rank).cuda(device)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}",
|
||||
)
|
||||
c10d.all_reduce(tensor, group=second_pg)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Final all_reduce completed, result: {tensor.item()}",
|
||||
)
|
||||
|
||||
# Calculate expected sum of remaining ranks
|
||||
all_excluded = set(first_exclusion + second_exclusion)
|
||||
remaining_ranks = [
|
||||
r for r in range(self.world_size) if r not in all_excluded
|
||||
]
|
||||
expected_sum = sum(remaining_ranks)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}",
|
||||
)
|
||||
self.assertEqual(tensor.item(), expected_sum)
|
||||
log_test_info(self.rank, "Final verification passed")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in second shrinking, not calling shrink_group",
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, "Skipping second shrinking (remaining group too small)"
|
||||
)
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
"This rank excluded in first shrinking, not calling shrink_group",
|
||||
)
|
||||
|
||||
log_test_info(self.rank, "Destroying process group")
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, "test_shrink_group_multiple_iterations completed")
|
||||
|
||||
# Helper methods for optimized shrink group tests
|
||||
def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True):
|
||||
"""Common setup for shrink group tests."""
|
||||
os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1"
|
||||
world_size = world_size or self.world_size
|
||||
store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
c10d.init_process_group(
|
||||
"nccl",
|
||||
world_size=world_size,
|
||||
rank=self.rank,
|
||||
store=store,
|
||||
pg_options=self.opts(),
|
||||
device_id=device,
|
||||
)
|
||||
pg = c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if warmup:
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=pg)
|
||||
|
||||
return device, pg
|
||||
|
||||
def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""):
|
||||
"""Validate properties of a shrunk process group."""
|
||||
self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None")
|
||||
actual_size = shrunk_pg.size()
|
||||
self.assertEqual(
|
||||
actual_size, expected_size, f"{test_name}: group size mismatch"
|
||||
)
|
||||
|
||||
new_rank = shrunk_pg.rank()
|
||||
self.assertTrue(
|
||||
0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}"
|
||||
)
|
||||
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}",
|
||||
)
|
||||
return new_rank
|
||||
|
||||
def _test_collective_on_shrunk_group(
|
||||
self, shrunk_pg, device, ranks_to_exclude, test_name=""
|
||||
):
|
||||
"""Test collective communication on shrunk group and verify correctness."""
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
c10d.all_reduce(test_tensor, group=shrunk_pg)
|
||||
|
||||
result = test_tensor.item()
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result, expected_sum, f"{test_name}: collective result mismatch"
|
||||
)
|
||||
log_test_info(
|
||||
self.rank, f"{test_name}: collective passed ({result} == {expected_sum})"
|
||||
)
|
||||
return result
|
||||
|
||||
def _perform_shrink_test(
|
||||
self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True
|
||||
):
|
||||
"""Complete shrink test flow: setup, shrink, validate, test collective, cleanup.
|
||||
|
||||
Consistent API: All ranks perform setup to initialize distributed environment.
|
||||
ONLY non-excluded ranks call shrink_group() for both default and non-default groups.
|
||||
Excluded ranks perform setup, then exit without calling shrink_group() or waiting.
|
||||
"""
|
||||
log_test_info(self.rank, f"{test_name} (world_size={self.world_size})")
|
||||
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
# All ranks (including excluded ones) perform setup to initialize distributed environment
|
||||
device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_"))
|
||||
is_default_group = pg == c10d.distributed_c10d._get_default_group()
|
||||
|
||||
if is_excluded:
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluded rank {self.rank} - setup complete, skipping shrink operation",
|
||||
)
|
||||
if shrink_flags & NCCL_SHRINK_ABORT:
|
||||
log_test_info(self.rank, f"Using abort for excluded rank {self.rank}")
|
||||
pg._get_backend(torch.device(device)).abort()
|
||||
log_test_info(
|
||||
self.rank, f"cleanup resources for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
log_test_info(self.rank, f"Excluded rank {self.rank} - exit")
|
||||
else:
|
||||
log_test_info(
|
||||
self.rank, f"Using regular destroy for excluded rank {self.rank}"
|
||||
)
|
||||
dist.destroy_process_group()
|
||||
return None
|
||||
|
||||
# Only non-excluded ranks proceed with shrink
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group})",
|
||||
)
|
||||
shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags)
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done",
|
||||
)
|
||||
|
||||
# Non-excluded ranks: validate and test the new group
|
||||
expected_size = self.world_size - len(ranks_to_exclude)
|
||||
_ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name)
|
||||
|
||||
if with_collective:
|
||||
_ = self._test_collective_on_shrunk_group(
|
||||
shrunk_pg, device, ranks_to_exclude, test_name
|
||||
)
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink + collective)")
|
||||
else:
|
||||
log_test_success(self.rank, f"{test_name} successful (shrink only)")
|
||||
|
||||
dist.destroy_process_group()
|
||||
return shrunk_pg
|
||||
|
||||
def _get_default_ranks_to_exclude(self):
|
||||
"""Get default ranks to exclude based on world size."""
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
return [self.world_size - 1] # Exclude last rank by default
|
||||
|
||||
@requires_nccl_shrink()
|
||||
@requires_world_size(3)
|
||||
def test_shrink_group_vs_abort_reinit_performance(self):
|
||||
"""Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability)."""
|
||||
log_test_info(self.rank, "=== TEST 1: abort+reinit ===")
|
||||
|
||||
device, pg1 = self._setup_shrink_test("_perf_reinit")
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Test 1: Traditional abort + reinit
|
||||
start_time = time.perf_counter()
|
||||
dist.destroy_process_group()
|
||||
|
||||
device, new_pg = self._setup_shrink_test("perf_shrink_test1")
|
||||
reinit_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective with original rank values for fair comparison (non-blocking mode)
|
||||
test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32)
|
||||
work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(r for r in range(self.world_size))
|
||||
self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed")
|
||||
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
dist.destroy_process_group(new_pg)
|
||||
|
||||
# Test 2: shrink_group with NCCL_SHRINK_ABORT
|
||||
log_test_info(self.rank, "=== TEST 2: shrink_group ===")
|
||||
|
||||
ranks_to_exclude = [self.world_size - 1]
|
||||
is_excluded = self.rank in ranks_to_exclude
|
||||
log_test_info(
|
||||
self.rank,
|
||||
f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}",
|
||||
)
|
||||
|
||||
device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix
|
||||
|
||||
shrink_time = 0
|
||||
if not is_excluded:
|
||||
torch.cuda.synchronize(device) # Ensure accurate timing
|
||||
start_time = time.perf_counter()
|
||||
shrunk_pg = c10d.shrink_group(
|
||||
ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT
|
||||
)
|
||||
c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg)
|
||||
shrink_time = time.perf_counter() - start_time
|
||||
|
||||
# Test collective communication on shrunk group (non-blocking mode)
|
||||
test_tensor = torch.full(
|
||||
(1,), self.rank, device=device, dtype=torch.float32
|
||||
)
|
||||
work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True)
|
||||
work.wait()
|
||||
|
||||
# Verify correctness
|
||||
expected_sum = sum(
|
||||
r for r in range(self.world_size) if r not in ranks_to_exclude
|
||||
)
|
||||
self.assertEqual(
|
||||
test_tensor.item(),
|
||||
expected_sum,
|
||||
"shrink_test: collective result mismatch",
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(device) # Ensure operations complete
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
dist.destroy_process_group()
|
||||
else:
|
||||
log_test_info(self.rank, "Excluded from shrink test - exiting immediately")
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
# Performance analysis (only for participating ranks)
|
||||
if shrink_time > 0 and reinit_time > 0:
|
||||
speedup = reinit_time / shrink_time
|
||||
time_saved = reinit_time - shrink_time
|
||||
|
||||
log_test_info(self.rank, "=== PERFORMANCE RESULTS ===")
|
||||
log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s")
|
||||
log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s")
|
||||
log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s")
|
||||
log_test_info(self.rank, f"speedup: {speedup:.2f}x")
|
||||
|
||||
if speedup > 1.1:
|
||||
log_test_success(self.rank, "shrink_group significantly faster")
|
||||
elif speedup > 0.9:
|
||||
log_test_info(self.rank, "≈ comparable performance")
|
||||
else:
|
||||
log_test_warning(self.rank, "abort+reinit faster")
|
||||
|
||||
log_test_info(self.rank, "Performance test completed")
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_deterministic_mode_no_break(self):
|
||||
@ -5115,6 +5789,229 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in t["entries"][0])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_circular_buffer_full(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer in entries_ is full and we call reset,
|
||||
then fill the buffer with new entries, dump_entries returns only the new
|
||||
entries and not the old ones.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely with 10 entries
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify buffer is full with 10 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Now reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add new entries after reset - fill the buffer completely again
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we get exactly 10 new entries, not 20
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 10)
|
||||
|
||||
# Verify all entries have the expected properties (from after reset)
|
||||
# After reset, record IDs should start from 0 again
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertEqual(entry["profiling_name"], "nccl:all_reduce")
|
||||
self.assertIn("record_id", entry)
|
||||
# Record IDs should be sequential starting from 0 after reset
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_partial_overwrite(self, timing_enabled):
|
||||
"""
|
||||
Test that when the circular buffer is full, we reset, and then add fewer
|
||||
entries than the buffer size, we only get the new entries.
|
||||
This tests that old entries at the end of the circular buffer are properly
|
||||
filtered out based on reset_epoch.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill the buffer completely
|
||||
for _ in range(10):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset the flight recorder
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Add only 3 new entries (much less than buffer size)
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Verify we only get the 3 new entries, not 10
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 3)
|
||||
|
||||
# Verify record IDs start from 0 after reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_reset_wraparound(self, timing_enabled):
|
||||
"""
|
||||
Test that when we reset in the middle of the circular buffer and then
|
||||
wrap around, dump_entries correctly returns only entries from the current
|
||||
epoch in the correct order.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# Fill half the buffer
|
||||
for _ in range(5):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Reset at this point (reset happens at index 5)
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Now add 8 entries, which will wrap around
|
||||
# (5->9 fills rest of buffer, then 0->2 wraps around)
|
||||
for _ in range(8):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should get exactly 8 entries, properly ordered
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 8)
|
||||
|
||||
# Entries should be in chronological order
|
||||
# The dump_entries() method returns entries from next_ to end, then 0 to next_
|
||||
# After filtering old entries, we should have 8 entries in order
|
||||
# Verify record IDs start from 0 after reset (id_ is reset in reset_all())
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("profiling_name", entry)
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_fr_record_multiple_resets(self, timing_enabled):
|
||||
"""
|
||||
Test multiple consecutive resets to ensure each reset properly increments
|
||||
the epoch and filters out entries from previous epochs.
|
||||
"""
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
|
||||
# Override buffer size to 10 for faster testing
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10"
|
||||
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
self.set_thread_name("fr_test_thread")
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
|
||||
# First batch: 2 entries
|
||||
for _ in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# First reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Second batch: 3 entries
|
||||
for _ in range(3):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Second reset
|
||||
torch._C._distributed_c10d._reset_fr_recording_nccl()
|
||||
|
||||
# Third batch: 4 entries
|
||||
for _ in range(4):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
time.sleep(1)
|
||||
|
||||
# Should only see the last 4 entries
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
self.assertEqual(len(t["entries"]), 4)
|
||||
|
||||
# Verify record IDs start from 0 after the last reset
|
||||
for i, entry in enumerate(t["entries"]):
|
||||
self.assertIn("record_id", entry)
|
||||
self.assertEqual(entry["record_id"], i)
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def check_if_test_is_skipped(fn):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
@ -5446,6 +6343,14 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase):
|
||||
if self.rank == 6 or self.rank == 7:
|
||||
dist.broadcast(tensor2, 6, group=ng2)
|
||||
self.assertEqual(tensor2, torch.full((1,), 6))
|
||||
|
||||
# Test the case when the split changes the pg option of split group
|
||||
# while the parent pg option is not changed.
|
||||
new_pg = c10d.new_group([0, 1, 2, 3, 4, 5, 6, 7], device_id=device)
|
||||
backend_new_pg = new_pg._get_backend(torch.device(device))
|
||||
self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8)
|
||||
c10d.split_group(new_pg, [[0, 2, 4, 6], [1, 3, 5, 7]])
|
||||
self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8)
|
||||
# a barrier and a cuda sync before destroying all pgs.
|
||||
dist.barrier(pg)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -7,6 +7,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._local_tensor import (
|
||||
local_tensor_mode,
|
||||
LocalIntNode,
|
||||
LocalRunnerMode,
|
||||
LocalTensor,
|
||||
LocalTensorMode,
|
||||
)
|
||||
@ -17,8 +19,10 @@ from torch.distributed.tensor import (
|
||||
Partial,
|
||||
Replicate,
|
||||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import reduce_local_int
|
||||
|
||||
|
||||
class LocalTensorTestBase(TestCase):
|
||||
@ -411,5 +415,78 @@ class TestLocalTensorWorld8(LocalTensorTestBase):
|
||||
self.assertEqual(full_tensor, local_res)
|
||||
|
||||
|
||||
from torch.distributed._local_tensor._c10d import local_p2p_op, wait_all
|
||||
|
||||
|
||||
class TestLocalRunner(LocalTensorTestBase):
|
||||
world_size = 6
|
||||
|
||||
@staticmethod
|
||||
def _get_pp_peer(pp_index, mesh, dim, dir):
|
||||
pp_meshes = mesh._get_all_submeshes(dim)
|
||||
pp_ret = {}
|
||||
for pp_mesh in pp_meshes:
|
||||
global_rank = pp_mesh.mesh[pp_index].item()
|
||||
global_peer = pp_mesh.mesh[(pp_index + dir) % pp_mesh.size()].item()
|
||||
pp_ret[global_rank] = global_peer
|
||||
|
||||
return torch.SymInt(LocalIntNode(pp_ret))
|
||||
|
||||
def _run_dp_pp(
|
||||
self,
|
||||
mesh: DeviceMesh,
|
||||
pp_index: int,
|
||||
actual: list[torch.Tensor | None],
|
||||
expected: list[torch.Tensor | None],
|
||||
) -> None:
|
||||
ltm = LocalTensorMode(mesh.size())
|
||||
with ltm:
|
||||
dp_mesh = mesh["dp"]
|
||||
pp_mesh = mesh["pp"]
|
||||
|
||||
x = torch.rand(2, 4)
|
||||
xd = distribute_tensor(x, dp_mesh, [Shard(0)])
|
||||
xd = xd * 2
|
||||
x = x * 2
|
||||
|
||||
yd = zeros(*xd.shape, device_mesh=dp_mesh, placements=[Shard(0)])
|
||||
|
||||
if pp_index != pp_mesh.size(0) - 1:
|
||||
# Send to next pp rank
|
||||
pp_next_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", +1)
|
||||
local_p2p_op(pp_next_rank, xd, dist.isend)
|
||||
expected[pp_index + 1] = ltm.tensor_map(
|
||||
x,
|
||||
lambda r, t: t
|
||||
if reduce_local_int(pp_next_rank, lambda vals: r in vals.values())
|
||||
else torch.zeros_like(t),
|
||||
)
|
||||
|
||||
if pp_index != 0:
|
||||
# Receive from prev pp rank
|
||||
pp_prev_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", -1)
|
||||
rw = local_p2p_op(pp_prev_rank, yd, dist.irecv)
|
||||
wait_all(rw)
|
||||
|
||||
y = yd.full_tensor()
|
||||
actual[pp_index] = y
|
||||
|
||||
def test_dp_pp(self):
|
||||
pp_size = 3
|
||||
mesh = init_device_mesh(
|
||||
"cpu", (self.world_size // pp_size, pp_size), mesh_dim_names=("dp", "pp")
|
||||
)
|
||||
actual: list[torch.Tensor | None] = [None] * pp_size
|
||||
expected: list[torch.Tensor | None] = [None] * pp_size
|
||||
with LocalRunnerMode(
|
||||
self.world_size,
|
||||
pp_size,
|
||||
lambda pp_index: self._run_dp_pp(mesh, pp_index, actual, expected),
|
||||
):
|
||||
pass
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
diff --git a/test/dynamo/cpython/3_13/test_heapq.py b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
index 1aa8e4e2897..94315fa68b4 100644
|
||||
index 1aa8e4e2897..bc177c2943e 100644
|
||||
--- a/test/dynamo/cpython/3_13/test_heapq.py
|
||||
+++ b/test/dynamo/cpython/3_13/test_heapq.py
|
||||
@@ -1,3 +1,23 @@
|
||||
@ -35,7 +35,7 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
def test_py_functions(self):
|
||||
for fname in func_names:
|
||||
self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
|
||||
@@ -27,24 +47,7 @@ class TestModules(TestCase):
|
||||
@@ -27,24 +47,12 @@ class TestModules(TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@ -46,12 +46,15 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
- # However, doctest can't easily find all docstrings in the module (loading
|
||||
- # it through import_fresh_module seems to confuse it), so we specifically
|
||||
- # create a finder which returns the doctests from the merge method.
|
||||
-
|
||||
+@torch._dynamo.disable
|
||||
+def randrange(*args):
|
||||
+ return random.randrange(*args)
|
||||
|
||||
- class HeapqMergeDocTestFinder:
|
||||
- def find(self, *args, **kwargs):
|
||||
- dtf = doctest.DocTestFinder()
|
||||
- return dtf.find(py_heapq.merge)
|
||||
-
|
||||
|
||||
- tests.addTests(doctest.DocTestSuite(py_heapq,
|
||||
- test_finder=HeapqMergeDocTestFinder()))
|
||||
- return tests
|
||||
@ -61,7 +64,155 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
|
||||
def test_push_pop(self):
|
||||
# 1) Push 256 random numbers and pop them off, verifying all's OK.
|
||||
@@ -264,12 +267,12 @@ class TestHeap:
|
||||
@@ -52,7 +60,8 @@ class TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
- item = random.random()
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@@ -83,14 +92,16 @@ class TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
- heap = [random.random() for dummy in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@@ -113,7 +124,8 @@ class TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -126,7 +138,8 @@ class TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
- data = [random.randrange(2000) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@@ -163,8 +176,9 @@ class TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
- size = random.randrange(50)
|
||||
- data = [random.randrange(25) for i in range(size)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ size = randrange(50)
|
||||
+ data = [randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@@ -177,12 +191,13 @@ class TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
- for i in range(random.randrange(25)):
|
||||
- row = []
|
||||
- for j in range(random.randrange(100)):
|
||||
- tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
- row.append(tup)
|
||||
- inputs.append(row)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ for i in range(randrange(25)):
|
||||
+ row = []
|
||||
+ for j in range(randrange(100)):
|
||||
+ tup = random.choice('ABC'), randrange(-500, 500)
|
||||
+ row.append(tup)
|
||||
+ inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@@ -209,12 +224,14 @@ class TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
- class Int(int):
|
||||
- pass
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class Int(int):
|
||||
+ pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
- stream = random.randrange(4)
|
||||
- x = random.randrange(500)
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ stream = randrange(4)
|
||||
+ x = randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@@ -224,7 +241,8 @@ class TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@@ -233,7 +251,8 @@ class TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
- data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@@ -248,28 +267,29 @@ class TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
- class LT:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __lt__(self, other):
|
||||
- return self.x > other.x
|
||||
- class LE:
|
||||
- def __init__(self, x):
|
||||
- self.x = x
|
||||
- def __le__(self, other):
|
||||
- return self.x >= other.x
|
||||
- data = [random.random() for i in range(100)]
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class LT:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __lt__(self, other):
|
||||
+ return self.x > other.x
|
||||
+ class LE:
|
||||
+ def __init__(self, x):
|
||||
+ self.x = x
|
||||
+ def __le__(self, other):
|
||||
+ return self.x >= other.x
|
||||
+ data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
|
||||
|
||||
@ -76,7 +227,7 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
module = c_heapq
|
||||
|
||||
|
||||
@@ -374,7 +377,7 @@ class SideEffectLT:
|
||||
@@ -374,7 +394,7 @@ class SideEffectLT:
|
||||
return self.value < other.value
|
||||
|
||||
|
||||
@ -85,7 +236,48 @@ index 1aa8e4e2897..94315fa68b4 100644
|
||||
|
||||
def test_non_sequence(self):
|
||||
for f in (self.module.heapify, self.module.heappop):
|
||||
@@ -464,13 +467,13 @@ class TestErrorHandling:
|
||||
@@ -435,10 +455,11 @@ class TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
- class EvilClass(int):
|
||||
- def __lt__(self, o):
|
||||
- heap.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class EvilClass(int):
|
||||
+ def __lt__(self, o):
|
||||
+ heap.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@@ -446,15 +467,16 @@ class TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
- class h(int):
|
||||
- def __lt__(self, o):
|
||||
- list2.clear()
|
||||
- return NotImplemented
|
||||
+ with torch._dynamo.error_on_graph_break(False):
|
||||
+ class h(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list2.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
- class g(int):
|
||||
- def __lt__(self, o):
|
||||
- list1.clear()
|
||||
- return NotImplemented
|
||||
+ class g(int):
|
||||
+ def __lt__(self, o):
|
||||
+ list1.clear()
|
||||
+ return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
@@ -464,13 +486,13 @@ class TestErrorHandling:
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1))
|
||||
self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1))
|
||||
|
||||
|
||||
@ -47,6 +47,11 @@ class TestModules(__TestCase):
|
||||
self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def randrange(*args):
|
||||
return random.randrange(*args)
|
||||
|
||||
|
||||
class _TestHeap:
|
||||
|
||||
def test_push_pop(self):
|
||||
@ -55,7 +60,8 @@ class _TestHeap:
|
||||
data = []
|
||||
self.check_invariant(heap)
|
||||
for i in range(256):
|
||||
item = random.random()
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
item = random.random()
|
||||
data.append(item)
|
||||
self.module.heappush(heap, item)
|
||||
self.check_invariant(heap)
|
||||
@ -86,14 +92,16 @@ class _TestHeap:
|
||||
|
||||
def test_heapify(self):
|
||||
for size in list(range(30)) + [20000]:
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
heap = [random.random() for dummy in range(size)]
|
||||
self.module.heapify(heap)
|
||||
self.check_invariant(heap)
|
||||
|
||||
self.assertRaises(TypeError, self.module.heapify, None)
|
||||
|
||||
def test_naive_nbest(self):
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = []
|
||||
for item in data:
|
||||
self.module.heappush(heap, item)
|
||||
@ -116,7 +124,8 @@ class _TestHeap:
|
||||
# heap instead of a min heap, it could go faster still via
|
||||
# heapify'ing all of data (linear time), then doing 10 heappops
|
||||
# (10 log-time steps).
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -129,7 +138,8 @@ class _TestHeap:
|
||||
self.assertRaises(IndexError, self.module.heapreplace, [], None)
|
||||
|
||||
def test_nbest_with_pushpop(self):
|
||||
data = [random.randrange(2000) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [randrange(2000) for i in range(1000)]
|
||||
heap = data[:10]
|
||||
self.module.heapify(heap)
|
||||
for item in data[10:]:
|
||||
@ -166,8 +176,9 @@ class _TestHeap:
|
||||
def test_heapsort(self):
|
||||
# Exercise everything with repeated heapsort checks
|
||||
for trial in range(100):
|
||||
size = random.randrange(50)
|
||||
data = [random.randrange(25) for i in range(size)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
size = randrange(50)
|
||||
data = [randrange(25) for i in range(size)]
|
||||
if trial & 1: # Half of the time, use heapify
|
||||
heap = data[:]
|
||||
self.module.heapify(heap)
|
||||
@ -180,12 +191,13 @@ class _TestHeap:
|
||||
|
||||
def test_merge(self):
|
||||
inputs = []
|
||||
for i in range(random.randrange(25)):
|
||||
row = []
|
||||
for j in range(random.randrange(100)):
|
||||
tup = random.choice('ABC'), random.randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
for i in range(randrange(25)):
|
||||
row = []
|
||||
for j in range(randrange(100)):
|
||||
tup = random.choice('ABC'), randrange(-500, 500)
|
||||
row.append(tup)
|
||||
inputs.append(row)
|
||||
|
||||
for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
|
||||
for reverse in [False, True]:
|
||||
@ -212,12 +224,14 @@ class _TestHeap:
|
||||
list(self.module.merge(iterable(), iterable()))
|
||||
|
||||
def test_merge_stability(self):
|
||||
class Int(int):
|
||||
pass
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class Int(int):
|
||||
pass
|
||||
inputs = [[], [], [], []]
|
||||
for i in range(20000):
|
||||
stream = random.randrange(4)
|
||||
x = random.randrange(500)
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
stream = randrange(4)
|
||||
x = randrange(500)
|
||||
obj = Int(x)
|
||||
obj.pair = (x, stream)
|
||||
inputs[stream].append(obj)
|
||||
@ -227,7 +241,8 @@ class _TestHeap:
|
||||
self.assertEqual(result, sorted(result))
|
||||
|
||||
def test_nsmallest(self):
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nsmallest(n, data)),
|
||||
@ -236,7 +251,8 @@ class _TestHeap:
|
||||
sorted(data, key=f)[:n])
|
||||
|
||||
def test_nlargest(self):
|
||||
data = [(random.randrange(2000), i) for i in range(1000)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
data = [(randrange(2000), i) for i in range(1000)]
|
||||
for f in (None, lambda x: x[0] * 547 % 2000):
|
||||
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
|
||||
self.assertEqual(list(self.module.nlargest(n, data)),
|
||||
@ -251,17 +267,18 @@ class _TestHeap:
|
||||
data = [comp(x) for x in data]
|
||||
self.module.heapify(data)
|
||||
return [self.module.heappop(data).x for i in range(len(data))]
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class LT:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __lt__(self, other):
|
||||
return self.x > other.x
|
||||
class LE:
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
def __le__(self, other):
|
||||
return self.x >= other.x
|
||||
data = [random.random() for i in range(100)]
|
||||
target = sorted(data, reverse=True)
|
||||
self.assertEqual(hsort(data, LT), target)
|
||||
self.assertRaises(TypeError, data, LE)
|
||||
@ -438,10 +455,11 @@ class _TestErrorHandling:
|
||||
def test_comparison_operator_modifiying_heap(self):
|
||||
# See bpo-39421: Strong references need to be taken
|
||||
# when comparing objects as they can alter the heap
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class EvilClass(int):
|
||||
def __lt__(self, o):
|
||||
heap.clear()
|
||||
return NotImplemented
|
||||
|
||||
heap = []
|
||||
self.module.heappush(heap, EvilClass(0))
|
||||
@ -449,15 +467,16 @@ class _TestErrorHandling:
|
||||
|
||||
def test_comparison_operator_modifiying_heap_two_heaps(self):
|
||||
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
with torch._dynamo.error_on_graph_break(False):
|
||||
class h(int):
|
||||
def __lt__(self, o):
|
||||
list2.clear()
|
||||
return NotImplemented
|
||||
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
class g(int):
|
||||
def __lt__(self, o):
|
||||
list1.clear()
|
||||
return NotImplemented
|
||||
|
||||
list1, list2 = [], []
|
||||
|
||||
|
||||
@ -283,7 +283,7 @@ class TestCompilerBisector(TestCase):
|
||||
)
|
||||
def test_bisect_pre_grad_graph(self):
|
||||
def f(x):
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
x = x + 1
|
||||
return x.relu()
|
||||
|
||||
|
||||
@ -36,6 +36,15 @@ class DummyUserDict(UserDict):
|
||||
pass
|
||||
|
||||
|
||||
class FakeMapping:
|
||||
def __init__(self, value: Any) -> None:
|
||||
self._value = value
|
||||
self.keys = lambda: ["a", "b", "c"] # not required to be a method
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._value
|
||||
|
||||
|
||||
class DictTests(torch._dynamo.test_case.TestCase):
|
||||
def test_dict_subclass_instantiation(self):
|
||||
def fn(x):
|
||||
@ -666,6 +675,18 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
for k1, m2 in zip(modules, module_dict.children()):
|
||||
self.assertTrue(modules[k1] is m2)
|
||||
|
||||
# FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update
|
||||
@unittest.expectedFailure
|
||||
def test_dict_construct_from_mapping_like(self):
|
||||
def fn(x):
|
||||
fm = FakeMapping(x)
|
||||
d = dict(fm, x=x)
|
||||
return d
|
||||
|
||||
x = torch.randn(4)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dict_subclass_initialization_in_graph(self):
|
||||
for super_class in (
|
||||
OrderedDict,
|
||||
@ -1087,12 +1108,52 @@ class DictTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_newly_constructed_default_dict_no_default_factory(self):
|
||||
def f1(x):
|
||||
d = defaultdict()
|
||||
try:
|
||||
d[1] += 42
|
||||
except KeyError:
|
||||
d[1] = 1
|
||||
return x + 1, d
|
||||
|
||||
x = torch.ones(2)
|
||||
ref = f1(x)
|
||||
res = torch.compile(f1, backend="eager", fullgraph=True)(x)
|
||||
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def f2(x):
|
||||
d = defaultdict(None)
|
||||
try:
|
||||
d[1] += 42
|
||||
except KeyError:
|
||||
d[1] = 1
|
||||
return x + 1, d
|
||||
|
||||
ref = f2(x)
|
||||
res = torch.compile(f2, backend="eager", fullgraph=True)(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def f3(x):
|
||||
d = defaultdict(None, {1: 10})
|
||||
d[1] += 42
|
||||
try:
|
||||
d[2] += 24
|
||||
except KeyError:
|
||||
d[2] = 1
|
||||
return x + 1, d
|
||||
|
||||
ref = f3(x)
|
||||
res = torch.compile(f3, backend="eager", fullgraph=True)(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_newly_constructed_default_dict_with_dict(self):
|
||||
def f(x):
|
||||
d = defaultdict(dict, {2: {"a": 1}})
|
||||
d[0] = {"b": 2}
|
||||
return x + 1, d
|
||||
d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406
|
||||
dd = defaultdict(list, d, d=4, e=5)
|
||||
dd["x"].append(42)
|
||||
return x + 1, d, dd
|
||||
|
||||
x = torch.ones(2)
|
||||
ref = f(x)
|
||||
|
||||
@ -427,17 +427,29 @@ from user code:
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return re.sub(
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
self.assertExpectedInline(
|
||||
first_graph_break,
|
||||
post_munge(first_graph_break),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree._C.PyCapsule.flatten_with_path.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: PyCapsule.flatten_with_path, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
|
||||
from torch.compiler import allow_in_graph
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
return extract_graph(fn, *args, **kwargs)[0:3]
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
||||
@ -69,6 +69,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
constrain_unify,
|
||||
ConstraintViolationError,
|
||||
expect_true,
|
||||
guard_or_false,
|
||||
guard_size_oblivious,
|
||||
ShapeEnv,
|
||||
)
|
||||
@ -100,7 +101,6 @@ from torch.testing._internal.common_utils import (
|
||||
wrapDeterministicFlagAPITest,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
pytree_modules = {
|
||||
@ -13194,6 +13194,30 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_dict_order(self, pytree):
|
||||
def fn(tree):
|
||||
new_tree = pytree.tree_map(lambda x: x, tree)
|
||||
return list(new_tree.keys()), list(new_tree.values())
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
fn_opt = torch.compile(fullgraph=True)(fn)
|
||||
|
||||
tree1 = {"b": x + 2, "a": x, "c": x - 1}
|
||||
expected1 = fn(tree1)
|
||||
actual1 = fn_opt(tree1)
|
||||
self.assertEqual(actual1, expected1)
|
||||
|
||||
tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)])
|
||||
expected2 = fn(tree2)
|
||||
actual2 = fn_opt(tree2)
|
||||
self.assertEqual(actual2, expected2)
|
||||
|
||||
tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1})
|
||||
expected3 = fn(tree3)
|
||||
actual3 = fn_opt(tree3)
|
||||
self.assertEqual(actual3, expected3)
|
||||
|
||||
@parametrize_pytree_module
|
||||
def test_pytree_tree_map_only(self, pytree):
|
||||
if not callable(getattr(pytree, "tree_map_only", None)):
|
||||
@ -13219,6 +13243,27 @@ class MiscTestsPyTree(torch._inductor.test_case.TestCase):
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
def test_pytree_register_constant_with_side_effect(self):
|
||||
class Foo:
|
||||
pass
|
||||
|
||||
class Bar:
|
||||
def __eq__(self, other):
|
||||
return super().__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return 0
|
||||
|
||||
python_pytree.register_constant(Bar)
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(x, obj):
|
||||
obj.attr = {3: Bar()}
|
||||
return x + 1
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(fn(inp, Foo()), inp + 1)
|
||||
|
||||
|
||||
class TestTracer(JitTestCase):
|
||||
def test_jit_save(self):
|
||||
@ -13636,6 +13681,74 @@ instantiate_device_type_tests(
|
||||
)
|
||||
|
||||
|
||||
class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul(self):
|
||||
def symbool_mul_fn(x_bool, sentinel):
|
||||
result = x_bool * sentinel
|
||||
return result
|
||||
|
||||
x_true = torch.tensor([True], device="cuda")
|
||||
x_false = torch.tensor([False], device="cuda")
|
||||
sentinel = torch.tensor(2.0, requires_grad=True, device="cuda")
|
||||
eager_result_true = symbool_mul_fn(x_true, sentinel)
|
||||
eager_result_false = symbool_mul_fn(x_false, sentinel)
|
||||
compiled_fn = torch.compile(symbool_mul_fn, fullgraph=True, dynamic=True)
|
||||
compiled_result_true = compiled_fn(x_true, sentinel)
|
||||
compiled_result_false = compiled_fn(x_false, sentinel)
|
||||
self.assertEqual(eager_result_true, compiled_result_true)
|
||||
self.assertEqual(eager_result_false, compiled_result_false)
|
||||
self.assertEqual(compiled_result_true.item(), 2.0)
|
||||
self.assertEqual(compiled_result_false.item(), 0.0)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_guard_or_false(self):
|
||||
def symbool_guard_fn(a_bool_tensor, b):
|
||||
u0 = a_bool_tensor.item()
|
||||
# Make sure guard_or_false still handles SymBool produced by .item()
|
||||
if guard_or_false(u0):
|
||||
return b * 10
|
||||
else:
|
||||
return b * 100
|
||||
|
||||
compiled_guard_fn = torch.compile(
|
||||
symbool_guard_fn, backend="eager", dynamic=True
|
||||
)
|
||||
a_true = torch.tensor(True, device="cuda")
|
||||
a_false = torch.tensor(False, device="cuda")
|
||||
b = torch.randn(6, device="cuda")
|
||||
eager_res_true = symbool_guard_fn(a_true, b)
|
||||
compiled_res_true = compiled_guard_fn(a_true, b)
|
||||
self.assertEqual(eager_res_true, compiled_res_true)
|
||||
eager_res_false = symbool_guard_fn(a_false, b)
|
||||
compiled_res_false = compiled_guard_fn(a_false, b)
|
||||
self.assertEqual(eager_res_false, compiled_res_false)
|
||||
self.assertEqual(compiled_res_true, b * 10)
|
||||
self.assertEqual(compiled_res_false, b * 100)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "This test requires a CUDA device")
|
||||
def test_symbool_tensor_mul_does_not_fail(self):
|
||||
def fuzzed_program(arg_0, sentinel):
|
||||
var_node_2 = arg_0
|
||||
var_node_1 = torch.squeeze(var_node_2)
|
||||
var_node_0 = var_node_1.item()
|
||||
result = var_node_0 * sentinel
|
||||
if result.is_complex():
|
||||
result = result.real
|
||||
return result
|
||||
|
||||
sentinel = torch.tensor(1.0, requires_grad=True, device="cuda")
|
||||
arg_0 = torch.tensor([True], dtype=torch.bool, device="cuda")
|
||||
args = (arg_0,) + (sentinel,)
|
||||
try:
|
||||
compiled_program = torch.compile(
|
||||
fuzzed_program, fullgraph=True, dynamic=True
|
||||
)
|
||||
compiled_program(*args)
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.testing import CompileCounter, same
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
"""
|
||||
This is an example of a pure-python version of autograd implemented by
|
||||
@zdevito. It represents a rather challenging test case for TorchDynamo
|
||||
|
||||
@ -48,6 +48,7 @@ from torch._dynamo.testing import (
|
||||
CompileCounter,
|
||||
CompileCounterWithBackend,
|
||||
EagerAndRecordGraphs,
|
||||
expectedFailureDynamic,
|
||||
rand_strided,
|
||||
same,
|
||||
skipIfNotPy312,
|
||||
@ -1000,6 +1001,18 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def test_compiled_module_truthiness(self):
|
||||
# Test with empty ModuleList
|
||||
original_empty = nn.ModuleList()
|
||||
compiled_empty = torch.compile(original_empty)
|
||||
self.assertEqual(bool(original_empty), bool(compiled_empty))
|
||||
self.assertFalse(bool(compiled_empty))
|
||||
# Test with non-empty ModuleList
|
||||
original_filled = nn.ModuleList([nn.Linear(10, 5)])
|
||||
compiled_filled = torch.compile(original_filled)
|
||||
self.assertEqual(bool(original_filled), bool(compiled_filled))
|
||||
self.assertTrue(bool(compiled_filled))
|
||||
|
||||
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder):
|
||||
root = guard_manager_wrapper.root
|
||||
cloned_root = root.clone_manager(lambda x: True)
|
||||
@ -7443,6 +7456,93 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
msg,
|
||||
)
|
||||
|
||||
@expectedFailureDynamic
|
||||
def test_dynamo_default_lru_cache_behavior(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x + 10
|
||||
|
||||
torch._dynamo.reset()
|
||||
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
|
||||
# Step 1: Compile a static shapes graph
|
||||
x = torch.randn(10, 10)
|
||||
fn(x)
|
||||
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(a), 1)
|
||||
static_shapes_cache_entry = a[0]
|
||||
|
||||
# Step 2: Compile a dynamic shapes graph
|
||||
y = torch.randn(20, 20)
|
||||
fn(y)
|
||||
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(b), 2)
|
||||
self.assertEqual(b[1], static_shapes_cache_entry)
|
||||
dynamic_shapes_cache_entry = b[0]
|
||||
|
||||
# Step 3: Run with Step 1's inputs
|
||||
# LRU cache will match against dynamic shape graph first
|
||||
fn(x)
|
||||
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(c), 2)
|
||||
self.assertEqual(c[0], dynamic_shapes_cache_entry)
|
||||
self.assertEqual(c[1], static_shapes_cache_entry)
|
||||
|
||||
@expectedFailureDynamic
|
||||
def test_dynamo_disable_lru_cache_behavior(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return x + 10
|
||||
|
||||
def run():
|
||||
torch._dynamo.reset()
|
||||
assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
|
||||
# Step 1: Compile a static shapes graph
|
||||
x = torch.randn(10, 10)
|
||||
fn(x)
|
||||
a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(a), 1)
|
||||
static_shapes_cache_entry = a[0]
|
||||
|
||||
# Step 2: Compile a dynamic shapes graph
|
||||
y = torch.randn(20, 20)
|
||||
fn(y)
|
||||
b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(b), 2)
|
||||
self.assertEqual(b[0], static_shapes_cache_entry)
|
||||
dynamic_shapes_cache_entry = b[1]
|
||||
|
||||
# Step 3: Run with Step 1's inputs
|
||||
# LRU cache is disabled, we should still have static entry first
|
||||
fn(x)
|
||||
c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(
|
||||
fn._torchdynamo_orig_callable.__code__
|
||||
)
|
||||
self.assertEqual(len(c), 2)
|
||||
self.assertEqual(c[0], static_shapes_cache_entry)
|
||||
self.assertEqual(c[1], dynamic_shapes_cache_entry)
|
||||
|
||||
try:
|
||||
torch._C._dynamo.eval_frame._set_lru_cache(False)
|
||||
run()
|
||||
finally:
|
||||
torch._C._dynamo.eval_frame._set_lru_cache(True)
|
||||
|
||||
|
||||
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
||||
def test_sub_alpha_scalar_repro(self, device):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user