mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			msaroufim/
			...
			revert_alw
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| fe3ee2e446 | 
| @ -69,8 +69,7 @@ RUN bash ./install_cuda.sh 13.0 | ||||
| ENV DESIRED_CUDA=13.0 | ||||
|  | ||||
| FROM ${ROCM_IMAGE} as rocm | ||||
| ARG PYTORCH_ROCM_ARCH | ||||
| ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} | ||||
| ENV PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" | ||||
| ADD ./common/install_mkl.sh install_mkl.sh | ||||
| RUN bash ./install_mkl.sh && rm install_mkl.sh | ||||
| ENV MKLROOT /opt/intel | ||||
|  | ||||
| @ -36,12 +36,6 @@ case ${DOCKER_TAG_PREFIX} in | ||||
|     ;; | ||||
|   rocm*) | ||||
|     BASE_TARGET=rocm | ||||
|     PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" | ||||
|     # add gfx950 conditionally starting in ROCm 7.0 | ||||
|     if [[ "$ROCM_VERSION" == *"7.0"* ]]; then | ||||
|         PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950" | ||||
|     fi | ||||
|     EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" | ||||
|     ;; | ||||
|   *) | ||||
|     echo "ERROR: Unknown docker tag ${DOCKER_TAG_PREFIX}" | ||||
|  | ||||
| @ -40,16 +40,12 @@ case ${DOCKER_TAG_PREFIX} in | ||||
|         ;; | ||||
|     rocm*) | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|         if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         BASE_TARGET=rocm | ||||
|         GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete | ||||
|         PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" | ||||
|         # add gfx950 conditionally starting in ROCm 7.0 | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950" | ||||
|         fi | ||||
|         DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}" | ||||
|         ;; | ||||
|     *) | ||||
|  | ||||
| @ -82,7 +82,7 @@ case ${image} in | ||||
|         ;; | ||||
|     manylinux2_28-builder:rocm*) | ||||
|         # we want the patch version of 6.4 instead | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then | ||||
|         if [[ $(ver $GPU_ARCH_VERSION) -eq $(ver 6.4) ]]; then | ||||
|             GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" | ||||
|         fi | ||||
|         TARGET=rocm_final | ||||
| @ -90,10 +90,6 @@ case ${image} in | ||||
|         DEVTOOLSET_VERSION="11" | ||||
|         GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete | ||||
|         PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" | ||||
|         # add gfx950 conditionally starting in ROCm 7.0 | ||||
|         if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then | ||||
|             PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950" | ||||
|         fi | ||||
|         DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" | ||||
|         ;; | ||||
|     manylinux2_28-builder:xpu) | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| SHELL=/usr/bin/env bash | ||||
|  | ||||
| DOCKER_CMD ?= docker | ||||
| DESIRED_ROCM ?= 7.0 | ||||
| DESIRED_ROCM ?= 6.4 | ||||
| DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM)) | ||||
| PACKAGE_NAME = magma-rocm | ||||
| # inherit this from underlying docker image, do not pass this env var to docker | ||||
| #PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201 | ||||
| #PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201 | ||||
|  | ||||
| DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \ | ||||
| 	-v $(shell git rev-parse --show-toplevel)/.ci:/builder \ | ||||
| @ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \ | ||||
| 	magma-rocm/build_magma.sh | ||||
|  | ||||
| .PHONY: all | ||||
| all: magma-rocm70 | ||||
| all: magma-rocm64 | ||||
| all: magma-rocm63 | ||||
|  | ||||
| @ -25,11 +24,6 @@ clean: | ||||
| 	$(RM) -r magma-* | ||||
| 	$(RM) -r output | ||||
|  | ||||
| .PHONY: magma-rocm70 | ||||
| magma-rocm70: DESIRED_ROCM := 7.0 | ||||
| magma-rocm70: | ||||
| 	$(DOCKER_RUN) | ||||
|  | ||||
| .PHONY: magma-rocm64 | ||||
| magma-rocm64: DESIRED_ROCM := 6.4 | ||||
| magma-rocm64: | ||||
|  | ||||
| @ -6,8 +6,8 @@ set -eou pipefail | ||||
| # The script expects DESIRED_CUDA and PACKAGE_NAME to be set | ||||
| ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" | ||||
|  | ||||
| # https://github.com/icl-utk-edu/magma/pull/65 | ||||
| MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec | ||||
| # Version 2.7.2 + ROCm related updates | ||||
| MAGMA_VERSION=a1625ff4d9bc362906bd01f805dbbe12612953f6 | ||||
|  | ||||
| # Folders for the build | ||||
| PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata | ||||
| @ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE | ||||
|  | ||||
| # Fetch magma sources and verify checksum | ||||
| pushd ${PACKAGE_DIR} | ||||
| git clone https://github.com/jeffdaily/magma | ||||
| git clone https://bitbucket.org/icl/magma.git | ||||
| pushd magma | ||||
| git checkout ${MAGMA_VERSION} | ||||
| popd | ||||
|  | ||||
| @ -58,7 +58,7 @@ time python tools/setup_helpers/generate_code.py \ | ||||
|  | ||||
| # Build the docs | ||||
| pushd docs/cpp | ||||
| time make VERBOSE=1 html | ||||
| time make VERBOSE=1 html -j | ||||
|  | ||||
| popd | ||||
| popd | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							| @ -22,9 +22,6 @@ self-hosted-runner: | ||||
|     - linux.arm64.m7g.4xlarge | ||||
|     - linux.arm64.m7g.4xlarge.ephemeral | ||||
|     - linux.arm64.r7g.12xlarge.memory | ||||
|     - linux.aws.h100 | ||||
|     - linux.aws.h100.4 | ||||
|     - linux.aws.h100.8 | ||||
|     - linux.4xlarge.nvidia.gpu | ||||
|     - linux.8xlarge.nvidia.gpu | ||||
|     - linux.16xlarge.nvidia.gpu | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							| @ -1 +1 @@ | ||||
| 9fe4c2bdb9859c14ad7f7479e1db7e01083bada3 | ||||
| 1983609239caaab24ab1ed2bfa2aa92e8c76c1b1 | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/_docs.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/_docs.yml
									
									
									
									
										vendored
									
									
								
							| @ -67,7 +67,7 @@ jobs: | ||||
|             # an OOM issue when running the job, so this upgrades the runner from 4xlarge | ||||
|             # to the next available tier of 12xlarge. So much memory just to generate cpp | ||||
|             # doc | ||||
|             runner: ${{ inputs.runner_prefix }}linux.12xlarge.memory | ||||
|             runner: ${{ inputs.runner_prefix }}linux.12xlarge | ||||
|             # TODO: Nightly cpp docs take longer and longer to finish (more than 3h now) | ||||
|             # Let's try to figure out how this can be improved | ||||
|             timeout-minutes: 360 | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/build-almalinux-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-almalinux-images.yml
									
									
									
									
										vendored
									
									
								
							| @ -36,7 +36,7 @@ jobs: | ||||
|     runs-on: linux.9xlarge.ephemeral | ||||
|     strategy: | ||||
|       matrix: | ||||
|         tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "rocm7.0", "cpu"] | ||||
|         tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.3", "rocm6.4", "cpu"] | ||||
|     steps: | ||||
|       - name: Build docker image | ||||
|         uses: pytorch/pytorch/.github/actions/binary-docker-build@main | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/build-libtorch-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/build-libtorch-images.yml
									
									
									
									
										vendored
									
									
								
							| @ -54,7 +54,6 @@ jobs: | ||||
|           { tag: "cuda12.6" }, | ||||
|           { tag: "rocm6.3"  }, | ||||
|           { tag: "rocm6.4"  }, | ||||
|           { tag: "rocm7.0"  }, | ||||
|           { tag: "cpu"      }, | ||||
|         ] | ||||
|     steps: | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/build-magma-rocm-linux.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-magma-rocm-linux.yml
									
									
									
									
										vendored
									
									
								
							| @ -34,7 +34,7 @@ jobs: | ||||
|       id-token: write | ||||
|     strategy: | ||||
|       matrix: | ||||
|         rocm_version: ["70", "64", "63"] | ||||
|         rocm_version: ["64", "63"] | ||||
|     steps: | ||||
|       - name: Checkout PyTorch | ||||
|         uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/build-manywheel-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/build-manywheel-images.yml
									
									
									
									
										vendored
									
									
								
							| @ -54,7 +54,6 @@ jobs: | ||||
|           { name: "manylinuxaarch64-builder",       tag: "cuda12.6",          runner: "linux.arm64.2xlarge.ephemeral" }, | ||||
|           { name: "manylinux2_28-builder",          tag: "rocm6.3",           runner: "linux.9xlarge.ephemeral" }, | ||||
|           { name: "manylinux2_28-builder",          tag: "rocm6.4",           runner: "linux.9xlarge.ephemeral" }, | ||||
|           { name: "manylinux2_28-builder",          tag: "rocm7.0",           runner: "linux.9xlarge.ephemeral" }, | ||||
|           { name: "manylinux2_28-builder",          tag: "cpu",               runner: "linux.9xlarge.ephemeral" }, | ||||
|           { name: "manylinux2_28_aarch64-builder",  tag: "cpu-aarch64",       runner: "linux.arm64.2xlarge.ephemeral" }, | ||||
|           { name: "manylinuxcxx11-abi-builder",     tag: "cpu-cxx11-abi",     runner: "linux.9xlarge.ephemeral" }, | ||||
|  | ||||
							
								
								
									
										59
									
								
								.github/workflows/create_release.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										59
									
								
								.github/workflows/create_release.yml
									
									
									
									
										vendored
									
									
								
							| @ -35,7 +35,6 @@ jobs: | ||||
|       contents: write | ||||
|     outputs: | ||||
|       pt_release_name: ${{ steps.release_name.outputs.pt_release_name }} | ||||
|       pt_pep517_release_name: ${{ steps.release_name.outputs.pt_pep517_release_name }} | ||||
|     steps: | ||||
|       - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|         with: | ||||
| @ -54,12 +53,8 @@ jobs: | ||||
|           tag_or_branch="${tag_or_branch#refs/heads/}" | ||||
|           # replace directory separators with _ in branch name | ||||
|           tag_or_branch="${tag_or_branch//\//_}" | ||||
|           torch_version="$(python -c 'from tools.generate_torch_version import get_torch_version; print(get_torch_version())')" | ||||
|           { | ||||
|             echo "PT_RELEASE_NAME=pytorch-$tag_or_branch"; | ||||
|             echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz"; | ||||
|             echo "PT_PEP517_RELEASE_FILE=torch-${torch_version}.tar.gz"; | ||||
|           } >> "$GITHUB_ENV" | ||||
|           echo "PT_RELEASE_NAME=pytorch-$tag_or_branch" >> "$GITHUB_ENV" | ||||
|           echo "PT_RELEASE_FILE=pytorch-$tag_or_branch.tar.gz" >> "$GITHUB_ENV" | ||||
|       - name: Checkout optional submodules | ||||
|         run: python3 tools/optional_submodules.py | ||||
|       - name: Copy docs requirements for inclusion | ||||
| @ -69,47 +64,30 @@ jobs: | ||||
|           cp .ci/docker/requirements-docs.txt docs/requirements.txt | ||||
|       - name: Create source distribution | ||||
|         run: | | ||||
|           # Create new folder with specified name so extracting the archive yields that | ||||
|           rm -rf "/tmp/$PT_RELEASE_NAME" | ||||
|           cp -r "$PWD" "/tmp/$PT_RELEASE_NAME" | ||||
|           mv "/tmp/$PT_RELEASE_NAME" . | ||||
|           # Cleanup | ||||
|           rm -rf "$PT_RELEASE_NAME"/{.circleci,.ci} | ||||
|           find "$PT_RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true | ||||
|           # Create archive | ||||
|           tar -czf "$PT_RELEASE_FILE" "$PT_RELEASE_NAME" | ||||
|           echo "Created source archive $PT_RELEASE_FILE with content: $(ls -a "$PT_RELEASE_NAME")" | ||||
|       - name: Create PEP 517 compatible source distribution | ||||
|         run: | | ||||
|           pip install build==1.2.2.post1 || exit 1 | ||||
|           python -m build --sdist || exit 1 | ||||
|           cd dist || exit 1 | ||||
|             # Create new folder with specified name so extracting the archive yields that | ||||
|             rm -rf "/tmp/$PT_RELEASE_NAME" | ||||
|             cp -r "$PWD" "/tmp/$PT_RELEASE_NAME" | ||||
|             mv "/tmp/$PT_RELEASE_NAME" . | ||||
|             # Cleanup | ||||
|             rm -rf "$PT_RELEASE_NAME"/{.circleci,.ci} | ||||
|             find "$PT_RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true | ||||
|             # Create archive | ||||
|             tar -czf "$PT_RELEASE_FILE" "$PT_RELEASE_NAME" | ||||
|             echo "Created source archive $PT_RELEASE_FILE with content: $(ls -a "$PT_RELEASE_NAME")" | ||||
|       - name: Upload source distribution for release | ||||
|         if: ${{ github.event_name == 'release' }} | ||||
|         uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # v2.2.2 | ||||
|         with: | ||||
|           files: | | ||||
|             ${{ env.PT_RELEASE_FILE }} | ||||
|             ${{ env.PT_PEP517_RELEASE_FILE }} | ||||
|       - name: Upload source distribution to GHA artifacts  # for release tags | ||||
|           files: ${{env.PT_RELEASE_FILE}} | ||||
|       - name: Upload source distribution to GHA artifacts for release tags | ||||
|         if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} | ||||
|         uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 | ||||
|         with: | ||||
|           name: ${{ env.PT_RELEASE_FILE }} | ||||
|           path: ${{ env.PT_RELEASE_FILE }} | ||||
|       - name: Upload PEP 517 source distribution to GHA artifacts  # for release tags | ||||
|         if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} | ||||
|         uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 | ||||
|         with: | ||||
|           name: ${{ env.PT_PEP517_RELEASE_FILE }} | ||||
|           path: dist/${{ env.PT_PEP517_RELEASE_FILE }} | ||||
|       - name: Set output | ||||
|         id: release_name | ||||
|         run: | | ||||
|           { | ||||
|             echo "pt_release_name=${{ env.PT_RELEASE_FILE }}"; | ||||
|             echo "pt_pep517_release_name=${{ env.PT_PEP517_RELEASE_FILE }}"; | ||||
|           } >> "${GITHUB_OUTPUT}" | ||||
|         run: echo "pt_release_name=${{ env.PT_RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}" | ||||
|  | ||||
|   upload_source_code_to_s3: | ||||
|     if: ${{ github.repository == 'pytorch/pytorch' && github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }} | ||||
| @ -125,9 +103,6 @@ jobs: | ||||
|       - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7 | ||||
|         with: | ||||
|           name: ${{ needs.release.outputs.pt_release_name }} | ||||
|       - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7 | ||||
|         with: | ||||
|           name: ${{ needs.release.outputs.pt_pep517_release_name }} | ||||
|       - name: Configure AWS credentials(PyTorch account) | ||||
|         uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 | ||||
|         with: | ||||
| @ -138,9 +113,7 @@ jobs: | ||||
|           s3-bucket: pytorch | ||||
|           s3-prefix: source_code/test | ||||
|           if-no-files-found: warn | ||||
|           path: | | ||||
|             ${{ needs.release.outputs.pt_release_name }} | ||||
|             ${{ needs.release.outputs.pt_pep517_release_name }} | ||||
|           path: ${{ needs.release.outputs.pt_release_name }} | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }} | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							| @ -127,8 +127,6 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # More memory is needed to build with asan | ||||
|       runner: linux.2xlarge.memory | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3.10-clang18-asan | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/slow.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/slow.yml
									
									
									
									
										vendored
									
									
								
							| @ -140,8 +140,6 @@ jobs: | ||||
|     uses: ./.github/workflows/_linux-build.yml | ||||
|     needs: get-label-type | ||||
|     with: | ||||
|       # More memory is needed to build with asan | ||||
|       runner: linux.2xlarge.memory | ||||
|       runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" | ||||
|       build-environment: linux-jammy-py3.10-clang18-asan | ||||
|       docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan | ||||
|  | ||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -82,7 +82,6 @@ torch/return_types.pyi | ||||
| torch/nn/functional.pyi | ||||
| torch/utils/data/datapipes/datapipe.pyi | ||||
| torch/csrc/autograd/generated/* | ||||
| torch/csrc/functionalization/generated/* | ||||
| torch/csrc/lazy/generated/*.[!m]* | ||||
| torch_compile_debug/ | ||||
| # Listed manually because some files in this directory are not generated | ||||
|  | ||||
| @ -1453,7 +1453,7 @@ init_command = [ | ||||
|     '--dry-run={{DRYRUN}}', | ||||
|     'usort==1.0.8.post1', | ||||
|     'isort==6.0.1', | ||||
|     'ruff==0.13.1',  # sync with RUFF | ||||
|     'ruff==0.12.9',  # sync with RUFF | ||||
| ] | ||||
| is_formatter = true | ||||
|  | ||||
| @ -1587,7 +1587,7 @@ init_command = [ | ||||
|     'python3', | ||||
|     'tools/linter/adapters/pip_init.py', | ||||
|     '--dry-run={{DRYRUN}}', | ||||
|     'ruff==0.13.1',  # sync with PYFMT | ||||
|     'ruff==0.12.9',  # sync with PYFMT | ||||
| ] | ||||
| is_formatter = true | ||||
|  | ||||
|  | ||||
| @ -91,8 +91,6 @@ generated_cpu_cpp = [ | ||||
|     "aten/src/ATen/NativeMetaFunctions.h", | ||||
|     "aten/src/ATen/RegistrationDeclarations.h", | ||||
|     "aten/src/ATen/VmapGeneratedPlumbing.h", | ||||
|     "aten/src/ATen/ViewMetaClasses.h", | ||||
|     "aten/src/ATen/ViewMetaClasses.cpp", | ||||
|     "aten/src/ATen/core/aten_interned_strings.h", | ||||
|     "aten/src/ATen/core/enum_tag.h", | ||||
|     "aten/src/ATen/core/TensorBody.h", | ||||
| @ -1077,7 +1075,6 @@ test_suite( | ||||
|         "aten/src/ATen/templates/LazyNonNativeIr.h", | ||||
|         "aten/src/ATen/templates/RegisterDispatchKey.cpp", | ||||
|         "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", | ||||
|         "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", | ||||
|         "aten/src/ATen/native/native_functions.yaml", | ||||
|         "aten/src/ATen/native/tags.yaml", | ||||
|         "aten/src/ATen/native/ts_native_functions.yaml", | ||||
|  | ||||
							
								
								
									
										105
									
								
								MANIFEST.in
									
									
									
									
									
								
							
							
						
						
									
										105
									
								
								MANIFEST.in
									
									
									
									
									
								
							| @ -1,61 +1,20 @@ | ||||
| # Reference: https://setuptools.pypa.io/en/latest/userguide/miscellaneous.html | ||||
|  | ||||
| # Include individual top-level files | ||||
| include CITATION.cff | ||||
| include CODEOWNERS | ||||
| include Dockerfile | ||||
| include LICENSE | ||||
| include MANIFEST.in | ||||
| include Makefile | ||||
| include NOTICE | ||||
| include .bc-linter.yml | ||||
| include .clang-format .clang-tidy | ||||
| include .cmakelintrc | ||||
| include .coveragerc | ||||
| include .dockerignore | ||||
| include .editorconfig | ||||
| include .flake8 | ||||
| include .gdbinit | ||||
| include .lintrunner.toml | ||||
| include .lldbinit | ||||
| include codex_setup.sh | ||||
| include docker.Makefile | ||||
| include pyrefly.toml | ||||
| include ubsan.supp | ||||
|  | ||||
| # Include bazel and BUCK related files | ||||
| include BUILD.bazel BUCK.oss | ||||
| include WORKSPACE | ||||
| include *.bzl | ||||
| include .bazelignore .bazelrc .bazelversion | ||||
|  | ||||
| # Include general configuration files | ||||
| include *.ini | ||||
| # Include important top-level information | ||||
| include *.md | ||||
| # Include technical text files at the moment, comprises | ||||
| # version.txt, CMakeLists.txt, requirements.txt | ||||
| include *.txt | ||||
|  | ||||
| # Include ctags configuration | ||||
| include .ctags.d/*.ctags | ||||
|  | ||||
| # Include subfolders completely | ||||
| graft .devcontainer | ||||
| graft .vscode | ||||
| # Include source files in SDist | ||||
| include CMakeLists.txt | ||||
| include *.bzl *.bazel .bazel* BUILD *.BUILD BUILD.* WORKSPACE | ||||
| include BUCK BUCK.* | ||||
| include requirements*.txt | ||||
| include version.txt | ||||
| include [Mm]akefile *.[Mm]akefile [Mm]akefile.* | ||||
| include [Dd]ockerfile *.[Dd]ockerfile [Dd]ockerfile.* .dockerignore | ||||
| graft android | ||||
| graft aten | ||||
| graft benchmarks | ||||
| graft binaries | ||||
| graft c10 | ||||
| graft caffe2 | ||||
| graft cmake | ||||
| graft docs | ||||
| graft functorch | ||||
| graft ios | ||||
| graft mypy_plugins | ||||
| graft scripts | ||||
| graft test | ||||
| graft third_party | ||||
| graft tools | ||||
| graft torch | ||||
| @ -63,37 +22,29 @@ graft torchgen | ||||
| # FIXME: torch-xla build during codegen will fail if include this file in wheel | ||||
| exclude torchgen/BUILD.bazel | ||||
|  | ||||
| # The following exclusions omit parts from third-party dependencies that | ||||
| # contain invalid symlinks[1] and that are not needed for pytorch, such as | ||||
| # bindings for unused languages | ||||
| prune third_party/flatbuffers/java | ||||
| prune third_party/flatbuffers/kotlin | ||||
| prune third_party/ittapi/rust | ||||
| prune third_party/nccl/pkg/debian | ||||
| prune third_party/opentelemetry-cpp/third_party/prometheus-cpp/cmake/project-import-* | ||||
|  | ||||
| # The following document is also an invalid symlink[1] and superfluous | ||||
| exclude third_party/flatbuffers/docs/source/CONTRIBUTING.md | ||||
|  | ||||
| # Omit autogenerated code | ||||
| prune torchgen/packaged | ||||
|  | ||||
| # Omit caches, compiled, and scm related content | ||||
| prune */__pycache__ | ||||
| prune **/.github | ||||
| prune **/.gitlab | ||||
| global-exclude *.o *.obj *.so *.dylib *.a *.pxd *.dll *.lib | ||||
| global-exclude *.py[cod] *.swp *~ | ||||
| global-exclude .git .git-blame-ignore-revs .gitattributes .gitignore .gitmodules | ||||
| global-exclude .gitlab-ci.yml | ||||
| # Misc files and directories in SDist | ||||
| include *.md | ||||
| include CITATION.cff | ||||
| include LICENSE NOTICE | ||||
| include mypy*.ini | ||||
| graft benchmarks | ||||
| graft docs | ||||
| graft mypy_plugins | ||||
| graft scripts | ||||
|  | ||||
| # Misc files needed for custom setuptools command | ||||
| include .gitignore | ||||
| include .gitmodules | ||||
|  | ||||
| # [1] Invalid symlinks for the purposes of Python source distributions are, | ||||
| # according to the source distribution format[2] links pointing outside the | ||||
| # destination directory or links with a `..` component, which is those of | ||||
| # concern here. | ||||
| # Include test suites in SDist | ||||
| graft test | ||||
| include pytest.ini | ||||
| include .coveragerc | ||||
|  | ||||
| # [2] https://packaging.python.org/en/latest/specifications/source-distribution-format/#source-distribution-archive-features | ||||
| # Prune generated/compiled files | ||||
| prune torchgen/packaged | ||||
| prune */__pycache__ | ||||
| global-exclude *.o *.obj *.so *.a *.dylib *.pxd *.dll *.lib *.py[cod] | ||||
|  | ||||
| prune */.git | ||||
| global-exclude .git *~ *.swp | ||||
|  | ||||
| @ -468,7 +468,7 @@ inline Tensor _sum_to( | ||||
|       // if we assume no reduction due to unbacked we ensure that at runtime. | ||||
|       TORCH_MAYBE_SYM_CHECK( | ||||
|           sym_eq(shape[i - leading_dims], sizes[i]), | ||||
|           "non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:", | ||||
|           "non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:", | ||||
|           shape[i - leading_dims], | ||||
|           ", ", | ||||
|           sizes[i]) | ||||
|  | ||||
| @ -9,6 +9,11 @@ | ||||
|  | ||||
| namespace at::functionalization { | ||||
|  | ||||
| ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { | ||||
|   if (out_idx == this->out_index) return *this; | ||||
|   return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); | ||||
| } | ||||
|  | ||||
| // Note [Functionalization: Alias Removal Part 2] | ||||
| // See Note [Functionalization: Alias Removal] for more details. | ||||
| // This function applies a single update from one of the views to the StorageImpl. | ||||
| @ -37,12 +42,12 @@ namespace at::functionalization { | ||||
| static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) { | ||||
|   at::Tensor t = update.new_val; | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); | ||||
|   if (update.view_metas.empty()) { return t; } | ||||
|   if (update.view_metas.empty()) return t; | ||||
|  | ||||
|   std::vector<at::Tensor> tmp_values({base}); | ||||
|   tmp_values.reserve(update.view_metas.size()); | ||||
|   for (size_t i = 0; i < update.view_metas.size() - 1; ++i) { | ||||
|     at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back()); | ||||
|     at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index); | ||||
|     // NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided | ||||
|     // All of these ops require additional information to recover the sizes of the original tensor. | ||||
|     // If need to, we could probably apply this optimization and only bother computing tmp_values | ||||
| @ -50,8 +55,9 @@ static const Tensor apply_update(const FunctionalStorageImpl::Update& update, co | ||||
|     tmp_values.push_back(std::move(next_view)); | ||||
|   } | ||||
|   for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) { | ||||
|     int64_t out_idx = update.view_metas[i].out_index; | ||||
|     // Each view inverse is implemented in ViewInverses.cpp. | ||||
|     t = update.view_metas[i]->reverse(tmp_values[i], t); | ||||
|     t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx); | ||||
|   } | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); | ||||
|   return t; | ||||
| @ -105,13 +111,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base) | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_)); | ||||
| } | ||||
|  | ||||
| void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) { | ||||
| void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) { | ||||
|   TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage"); | ||||
|  | ||||
|   if (metas.size() > 1) { | ||||
|     for (size_t i = 1; i < metas.size(); ++i) { | ||||
|       // Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI | ||||
|       TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided, | ||||
|       TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided, | ||||
| "During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i, | ||||
| " was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today," | ||||
| "so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you " | ||||
|  | ||||
| @ -8,89 +8,44 @@ namespace at::functionalization { | ||||
|  | ||||
| // See Note [Functionalization Pass In Core] | ||||
|  | ||||
| enum class InverseReturnMode { | ||||
|   /// Specifies that functional inverses should always return a view. | ||||
|   AlwaysView, | ||||
|   /// Specifies that functional inverses should always return a non-view / copy. | ||||
|   NeverView, | ||||
|   /// Specifies that functional inverses should return a view unless a (copying) | ||||
|   /// scatter | ||||
|   /// inverse exists, in which case that will be used instead. | ||||
|   /// This avoids as_strided() calls that can be difficult for subclasses to | ||||
|   /// handle. | ||||
|   ViewOrScatterInverse, | ||||
| }; | ||||
|  | ||||
| #define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \ | ||||
|   static const char* name() {                 \ | ||||
|     return #TYPE;                             \ | ||||
|   } | ||||
|  | ||||
| #define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \ | ||||
|   using SerializableTuple = std::tuple<__VA_ARGS__> | ||||
|  | ||||
| // ViewMeta is a class used by the functionalization pass to navigate between | ||||
| // a base tensor and a view tensor. | ||||
| // For example, if I call `b = a.view1(...)` | ||||
| // the functionalization pass will generate and store a ViewMeta specialization | ||||
| // for `view1` operation on b that looks like: | ||||
| // the functionalization pass will generate and store a ViewMeta on b that looks | ||||
| // like: | ||||
| // | ||||
| // struct TORCH_API view1_ViewMeta : public ViewMeta { | ||||
| //   FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta); | ||||
| //   FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( | ||||
| //       bool /* reapply_views */, | ||||
| //       const std::vector<int64_t>&); | ||||
| // | ||||
| //   view1_ViewMeta(const SerializableTuple& tpl) | ||||
| //       : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} | ||||
| // | ||||
| //   view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size) | ||||
| //       : ViewMeta(/*has_symbolic_inputs=*/false), | ||||
| //         reapply_views(reapply_views), | ||||
| //         size(size) {} | ||||
| // | ||||
| //   Tensor forward(const Tensor& base) override { | ||||
| //       return base.view1(...); | ||||
| // ViewMeta( | ||||
| //   [<captures>](const Tensor& base, int64_t mutated_view_idx) { | ||||
| //     return base.view1(...); | ||||
| //   }, | ||||
| //   [<captures>](const at::Tensor& base, const at::Tensor& mutated_view, | ||||
| //   int64_t mutated_view_idx) -> at::Tensor { | ||||
| //     return at::functionalization::impl::view1_inverse(base, mutated_view, | ||||
| //     ...); | ||||
| //   } | ||||
| // | ||||
| //   Tensor reverse(const Tensor& base, const Tensor& mutated_view) override { | ||||
| //       return at::functionalization::impl::view1_inverse(base, mutated_view, | ||||
| //       ...); | ||||
| //   } | ||||
| // The forward_fn lambda describes how to replay view1 on a tensor. | ||||
| // | ||||
| //   SerializableTuple to_serializable_tuple() { | ||||
| //     return std::make_tuple(reapply_views, size); | ||||
| //   } | ||||
| // | ||||
| //   bool reapply_views; | ||||
| //   std::vector<int64_t> size; | ||||
| // }; | ||||
| // | ||||
| // The forward function describes how to replay view1 on a tensor. | ||||
| // | ||||
| // The reverse function describes how, given a tensor that is already a view, | ||||
| // The reverse_fn lambda describes how, given a tensor that is already a view, | ||||
| // how to get the corresponding base tensor. See Note [Functionalization Pass: | ||||
| // View Inverses] for details. | ||||
| // | ||||
| // `SerializedTuple` is a typedef that defines an `std::tuple<...>` type | ||||
| // representing the `ViewMeta` instance state. Methods that take in/return such | ||||
| // a type are used for supporting pickle serialization. | ||||
| struct ViewMeta { | ||||
|   ViewMeta( | ||||
|       std::function<Tensor(const Tensor&, int64_t)> forward, | ||||
|       std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse, | ||||
|       bool has_symbolic_inputs, | ||||
|       bool is_multi_output = false, | ||||
|       bool is_as_strided = false, | ||||
|       int64_t out_idx = 0) | ||||
|       : out_index(out_idx), | ||||
|       : forward_fn(std::move(forward)), | ||||
|         reverse_fn(std::move(reverse)), | ||||
|         out_index(out_idx), | ||||
|         is_multi_output(is_multi_output), | ||||
|         is_as_strided(is_as_strided), | ||||
|         has_symbolic_inputs(has_symbolic_inputs) {} | ||||
|  | ||||
|   virtual ~ViewMeta() = default; | ||||
|  | ||||
|   virtual Tensor forward(const Tensor& base) = 0; | ||||
|   virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; | ||||
|  | ||||
|   std::function<Tensor(const Tensor&, int64_t)> forward_fn; | ||||
|   std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn; | ||||
|   // See Note [out_idx in ViewMeta] | ||||
|   int64_t out_index; | ||||
|  | ||||
| @ -102,17 +57,10 @@ struct ViewMeta { | ||||
|   // Tells us if this view operation has any symbolic inputs | ||||
|   bool has_symbolic_inputs; | ||||
|  | ||||
|   // Returns a new ViewMeta with the same forward/reverse | ||||
|   // Returns a copy of the current ViewMeta, if out_idx matches the current | ||||
|   // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse | ||||
|   // functions, but a new out index. | ||||
|   // | ||||
|   // This method should be implemented by those `ViewMeta` that have more than | ||||
|   // one output. | ||||
|   virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) { | ||||
|     TORCH_CHECK_NOT_IMPLEMENTED( | ||||
|         false, | ||||
|         "ViewMeta::to_out_index not implemented. ", | ||||
|         "Likely because there's only one output."); | ||||
|   } | ||||
|   ViewMeta to_out_idx(int64_t out_idx); | ||||
| }; | ||||
|  | ||||
| // FunctionalStorageImpl is a subclass of StorageImpl used by the | ||||
| @ -145,14 +93,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { | ||||
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|     const at::Tensor new_val; | ||||
|     // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||
|     const std::vector<std::shared_ptr<ViewMeta>> view_metas; | ||||
|     const std::vector<ViewMeta> view_metas; | ||||
|   }; | ||||
|  | ||||
|   explicit FunctionalStorageImpl(const Tensor& value); | ||||
|  | ||||
|   void add_update( | ||||
|       const Tensor& updated_val, | ||||
|       const std::vector<std::shared_ptr<ViewMeta>>& view_metas); | ||||
|       const std::vector<ViewMeta>& view_metas); | ||||
|   bool apply_updates(); | ||||
|   const Tensor& base() { | ||||
|     return base_; | ||||
|  | ||||
| @ -129,19 +129,17 @@ void FunctionalTensorWrapper::freeze_storage() const { | ||||
| // - view_value: The output tensor that we need to wrap. | ||||
| // - base: The "base" of the view that `view_value` was generated from. | ||||
| // See Note [Functionalization: Alias Removal Part 2] for more details on the mutation replay logic. | ||||
| FunctionalTensorWrapper::FunctionalTensorWrapper( | ||||
|     const Tensor& view_value, | ||||
|     const FunctionalTensorWrapper* base, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta) | ||||
|     : c10::TensorImpl( | ||||
|           c10::DispatchKeySet(DispatchKey::Functionalize), | ||||
|           view_value.dtype(), | ||||
|           base->storage().data_ptr().device()), | ||||
|       value_(view_value), | ||||
|       is_multi_output_view_( | ||||
|           base->is_multi_output_view_ || meta->is_multi_output), | ||||
|       was_storage_changed_(base->was_storage_changed_), | ||||
|       is_symbolic_(base->is_symbolic_) { | ||||
| FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const FunctionalTensorWrapper* base, const functionalization::ViewMeta& meta) | ||||
|   : c10::TensorImpl( | ||||
|       c10::DispatchKeySet(DispatchKey::Functionalize), | ||||
|       view_value.dtype(), | ||||
|       base->storage().data_ptr().device() | ||||
|     ), | ||||
|     value_(view_value), | ||||
|     is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), | ||||
|     was_storage_changed_(base->was_storage_changed_), | ||||
|     is_symbolic_(base->is_symbolic_) | ||||
| { | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); | ||||
|   TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); | ||||
|   set_constructor_metadata(); | ||||
| @ -150,10 +148,11 @@ FunctionalTensorWrapper::FunctionalTensorWrapper( | ||||
|       view_metas_ = base->view_metas_;  // copy | ||||
|   } | ||||
|   view_metas_.push_back(meta); | ||||
|   maybe_mark_symbolic(meta.get()); | ||||
|   maybe_mark_symbolic(meta); | ||||
|   storage_ = base->storage_; // alias this tensor's storage with the base tensor's | ||||
| } | ||||
|  | ||||
|  | ||||
| functionalization::FunctionalStorageImpl* FunctionalTensorWrapper::functional_storage_impl() const { | ||||
|   return static_cast<functionalization::FunctionalStorageImpl*>(storage_.unsafeGetStorageImpl()); | ||||
| } | ||||
| @ -177,18 +176,18 @@ bool FunctionalTensorWrapper::is_up_to_date() const { | ||||
| } | ||||
|  | ||||
| // See Note [Functionalization Pass - Inplace View Ops] | ||||
| void FunctionalTensorWrapper::mutate_view_meta(const std::shared_ptr<at::functionalization::ViewMeta>& meta) { | ||||
| void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::ViewMeta& meta) { | ||||
|   view_metas_.push_back(meta); | ||||
|   // Manually track the fact that this tensor received a metadata mutation! | ||||
|   has_metadata_mutation_ = true; | ||||
|   // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. | ||||
|   maybe_mark_symbolic(meta.get()); | ||||
|   maybe_mark_symbolic(meta); | ||||
|   // Note [Functionalization Pass - Inplace View Ops] | ||||
|   // So, these ops are special - they're mutation AND view ops. They get special codegen. | ||||
|   // An example is transpose_, e.g. `a.transpose_()` | ||||
|   // Calling transpose_() should ensure that a gets an alias, and append the new ViewMeta to a's current list of ViewMetas. | ||||
|   at::AutoDispatchSkipFunctionalize guard; | ||||
|   value_ = meta->forward(value_); | ||||
|   value_ = meta.forward_fn(value_, meta.out_index); | ||||
|   TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); | ||||
| } | ||||
|  | ||||
| @ -369,8 +368,15 @@ void FunctionalTensorWrapper::sync_() { | ||||
|   regenerate_from_base(); | ||||
| } | ||||
|  | ||||
| const std::vector<std::shared_ptr<functionalization::ViewMeta>>& FunctionalTensorWrapper::view_metas() const { | ||||
|   return view_metas_; | ||||
| Tensor FunctionalTensorWrapper::apply_view_metas(const Tensor& base) { | ||||
|   auto t = base; | ||||
|  | ||||
|   // Reapply views to get the viewed tensor from the base in alias_ | ||||
|   for (auto& view_meta: view_metas_) { | ||||
|     t = view_meta.forward_fn(t, view_meta.out_index); | ||||
|   } | ||||
|  | ||||
|   return t; | ||||
| } | ||||
|  | ||||
| void FunctionalTensorWrapper::regenerate_from_base() { | ||||
| @ -379,7 +385,7 @@ void FunctionalTensorWrapper::regenerate_from_base() { | ||||
|   auto t = storage_impl->base(); | ||||
|  | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); | ||||
|   t = at::functionalization::impl::apply_view_meta_sequence(t, view_metas_); | ||||
|   t = apply_view_metas(t); | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t)); | ||||
|  | ||||
|   replace_(t, /*from_lazy_regenerate=*/true); | ||||
| @ -721,11 +727,11 @@ bool isFunctionalTensor(const std::optional<Tensor>& t) { | ||||
| } | ||||
|  | ||||
| bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) { | ||||
|   if (t_list.empty()) { return false; } | ||||
|   if (t_list.empty()) return false; | ||||
|   auto functional_count = 0; | ||||
|   for (const auto i : c10::irange(t_list.size())) { | ||||
|     auto const & e= t_list[i]; | ||||
|     if (!e.has_value() || !e->defined()) { continue; } | ||||
|     if (!e.has_value() || !e->defined()) continue; | ||||
|     if (isFunctionalTensor(e)) { | ||||
|       ++functional_count; | ||||
|     } | ||||
| @ -735,10 +741,10 @@ bool isFunctionalTensor(const c10::List<::std::optional<Tensor>>& t_list) { | ||||
|  | ||||
| template <typename T> | ||||
| static bool isFunctionalTensorIListRef(c10::IListRef<T> list) { | ||||
|   if (list.size() == 0) { return false; } | ||||
|   if (list.size() == 0) return false; | ||||
|   auto functional_count = 0; | ||||
|   for (const auto& tensor : list) { | ||||
|     if (!tensor.defined()) { continue; } | ||||
|     if (!tensor.defined()) continue; | ||||
|     if (isFunctionalTensor(tensor)) { | ||||
|       ++functional_count; | ||||
|     } | ||||
| @ -756,28 +762,20 @@ void freeze_functional_tensor(const Tensor& tensor) { | ||||
|   functional_base_impl->freeze_storage(); | ||||
| } | ||||
|  | ||||
| Tensor create_functional_tensor_with_view_meta( | ||||
|     const at::Tensor& view_to_wrap, | ||||
|     const at::Tensor& base, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta, | ||||
|     int64_t out_idx) { | ||||
| Tensor create_functional_tensor_with_view_meta(const at::Tensor& view_to_wrap, const at::Tensor& base, functionalization::ViewMeta meta, int64_t out_idx) { | ||||
|   TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(view_to_wrap)); | ||||
|   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(base)); | ||||
|   auto functional_base_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(base); | ||||
|   auto meta_ = meta; | ||||
|   if (out_idx != 0) { | ||||
|     // Note [out_idx in ViewMeta] | ||||
|     // When a view op outputs multiple tensors, each output needs its own separate ViewMeta. | ||||
|     // Each ViewMeta also tracks the index of the particular output tensor, which is needed in the reverse function. | ||||
|     meta_ = meta->to_out_index(out_idx); | ||||
|     meta = meta.to_out_idx(out_idx); | ||||
|   } | ||||
|   return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta_); | ||||
|   return at::detail::make_tensor<FunctionalTensorWrapper>(view_to_wrap, functional_base_impl, meta); | ||||
| } | ||||
|  | ||||
| std::vector<Tensor> create_functional_tensor_with_view_meta( | ||||
|     ITensorListRef view_to_wrap, | ||||
|     const at::Tensor& base, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta) { | ||||
| std::vector<Tensor> create_functional_tensor_with_view_meta(ITensorListRef view_to_wrap, const at::Tensor& base, const functionalization::ViewMeta& meta) { | ||||
|   std::vector<Tensor> outputs(view_to_wrap.size()); | ||||
|   int64_t i = 0; | ||||
|   for (const auto& tensor : view_to_wrap) { | ||||
| @ -787,22 +785,12 @@ std::vector<Tensor> create_functional_tensor_with_view_meta( | ||||
|   return outputs; | ||||
| } | ||||
|  | ||||
| void mutate_view_meta(const at::Tensor& self, const std::shared_ptr<functionalization::ViewMeta>& meta) { | ||||
| void mutate_view_meta(const at::Tensor& self, const functionalization::ViewMeta& meta) { | ||||
|   TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); | ||||
|   auto self_impl = at::functionalization::impl::unsafeGetFunctionalWrapper(self); | ||||
|   self_impl->mutate_view_meta(meta); | ||||
| } | ||||
|  | ||||
| Tensor apply_view_meta_sequence( | ||||
|     const Tensor& base, | ||||
|     const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence) { | ||||
|   Tensor r = base; | ||||
|   for (auto& vm : sequence) { | ||||
|     r = vm->forward(r); | ||||
|   } | ||||
|   return r; | ||||
| } | ||||
|  | ||||
| // Note [Propagating strides in the functionalization pass] | ||||
| // In order to properly compute stride information, the functionalization pass | ||||
| // calls each {view} reference implementations with meta tensors. | ||||
| @ -896,7 +884,7 @@ void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* s | ||||
|     const auto& ivalue = returns[idx]; | ||||
|     if (ivalue.isTensor()) { | ||||
|       const auto& t = ivalue.toTensor(); | ||||
|       if (!t.defined()) { continue; } | ||||
|       if (!t.defined()) continue; | ||||
|       at::functionalization::impl::sync(t); | ||||
|       auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t)); | ||||
|       (*stack)[returns_begin + idx] = t_new; | ||||
|  | ||||
| @ -56,7 +56,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|   explicit FunctionalTensorWrapper( | ||||
|       const Tensor& view_value, | ||||
|       const FunctionalTensorWrapper* base, | ||||
|       const std::shared_ptr<functionalization::ViewMeta>& meta); | ||||
|       const functionalization::ViewMeta& meta); | ||||
|  | ||||
|   // Get the underlying, actual tensor, that doesn't know anything about | ||||
|   // functionalization. | ||||
| @ -99,17 +99,17 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|         ->are_all_mutations_under_no_grad_or_inference_mode(); | ||||
|   } | ||||
|  | ||||
|   void maybe_mark_symbolic(functionalization::ViewMeta* meta) { | ||||
|     is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; | ||||
|   void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { | ||||
|     is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; | ||||
|   } | ||||
|  | ||||
|   bool is_symbolic() const { | ||||
|     return is_symbolic_; | ||||
|   } | ||||
|  | ||||
|   // Retrieves the ViewMeta sequence of this tensor. | ||||
|   const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas() | ||||
|       const; | ||||
|   // Runs the forward_fn of every ViewMeta collected in the current instance | ||||
|   // to some other base. | ||||
|   Tensor apply_view_metas(const Tensor& base); | ||||
|  | ||||
|   // Sync's the underlying tensor with its alias, if it's out of date. This | ||||
|   // involves two steps: 1) Apply any pending updates/mutations to the alias 2) | ||||
| @ -146,8 +146,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|   // from the base tensor. This method is used by inplace-view ops like | ||||
|   // transpose_. It appends a ViewMeta to the existing stack, and refreshes the | ||||
|   // tensor by replaying the views off of the alias. | ||||
|   void mutate_view_meta( | ||||
|       const std::shared_ptr<at::functionalization::ViewMeta>& meta); | ||||
|   void mutate_view_meta(const at::functionalization::ViewMeta& meta); | ||||
|  | ||||
|   // Custom implementation of self.set_(src) | ||||
|   void set__impl(const FunctionalTensorWrapper* other); | ||||
| @ -286,7 +285,7 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { | ||||
|   bool is_symbolic_ = false; | ||||
|  | ||||
|   size_t generation_ = 0; | ||||
|   std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_; | ||||
|   std::vector<at::functionalization::ViewMeta> view_metas_; | ||||
|  | ||||
|  protected: | ||||
|   static void copy_tensor_metadata( | ||||
| @ -378,20 +377,16 @@ TORCH_API void propagate_xla_data_direct( | ||||
| Tensor create_functional_tensor_with_view_meta( | ||||
|     const Tensor& view_to_wrap, | ||||
|     const Tensor& base, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta, | ||||
|     functionalization::ViewMeta meta, | ||||
|     int64_t out_idx = 0); | ||||
| std::vector<Tensor> create_functional_tensor_with_view_meta( | ||||
|     ITensorListRef view_to_wrap, | ||||
|     const Tensor& base, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta); | ||||
|     const functionalization::ViewMeta& meta); | ||||
|  | ||||
| void mutate_view_meta( | ||||
|     const Tensor& self, | ||||
|     const std::shared_ptr<functionalization::ViewMeta>& meta); | ||||
|  | ||||
| TORCH_API Tensor apply_view_meta_sequence( | ||||
|     const Tensor& base, | ||||
|     const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence); | ||||
|     const functionalization::ViewMeta& meta); | ||||
|  | ||||
| void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); | ||||
| void set_sizes_strides_offset( | ||||
|  | ||||
| @ -1,5 +1,3 @@ | ||||
| #include <ATen/FunctionalizeFallbackKernel.h> | ||||
|  | ||||
| #include <ATen/core/dispatch/Dispatcher.h> | ||||
| #include <ATen/core/LegacyTypeDispatch.h> | ||||
| #include <ATen/EmptyTensor.h> | ||||
| @ -9,6 +7,7 @@ | ||||
| #include <torch/library.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <c10/util/strides.h> | ||||
| #include <ATen/EmptyTensor.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/ATen.h> | ||||
| @ -29,31 +28,6 @@ | ||||
| #include <utility> | ||||
| #endif | ||||
|  | ||||
| namespace at::functionalization { | ||||
|  | ||||
| Tensor resize__ViewMeta::forward(const Tensor& base) { | ||||
|   if (reapply_views) { | ||||
|     return base.as_strided(size, c10::contiguous_strides(size)); | ||||
|   } else { | ||||
|     return at::as_strided_copy(base, size, c10::contiguous_strides(size)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| Tensor resize__ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { | ||||
|   return base.as_strided_scatter( | ||||
|       mutated_view, size, c10::contiguous_strides(size)); | ||||
| } | ||||
|  | ||||
| Tensor _unsafe_view_ViewMeta::forward(const Tensor& base) { | ||||
|   return at::_unsafe_view_symint(base, size); | ||||
| } | ||||
|  | ||||
| Tensor _unsafe_view_ViewMeta::reverse(const Tensor& base, const Tensor& mutated_view) { | ||||
|   return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); | ||||
| } | ||||
|  | ||||
| } // namespace at::functionalization | ||||
|  | ||||
| namespace { | ||||
|   void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet [[maybe_unused]], torch::jit::Stack* stack) { | ||||
|     const auto& schema = op.schema(); | ||||
| @ -132,9 +106,7 @@ namespace { | ||||
|       const auto& ivalue = returns[idx]; | ||||
|       if (ivalue.isTensor() && should_wrap_outputs) { | ||||
|         const auto& t = ivalue.toTensor(); | ||||
|         if (!t.defined()) { | ||||
|           continue; | ||||
|         } | ||||
|         if (!t.defined()) continue; | ||||
|         auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t)); | ||||
|         (*stack)[returns_begin + idx] = t_new; | ||||
|       } else if (ivalue.isTensorList() && should_wrap_outputs) { | ||||
| @ -197,8 +169,19 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch | ||||
|   // The output of resizing is equivalent to taking a slice of a larger tensor. | ||||
|   // We have to emulate this "slicing" with an as_strided call. | ||||
|   auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); | ||||
|   auto view_meta = std::make_shared<at::functionalization::resize__ViewMeta>( | ||||
|       reapply_views, size.vec()); | ||||
|   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( | ||||
|     [reapply_views = reapply_views, size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { | ||||
|       if (reapply_views) { | ||||
|         return base.as_strided(size, c10::contiguous_strides(size)); | ||||
|       } else { | ||||
|         return at::as_strided_copy(base, size, c10::contiguous_strides(size)); | ||||
|       } | ||||
|     }, | ||||
|     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { | ||||
|       return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); | ||||
|     }, | ||||
|     /*has_symbolic_inputs=*/false | ||||
|   ); | ||||
|   at::functionalization::impl::mutate_view_meta(self, view_meta); | ||||
|   return self; | ||||
| } | ||||
| @ -317,11 +300,17 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt | ||||
|     tmp_output = at::_unsafe_view_symint(self_, size); | ||||
|   } | ||||
|  | ||||
|   bool has_symbolic_inputs = std::any_of( | ||||
|       size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); | ||||
|   auto view_meta = | ||||
|       std::make_shared<at::functionalization::_unsafe_view_ViewMeta>( | ||||
|           has_symbolic_inputs, size.vec()); | ||||
|   bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); | ||||
|  | ||||
|   at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( | ||||
|     [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { | ||||
|       return at::_unsafe_view_symint(base, size); | ||||
|     }, | ||||
|     [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx [[maybe_unused]]) -> at::Tensor { | ||||
|       return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); | ||||
|     }, | ||||
|     /*has_symbolic_inputs=*/has_symbolic_inputs | ||||
|   ); | ||||
|  | ||||
|   auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); | ||||
|   // See  Note [Propagating strides in the functionalization pass] | ||||
|  | ||||
| @ -1,58 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/FunctionalStorageImpl.h> | ||||
|  | ||||
| namespace at::functionalization { | ||||
|  | ||||
| // `ViewMeta` implementation for `resize_` operation. | ||||
| struct TORCH_API resize__ViewMeta : public ViewMeta { | ||||
|   FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta) | ||||
|   FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( | ||||
|       bool /* reapply_views */, | ||||
|       const std::vector<int64_t>&); | ||||
|  | ||||
|   resize__ViewMeta(const SerializableTuple& tpl) | ||||
|       : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} | ||||
|  | ||||
|   resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size) | ||||
|       : ViewMeta(/*has_symbolic_inputs=*/false), | ||||
|         reapply_views(reapply_views), | ||||
|         size(size) {} | ||||
|  | ||||
|   Tensor forward(const Tensor& base) override; | ||||
|   Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; | ||||
|  | ||||
|   SerializableTuple to_serializable_tuple() { | ||||
|     return std::make_tuple(reapply_views, size); | ||||
|   } | ||||
|  | ||||
|   bool reapply_views; | ||||
|   std::vector<int64_t> size; | ||||
| }; | ||||
|  | ||||
| // `ViewMeta` implementation for `_unsafe_view` operation. | ||||
| struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta { | ||||
|   FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta) | ||||
|   FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( | ||||
|       bool /* has_symbolic_inputs */, | ||||
|       const std::vector<c10::SymInt>&); | ||||
|  | ||||
|   _unsafe_view_ViewMeta(const SerializableTuple& tpl) | ||||
|       : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} | ||||
|  | ||||
|   _unsafe_view_ViewMeta( | ||||
|       bool has_symbolic_inputs, | ||||
|       const std::vector<c10::SymInt>& size) | ||||
|       : ViewMeta(has_symbolic_inputs), size(size) {} | ||||
|  | ||||
|   Tensor forward(const Tensor& base) override; | ||||
|   Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; | ||||
|  | ||||
|   SerializableTuple to_serializable_tuple() { | ||||
|     return std::make_tuple(has_symbolic_inputs, size); | ||||
|   } | ||||
|  | ||||
|   std::vector<c10::SymInt> size; | ||||
| }; | ||||
|  | ||||
| } // namespace at::functionalization | ||||
| @ -1,22 +1,32 @@ | ||||
| #include <ATen/core/PythonOpRegistrationTrampoline.h> | ||||
| #include <c10/core/impl/PyInterpreterHooks.h> | ||||
|  | ||||
| // TODO: delete this | ||||
| namespace at::impl { | ||||
|  | ||||
| c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::interpreter_ = nullptr; | ||||
| // The strategy is that all python interpreters attempt to register themselves | ||||
| // as the main interpreter, but only one wins.  Only that interpreter is | ||||
| // allowed to interact with the C++ dispatcher.  Furthermore, when we execute | ||||
| // logic on that interpreter, we do so hermetically, never setting pyobj field | ||||
| // on Tensor. | ||||
|  | ||||
| std::atomic<c10::impl::PyInterpreter*> | ||||
|     PythonOpRegistrationTrampoline::interpreter_{nullptr}; | ||||
|  | ||||
| c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() { | ||||
|   return c10::impl::getGlobalPyInterpreter(); | ||||
|   return PythonOpRegistrationTrampoline::interpreter_.load(); | ||||
| } | ||||
|  | ||||
| bool PythonOpRegistrationTrampoline::registerInterpreter( | ||||
|     c10::impl::PyInterpreter* interp) { | ||||
|   if (interpreter_ != nullptr) { | ||||
|   c10::impl::PyInterpreter* expected = nullptr; | ||||
|   interpreter_.compare_exchange_strong(expected, interp); | ||||
|   if (expected != nullptr) { | ||||
|     // This is the second (or later) Python interpreter, which means we need | ||||
|     // non-trivial hermetic PyObject TLS | ||||
|     c10::impl::HermeticPyObjectTLS::init_state(); | ||||
|     return false; | ||||
|   } else { | ||||
|     return true; | ||||
|   } | ||||
|   interpreter_ = interp; | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| } // namespace at::impl | ||||
|  | ||||
| @ -2,21 +2,19 @@ | ||||
|  | ||||
| #include <ATen/core/dispatch/Dispatcher.h> | ||||
|  | ||||
| // TODO: We can get rid of this | ||||
| // TODO: this can probably live in c10 | ||||
|  | ||||
|  | ||||
| namespace at::impl { | ||||
|  | ||||
| // Manages the single Python interpreter instance for PyTorch. | ||||
| class TORCH_API PythonOpRegistrationTrampoline final { | ||||
|   static c10::impl::PyInterpreter* interpreter_; | ||||
|   static std::atomic<c10::impl::PyInterpreter*> interpreter_; | ||||
|  | ||||
| public: | ||||
|   // Register the Python interpreter. Returns true on first registration, | ||||
|   // false if an interpreter was already registered. | ||||
|   //  Returns true if you successfully registered yourself (that means | ||||
|   //  you are in the hot seat for doing the operator registrations!) | ||||
|   static bool registerInterpreter(c10::impl::PyInterpreter*); | ||||
|  | ||||
|   // Returns the registered interpreter via the global PyInterpreter hooks. | ||||
|   // Returns nullptr if no interpreter has been registered yet. | ||||
|   static c10::impl::PyInterpreter* getInterpreter(); | ||||
| }; | ||||
|  | ||||
| @ -149,105 +149,5 @@ static inline void pack_vnni4( | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // This is a helper function for transpose_pack_vnni4 | ||||
| // Transform a [4, 16] block (with incontiguous output) | ||||
| // Src: | ||||
| // a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 a16 | ||||
| // b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 b16 | ||||
| // c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 c16 | ||||
| // d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 d16 | ||||
| // Dst: | ||||
| // a1 a2 a3 a4 b1 b2 b3 b4 c1 c2 c3 c4 d1 d2 d3 d4 | ||||
| // a5 a6 a7 a8 b5 b6 b7 b8 c5 c6 c7 c8 d5 d6 d7 d8 | ||||
| // a9 a10 a11 a12 b9 b10 b11 b12 c9 c10 c11 c12 d9 d10 d11 d12 | ||||
| // a13 a14 a15 a16 b13 b14 b15 b16 c13 c14 c15 c16 d13 d14 d15 d16 | ||||
| template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>> | ||||
| static inline void transpose_vnni4_pad_4x16_block( | ||||
|     const scalar_t* src, | ||||
|     scalar_t* dst, | ||||
|     int64_t ld_src, | ||||
|     int64_t ld_dst, | ||||
|     int krem = 4) { | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|   __m128i r[4]; | ||||
|   for (int i = 0; i < krem; ++i) { | ||||
|     r[i] = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * ld_src)); | ||||
|   } | ||||
|   for (int i = krem; i < 4; ++i) { | ||||
|     r[i] = _mm_setzero_si128(); | ||||
|   } | ||||
|  | ||||
|   // Transpose 4x16 bytes using unpack and shuffle | ||||
|   __m128i t0 = _mm_unpacklo_epi32(r[0], r[1]); | ||||
|   __m128i t1 = _mm_unpackhi_epi32(r[0], r[1]); | ||||
|   __m128i t2 = _mm_unpacklo_epi32(r[2], r[3]); | ||||
|   __m128i t3 = _mm_unpackhi_epi32(r[2], r[3]); | ||||
|  | ||||
|   __m128i r0 = _mm_unpacklo_epi64(t0, t2); | ||||
|   __m128i r1 = _mm_unpackhi_epi64(t0, t2); | ||||
|   __m128i r2 = _mm_unpacklo_epi64(t1, t3); | ||||
|   __m128i r3 = _mm_unpackhi_epi64(t1, t3); | ||||
|  | ||||
|   // Store output | ||||
|   if (krem == 4) { | ||||
|     // normal case | ||||
|     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), r0); | ||||
|     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst), r1); | ||||
|     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 2), r2); | ||||
|     _mm_storeu_si128(reinterpret_cast<__m128i*>(dst + ld_dst * 3), r3); | ||||
|   } else { | ||||
|     // masked case | ||||
|     __mmask16 mask = (1ULL << (krem * 4)) - 1; | ||||
|     _mm_mask_storeu_epi8(dst, mask, r0); | ||||
|     _mm_mask_storeu_epi8(reinterpret_cast<__m128i*>(dst + ld_dst), mask, r1); | ||||
|     _mm_mask_storeu_epi8( | ||||
|         reinterpret_cast<__m128i*>(dst + ld_dst * 2), mask, r2); | ||||
|     _mm_mask_storeu_epi8( | ||||
|         reinterpret_cast<__m128i*>(dst + ld_dst * 3), mask, r3); | ||||
|   } | ||||
| #else | ||||
|   TORCH_CHECK( | ||||
|       false, | ||||
|       "transpose_vnni4_pad_4x16_block is only supported when AVX-512 is supported") | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // Do the transpose packing fusion with VNNI4 | ||||
| // Reorder [K, N] → [N/4, K, 4] (VNNI4-style layout for bit8) | ||||
| template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>> | ||||
| static inline void transpose_pack_vnni4( | ||||
|     const scalar_t* src, | ||||
|     scalar_t* dst, | ||||
|     int64_t ld_src, | ||||
|     int64_t K, | ||||
|     int64_t N) { | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|   TORCH_CHECK( | ||||
|       N % 16 == 0, "N needs to be multiple of 16 for transpose_pack_vnni4"); | ||||
|   int64_t bk = 0; | ||||
|   int64_t _K = K / 4 * 4; | ||||
|   for (; bk < _K; bk += 4) { | ||||
|     int64_t bn = 0; | ||||
|     for (; bn < N; bn += 16) { | ||||
|       transpose_vnni4_pad_4x16_block( | ||||
|           src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Handle leftover K rows (< 4) | ||||
|   if (K % 4 != 0) { | ||||
|     int krem = K - bk; | ||||
|     int64_t bn = 0; | ||||
|     for (; bn < N; bn += 16) { | ||||
|       transpose_vnni4_pad_4x16_block( | ||||
|           src + bk * ld_src + bn, dst + bn * K + bk * 4, ld_src, K * 4, krem); | ||||
|     } | ||||
|   } | ||||
| #else | ||||
|   TORCH_CHECK( | ||||
|       false, "transpose_pack_vnni4 is only supported when AVX-512 is supported") | ||||
| #endif | ||||
| } | ||||
|  | ||||
| } // namespace CPU_CAPABILITY | ||||
| } // namespace at::vec | ||||
|  | ||||
| @ -281,9 +281,6 @@ bool CUDAHooks::compiledWithMIOpen() const { | ||||
|  | ||||
| bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const { | ||||
| #if AT_CUDNN_ENABLED() | ||||
|   if (!hasCUDA()) { | ||||
|     return false; | ||||
|   } | ||||
|   // NOTE: extra parenthesis around numbers disable clang warnings about | ||||
|   // dead code | ||||
|   return true; | ||||
| @ -294,9 +291,6 @@ bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const { | ||||
|  | ||||
| bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const { | ||||
| #if AT_CUDNN_ENABLED() | ||||
|   if (!hasCUDA()) { | ||||
|     return false; | ||||
|   } | ||||
|   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||||
|   // Check for Volta cores | ||||
|   if (prop->major >= 7) { | ||||
| @ -311,9 +305,6 @@ bool CUDAHooks::supportsDepthwiseConvolutionWithCuDNN() const { | ||||
|  | ||||
| bool CUDAHooks::supportsBFloat16ConvolutionWithCuDNNv8() const { | ||||
| #if AT_CUDNN_ENABLED() | ||||
|   if (!hasCUDA()) { | ||||
|     return false; | ||||
|   } | ||||
|   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); | ||||
|   // Check for Volta cores | ||||
|   if (prop->major >= 8) { | ||||
|  | ||||
| @ -465,11 +465,8 @@ inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor | ||||
|     return false; | ||||
|   } | ||||
|  | ||||
|   auto is_channel_last = [](const at::Tensor& t) { | ||||
|     auto fmt = t.suggest_memory_format(); | ||||
|     return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d; | ||||
|   }; | ||||
|   return is_channel_last(input) || is_channel_last(weight); | ||||
|   auto fmt = input.suggest_memory_format(); | ||||
|   return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d; | ||||
| } | ||||
|  | ||||
| } // namespace at::native | ||||
|  | ||||
| @ -32,6 +32,10 @@ | ||||
| #include <ATen/native/mkldnn/Utils.h> | ||||
| #endif | ||||
|  | ||||
| #ifdef USE_MPS | ||||
| #include <ATen/mps/MPSDevice.h> | ||||
| #endif | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| @ -1425,8 +1429,12 @@ static inline at::MemoryFormat determine_backend_memory_format( | ||||
|       } | ||||
|       break; | ||||
|     case ConvBackend::Mps: | ||||
|     case ConvBackend::MpsTranspose: | ||||
|       if (mps_conv_use_channels_last(input, weight)) { | ||||
| #ifdef USE_MPS | ||||
|         if (!mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) { | ||||
|           break; | ||||
|         } | ||||
| #endif | ||||
|         backend_memory_format = (k == 5) ? MemoryFormat::ChannelsLast3d : MemoryFormat::ChannelsLast; | ||||
|       } | ||||
|       break; | ||||
|  | ||||
| @ -9,7 +9,6 @@ | ||||
| #include <ATen/native/TransposeType.h> | ||||
| #include <ATen/native/Unfold3d.h> | ||||
| #include <c10/util/irange.h> | ||||
| #include <c10/util/safe_numerics.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Functions.h> | ||||
| @ -175,23 +174,6 @@ static inline void slow_conv3d_shape_check( | ||||
|   const int64_t input_height = input.size(dim_height); | ||||
|   const int64_t input_width = input.size(dim_width); | ||||
|  | ||||
|   constexpr int64_t MAX_SAFE_PAD = (1LL << 61); | ||||
|  | ||||
|   TORCH_CHECK_VALUE( | ||||
|     pad_height <= MAX_SAFE_PAD, | ||||
|     "Padding height too large: pad_height=", | ||||
|     pad_height); | ||||
|  | ||||
|   TORCH_CHECK_VALUE( | ||||
|     pad_width <= MAX_SAFE_PAD, | ||||
|     "Padding width too large: pad_width=", | ||||
|     pad_width); | ||||
|  | ||||
|   TORCH_CHECK_VALUE( | ||||
|     pad_depth <= MAX_SAFE_PAD, | ||||
|     "Padding depth too large: pad_depth=", | ||||
|     pad_depth); | ||||
|  | ||||
|   const int64_t exact_input_depth = input_depth + 2 * pad_depth; | ||||
|   const int64_t exact_input_height = input_height + 2 * pad_height; | ||||
|   const int64_t exact_input_width = input_width + 2 * pad_width; | ||||
| @ -239,14 +221,6 @@ static inline void slow_conv3d_shape_check( | ||||
|       output_width, | ||||
|       "). Output size is too small"); | ||||
|  | ||||
|   uint64_t kernel_product; | ||||
|   TORCH_CHECK( | ||||
|     !c10::mul_overflows(kernel_height, kernel_width, &kernel_product), | ||||
|     "Kernel height x width product is too large: kernel_height=", | ||||
|     kernel_height, | ||||
|     ", kernel_width=", | ||||
|     kernel_width); | ||||
|  | ||||
|   if (weight.defined()) { | ||||
|     int64_t n_input_plane = weight.size(1); | ||||
|     if (weight.dim() == 2) { | ||||
|  | ||||
| @ -23,7 +23,6 @@ | ||||
| #include <ATen/ops/linspace.h> | ||||
| #endif | ||||
|  | ||||
| #include <cmath> | ||||
| #include <numeric> | ||||
| #include <tuple> | ||||
| #include <vector> | ||||
| @ -203,46 +202,6 @@ select_outer_bin_edges(const Tensor& input, std::optional<c10::ArrayRef<double>> | ||||
|     return std::make_pair(leftmost_edges, rightmost_edges); | ||||
| } | ||||
|  | ||||
|  | ||||
| /* Bin edges correction based on the precision representation. | ||||
|  * To maintain the backward compatibility we take max(std::nextafter<>, +1) | ||||
|  * and min(std::nextafter<>, -1) for scalar types. For other types +/- 1 as usual. | ||||
|  */ | ||||
| void bins_edges_correction(const ScalarType& t, double &leftmost_edge, double &rightmost_edge) | ||||
| { | ||||
| #define UPDATE_WITH_LIMIT(real_type, scalartype) \ | ||||
|   case ScalarType::scalartype:                   \ | ||||
|     leftmost_edge = std::min(                    \ | ||||
|         static_cast<double>(                     \ | ||||
|             std::nexttoward(                     \ | ||||
|                 static_cast<real_type>(leftmost_edge),   \ | ||||
|                 std::numeric_limits<real_type>::lowest() \ | ||||
|             )                                    \ | ||||
|         ),                                       \ | ||||
|         leftmost_edge - 1.                       \ | ||||
|     );                                           \ | ||||
|     rightmost_edge = std::max(                   \ | ||||
|         static_cast<double>(                     \ | ||||
|             std::nexttoward(                     \ | ||||
|                 static_cast<real_type>(rightmost_edge), \ | ||||
|                 std::numeric_limits<real_type>::max()   \ | ||||
|             )                                    \ | ||||
|         ),                                       \ | ||||
|         rightmost_edge + 1.                      \ | ||||
|     );                                           \ | ||||
|     break; | ||||
|  | ||||
|     switch (t) { | ||||
|         UPDATE_WITH_LIMIT(double, Double) | ||||
|         UPDATE_WITH_LIMIT(float, Float) | ||||
|         default: | ||||
|             // Fallback to the default behavior for other types | ||||
|             leftmost_edge -= 1; | ||||
|             rightmost_edge += 1; | ||||
|     } | ||||
| #undef UPDATE_WITH_LIMIT | ||||
| } | ||||
|  | ||||
| /* histc's version of the logic for outermost bin edges. | ||||
|  */ | ||||
| std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input, | ||||
| @ -257,7 +216,8 @@ std::pair<double, double> histc_select_outer_bin_edges(const Tensor& input, | ||||
|     } | ||||
|  | ||||
|     if (leftmost_edge == rightmost_edge) { | ||||
|         bins_edges_correction(input.dtype().toScalarType(), leftmost_edge, rightmost_edge); | ||||
|         leftmost_edge -= 1; | ||||
|         rightmost_edge += 1; | ||||
|     } | ||||
|  | ||||
|     TORCH_CHECK(!(std::isinf(leftmost_edge) || std::isinf(rightmost_edge) || | ||||
|  | ||||
| @ -42,19 +42,6 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) { | ||||
|     }); | ||||
| } | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { | ||||
|     gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) { | ||||
|         return static_cast<float>(value); | ||||
|     }); | ||||
| } | ||||
| void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { | ||||
|     gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) { | ||||
|         return static_cast<float>(value); | ||||
|     }); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| void float8_copy_kernel_cuda(TensorIteratorBase &iter) { | ||||
|   ScalarType dtype = iter.dtype(0); | ||||
|   ScalarType other_dtype = iter.dtype(1); | ||||
| @ -200,17 +187,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { | ||||
|      } else { | ||||
|        float16_copy_kernel_cuda(iter); | ||||
|      } | ||||
|   } | ||||
| #ifdef USE_ROCM | ||||
|   else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) { | ||||
|     if (iter.dtype(1) == kBFloat16) { | ||||
|       bfloat16tofloat32_copy_kernel_cuda(iter); | ||||
|     } else { | ||||
|       float16tofloat32_copy_kernel_cuda(iter); | ||||
|     } | ||||
|   } | ||||
| #endif | ||||
|   else if (isBitsType(dtype)) { | ||||
|   } else if (isBitsType(dtype)) { | ||||
|     TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " | ||||
|       "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); | ||||
|     AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { | ||||
|  | ||||
| @ -52,7 +52,9 @@ static void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* | ||||
|                                      NSUInteger dilationRateInX, | ||||
|                                      NSUInteger dilationRateInY, | ||||
|                                      NSUInteger paddingHorizontal, | ||||
|                                      NSUInteger paddingVertical) { | ||||
|                                      NSUInteger paddingVertical, | ||||
|                                      c10::MemoryFormat memory_format, | ||||
|                                      NSUInteger groups) { | ||||
|   descriptor_.strides = | ||||
|       @[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ]; | ||||
|   descriptor_.dilationRates = | ||||
| @ -101,7 +103,7 @@ static void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, | ||||
|   descriptor_.groups = groups; | ||||
| } | ||||
|  | ||||
| static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
| static Tensor _mps_convolution_impl(const Tensor& input_t_, | ||||
|                                     const Tensor& weight_t, | ||||
|                                     const std::optional<Tensor>& bias_opt, | ||||
|                                     IntArrayRef padding, | ||||
| @ -109,15 +111,12 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|                                     IntArrayRef dilation, | ||||
|                                     int64_t groups, | ||||
|                                     std::optional<IntArrayRef> input_shape) { | ||||
|   constexpr auto kChannelsLast = MemoryFormat::ChannelsLast; | ||||
|   constexpr auto kContiguous = MemoryFormat::Contiguous; | ||||
|   const bool is_macos_15_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); | ||||
|  | ||||
|   const bool is3DConv = input_t.dim() == 5; | ||||
|   const auto memory_format = input_t.suggest_memory_format(); | ||||
|   const auto input_suggested_layout = memory_format == kChannelsLast && is_macos_15_plus ? kChannelsLast : kContiguous; | ||||
|   const bool is_channels_last = mps_conv_use_channels_last(input_t, weight_t) && !is3DConv; | ||||
|   const bool bias_defined = bias_opt ? bias_opt->defined() : false; | ||||
|   const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); | ||||
|   Tensor input_t = input_t_; | ||||
|   bool is3DConv = input_t.dim() == 5; | ||||
|   if (!is_macOS_15_0_or_newer || is3DConv) { | ||||
|     input_t = input_t.contiguous(); | ||||
|   } | ||||
|  | ||||
|   TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); | ||||
|  | ||||
| @ -127,6 +126,15 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|   checkAllSameType(c, {input, weight}); | ||||
|   checkAllSameGPU(c, {input, weight}); | ||||
|  | ||||
|   bool bias_defined; | ||||
|  | ||||
|   if (bias_opt == std::nullopt) | ||||
|     bias_defined = false; | ||||
|   else | ||||
|     bias_defined = bias_opt->defined(); | ||||
|  | ||||
|   auto memory_format = input_t.suggest_memory_format(); | ||||
|   bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv; | ||||
|   auto output_t = | ||||
|       at::empty(input_shape.has_value() ? input_shape.value() | ||||
|                                         : conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), | ||||
| @ -134,18 +142,12 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|                 std::nullopt, | ||||
|                 kMPS, | ||||
|                 std::nullopt, | ||||
|                 is_channels_last ? kChannelsLast : kContiguous); | ||||
|                 is_macOS_15_0_or_newer ? memory_format : MemoryFormat::Contiguous); | ||||
|   if (output_t.numel() == 0) { | ||||
|     return output_t; | ||||
|   } | ||||
|   TensorArg output{output_t, "result", 0}; | ||||
|  | ||||
|   // TODO: Remove me when MacOS-14 is no longer supported | ||||
|   std::optional<Tensor> output_c; | ||||
|   if (!is_macos_15_plus && is_channels_last) { | ||||
|     output_c = at::empty_like(output_t, output_t.options().memory_format(kContiguous)); | ||||
|   } | ||||
|  | ||||
|   if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) { | ||||
|     // On macOS < 15.1, MPS convolution kernel does not support output channels > 2^16 | ||||
|     for (auto elem : output_t.sizes()) { | ||||
| @ -184,22 +186,32 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|                                   getArrayRefString(dilation), | ||||
|                                   getArrayRefString(padding), | ||||
|                                   groups, | ||||
|                                   input_suggested_layout == kChannelsLast, | ||||
|                                   is_channels_last, | ||||
|                                   mps::getTensorsStringKey({input_t, weight_t}), | ||||
|                                   bias_defined, | ||||
|                                   bias_shape_key); | ||||
|  | ||||
|     auto inputShape = mps::getMPSShape(input_t, input_suggested_layout); | ||||
|     auto outputShape = mps::getMPSShape(output_t, input_suggested_layout); | ||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||
|       bool isDepthwiseConv = | ||||
|           (groups > 1 && weight_t.size(1) == 1) && input_t.dim() >= 4 && weight_t.dim() >= 4 && !is_channels_last; | ||||
|     MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); | ||||
|     MPSShape* outputShape = mps::getMPSShape(output_t, memory_format); | ||||
|     MPSNDArray* inputNDArray = nil; | ||||
|     MPSNDArray* outputNDArray = nil; | ||||
|  | ||||
|       auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t), inputShape); | ||||
|       auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t); | ||||
|       MPSGraphTensor* outputTensor = nil; | ||||
|     if (input_t.is_contiguous(memory_format) && output_t.is_contiguous(memory_format) && is_macOS_15_0_or_newer) { | ||||
|       inputNDArray = getMPSNDArray(input_t, inputShape); | ||||
|       outputNDArray = getMPSNDArray(output_t, outputShape); | ||||
|     } | ||||
|  | ||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||
|       MPSShape* weightShape = mps::getMPSShape(weight_t); | ||||
|       bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 && | ||||
|                               weightShape.count >= 4 && !is_channels_last); | ||||
|  | ||||
|       MPSGraphTensor* inputTensor = | ||||
|           mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t.scalar_type()), inputShape); | ||||
|       MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t); | ||||
|       MPSGraphTensor* outputTensor; | ||||
|       if (is3DConv) { | ||||
|         auto conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease]; | ||||
|         MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease]; | ||||
|         fill_conv3d_desc(conv3dDescriptor_, | ||||
|                          stride[2], | ||||
|                          stride[1], | ||||
| @ -217,9 +229,17 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|                                                     descriptor:conv3dDescriptor_ | ||||
|                                                           name:nil]; | ||||
|       } else if (isDepthwiseConv) { | ||||
|         auto depthWiseConv3dDescriptor_ = [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; | ||||
|         fill_depthwise_conv_desc( | ||||
|             depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]); | ||||
|         MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = | ||||
|             [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; | ||||
|         fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, | ||||
|                                  stride[1], | ||||
|                                  stride[0], | ||||
|                                  dilation[1], | ||||
|                                  dilation[0], | ||||
|                                  padding[1], | ||||
|                                  padding[0], | ||||
|                                  memory_format, | ||||
|                                  groups); | ||||
|  | ||||
|         MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor | ||||
|                                                                 dimension:-3 | ||||
| @ -238,7 +258,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|                        dilation[0], | ||||
|                        padding[1], | ||||
|                        padding[0], | ||||
|                        input_suggested_layout, | ||||
|                        memory_format, | ||||
|                        groups); | ||||
|  | ||||
|         outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor | ||||
| @ -250,6 +270,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|       MPSGraphTensor* biasTensor = nil; | ||||
|       if (bias_defined) { | ||||
|         biasTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(bias_opt.value())); | ||||
|       } | ||||
|  | ||||
|       if (is_channels_last && !is_macOS_15_0_or_newer) { | ||||
|         outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor); | ||||
|       } | ||||
|  | ||||
|       if (bias_defined) { | ||||
|         outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil]; | ||||
|       } | ||||
|       newCachedGraph->inputTensor_ = inputTensor; | ||||
| @ -258,26 +285,27 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|       newCachedGraph->outputTensor_ = outputTensor; | ||||
|     }); | ||||
|  | ||||
|     auto inputPlaceholder = input_suggested_layout == kContiguous | ||||
|         ? Placeholder(cachedGraph->inputTensor_, output_c || is3DConv ? input_t.contiguous() : input_t) | ||||
|         : Placeholder(cachedGraph->inputTensor_, getMPSNDArray(input_t, inputShape)); | ||||
|     auto outputPlaceholder = input_suggested_layout == kContiguous | ||||
|         ? Placeholder(cachedGraph->outputTensor_, output_c ? *output_c : output_t) | ||||
|         : Placeholder(cachedGraph->outputTensor_, getMPSNDArray(output_t, outputShape)); | ||||
|     auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, output_c ? weight_t.contiguous() : weight_t); | ||||
|     auto inputPlaceholder = inputNDArray ? Placeholder(cachedGraph->inputTensor_, inputNDArray) | ||||
|                                          : Placeholder(cachedGraph->inputTensor_, input_t, inputShape); | ||||
|     auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); | ||||
|     auto biasPlaceholder = Placeholder(); | ||||
|     // Reshape the bias to be broadcastable with output of conv2d or conv3d | ||||
|     if (bias_defined) { | ||||
|       if (is3DConv) { | ||||
|         biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, bias_shape[0], 1, 1, 1})); | ||||
|       } else if (input_suggested_layout == kChannelsLast) { | ||||
|         biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, 1, 1, bias_shape[0]})); | ||||
|         biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1, 1})); | ||||
|       } else { | ||||
|         biasPlaceholder = Placeholder(cachedGraph->biasTensor_, bias_opt->view({1, bias_shape[0], 1, 1})); | ||||
|         if (is_channels_last && is_macOS_15_0_or_newer) { | ||||
|           biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, 1, 1, bias_shape[0]})); | ||||
|         } else { | ||||
|           biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1})); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     auto outputPlaceholder = outputNDArray ? Placeholder(cachedGraph->outputTensor_, outputNDArray) | ||||
|                                            : Placeholder(cachedGraph->outputTensor_, output_t); | ||||
|  | ||||
|     auto feeds = [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease]; | ||||
|     NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = | ||||
|         [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease]; | ||||
|     feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); | ||||
|     feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData(); | ||||
|     if (bias_defined) { | ||||
| @ -287,10 +315,6 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, | ||||
|     runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); | ||||
|   } | ||||
|  | ||||
|   if (output_c) { | ||||
|     output_t.copy_(*output_c); | ||||
|   } | ||||
|  | ||||
|   return output_t; | ||||
| } | ||||
|  | ||||
| @ -327,21 +351,14 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, | ||||
|   TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2}; | ||||
|   checkAllSameType(c, {grad_output, weight}); | ||||
|   checkAllSameGPU(c, {grad_output, weight}); | ||||
|   constexpr auto kChannelsLast = at::MemoryFormat::ChannelsLast; | ||||
|   bool is_channels_last = mps_conv_use_channels_last(grad_output_t, weight_t) && !is3DConv; | ||||
|   auto grad_input_t = | ||||
|       at::empty(input_size, grad_output_t.options(), is_channels_last ? std::optional(kChannelsLast) : std::nullopt); | ||||
|   auto memory_format = grad_output_t.suggest_memory_format(); | ||||
|   bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv; | ||||
|   auto grad_input_t = at::empty(input_size, grad_output_t.options(), std::nullopt); | ||||
|  | ||||
|   // Avoid "grad_input" when this is being used as transposed convolution | ||||
|   TensorArg grad_input{grad_input_t, "result", 0}; | ||||
|   convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); | ||||
|  | ||||
|   // TODO: Remove me when MacOS-14 is no longer supported | ||||
|   std::optional<Tensor> grad_input_c; | ||||
|   if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_channels_last) { | ||||
|     grad_input_c = at::empty_like(grad_input_t, grad_input_t.options().memory_format(MemoryFormat::Contiguous)); | ||||
|   } | ||||
|  | ||||
|   // Derive from MPSCachedGraph | ||||
|   struct CachedGraph : public MPSCachedGraph { | ||||
|     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} | ||||
| @ -353,6 +370,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, | ||||
|   // Add backward with input | ||||
|   @autoreleasepool { | ||||
|     MPSStream* stream = getCurrentMPSStream(); | ||||
|  | ||||
|     MPSShape* mps_input_shape = getMPSShape(input_size); | ||||
|     std::string key = fmt::format("mps_{}_convolution_backward_input:{}:{}:{}:{}:{}:{}", | ||||
|                                   is3DConv ? "3d_" : "", | ||||
| @ -393,8 +411,15 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, | ||||
|       } else if (isDepthwiseConv) { | ||||
|         MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = | ||||
|             [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; | ||||
|         fill_depthwise_conv_desc( | ||||
|             depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]); | ||||
|         fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, | ||||
|                                  stride[1], | ||||
|                                  stride[0], | ||||
|                                  dilation[1], | ||||
|                                  dilation[0], | ||||
|                                  padding[1], | ||||
|                                  padding[0], | ||||
|                                  at::MemoryFormat::Contiguous, | ||||
|                                  groups); | ||||
|         MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor | ||||
|                                                                 dimension:-3 | ||||
|                                                             withDimension:-4 | ||||
| @ -429,18 +454,14 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, | ||||
|       newCachedGraph->gradInputTensor_ = gradInputTensor; | ||||
|     }); | ||||
|  | ||||
|     auto gradOutputPlaceholder = | ||||
|         Placeholder(cachedGraph->gradOutputTensor_, grad_input_c ? grad_output_t.contiguous() : grad_output_t); | ||||
|     auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, grad_input_c ? weight_t.contiguous() : weight_t); | ||||
|     auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input_c ? *grad_input_c : grad_input_t); | ||||
|     auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t); | ||||
|     auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); | ||||
|     auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input); | ||||
|  | ||||
|     auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, weightsPlaceholder); | ||||
|     runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); | ||||
|   } | ||||
|   if (grad_input_c) { | ||||
|     grad_input_t.copy_(*grad_input_c); | ||||
|   } | ||||
|   return grad_input_t; | ||||
|   return *grad_input; | ||||
| } | ||||
|  | ||||
| static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
| @ -453,11 +474,9 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
|                                                bool bias_defined) { | ||||
|   using namespace at::native::mps; | ||||
|   using namespace mps; | ||||
|   const bool is3DConv = input_t.dim() == 5; | ||||
|   bool is3DConv = input_t.dim() == 5; | ||||
|   TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); | ||||
|   CheckedFrom c = "mps_convolution_backward_weights"; | ||||
|   constexpr auto kChannelsLast = at::MemoryFormat::ChannelsLast; | ||||
|   bool is_channels_last = mps_conv_use_channels_last(input_t, grad_output_t) && !is3DConv; | ||||
|  | ||||
|   // For uniformity with everything else, although it seems grad_weight | ||||
|   // would be unambiguous too. | ||||
| @ -468,8 +487,7 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
|   checkAllSameGPU(c, {grad_output, input}); | ||||
|  | ||||
|   auto grad_weight_t = | ||||
|       at::empty(weight_size, grad_output_t.options(), is_channels_last ? std::optional(kChannelsLast) : std::nullopt); | ||||
|  | ||||
|       at::empty(weight_size, grad_output_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); | ||||
|   TensorArg grad_weight{grad_weight_t, "result", 0}; | ||||
|  | ||||
|   convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups); | ||||
| @ -482,23 +500,16 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
|     MPSGraphTensor* gradWeightTensor_ = nil; | ||||
|   }; | ||||
|  | ||||
|   // TODO: Remove me when MacOS-14 is no longer supported | ||||
|   std::optional<Tensor> grad_weight_c; | ||||
|   if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_channels_last) { | ||||
|     grad_weight_c = at::empty_like(grad_weight_t, grad_weight_t.options().memory_format(MemoryFormat::Contiguous)); | ||||
|   } | ||||
|  | ||||
|   @autoreleasepool { | ||||
|     MPSStream* stream = getCurrentMPSStream(); | ||||
|  | ||||
|     MPSShape* mps_weight_shape = getMPSShape(weight_size); | ||||
|     std::string key = fmt::format("mps_{}convolution_backward_weights:{}:{}:{}:{}:{}:{}", | ||||
|     std::string key = fmt::format("mps_{}convolution_backward_weights:{}:{}:{}:{}:{}", | ||||
|                                   is3DConv ? "3d_" : "", | ||||
|                                   getArrayRefString(stride), | ||||
|                                   getArrayRefString(dilation), | ||||
|                                   getArrayRefString(padding), | ||||
|                                   groups, | ||||
|                                   is_channels_last, | ||||
|                                   getTensorsStringKey({grad_output_t, input_t, grad_weight_t})); | ||||
|     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { | ||||
|       MPSShape* inputShape = getMPSShape(input_t); | ||||
| @ -530,8 +541,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
|       } else if (isDepthwiseConv) { | ||||
|         MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = | ||||
|             [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; | ||||
|         fill_depthwise_conv_desc( | ||||
|             depthWiseConv3dDescriptor_, stride[1], stride[0], dilation[1], dilation[0], padding[1], padding[0]); | ||||
|         fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, | ||||
|                                  stride[1], | ||||
|                                  stride[0], | ||||
|                                  dilation[1], | ||||
|                                  dilation[0], | ||||
|                                  padding[1], | ||||
|                                  padding[0], | ||||
|                                  at::MemoryFormat::Contiguous, | ||||
|                                  groups); | ||||
|         NSNumber* outputFeatChannelDim = mps_weight_shape[0]; | ||||
|         MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ]; | ||||
|         MPSGraphTensor* gradWeightTensorTranspose = | ||||
| @ -565,19 +583,14 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, | ||||
|       newCachedGraph->gradWeightTensor_ = gradWeightTensor; | ||||
|     }); | ||||
|  | ||||
|     auto gradOutputPlaceholder = | ||||
|         Placeholder(cachedGraph->gradOutputTensor_, grad_weight_c ? grad_output_t.contiguous() : grad_output_t); | ||||
|     auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, grad_weight_c ? input_t.contiguous() : input_t); | ||||
|     auto outputPlaceholder = | ||||
|         Placeholder(cachedGraph->gradWeightTensor_, grad_weight_c ? *grad_weight_c : grad_weight_t); | ||||
|     auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t); | ||||
|     auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); | ||||
|     auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t); | ||||
|  | ||||
|     auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder); | ||||
|     runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); | ||||
|   } | ||||
|  | ||||
|   if (grad_weight_c) { | ||||
|     grad_weight_t.copy_(*grad_weight_c); | ||||
|   } | ||||
|   return grad_weight_t; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -158,46 +158,12 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack( | ||||
|   return packed_ptr; | ||||
| } | ||||
|  | ||||
| #ifdef USE_FBGEMM | ||||
| namespace { | ||||
| /// Number of columns in the rowwise min/max buffer passed to the quantization function(s) | ||||
| constexpr int kRowwiseMinMaxNumCols = 2; | ||||
|  | ||||
| bool _validate_rowwise_min_max( | ||||
|   const at::Tensor& weight, | ||||
|   const std::optional<at::Tensor>& rowwise_min_max_opt) { | ||||
|   const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value(); | ||||
|  | ||||
|   if (is_valid_rowwise_min_max) { | ||||
|       TORCH_CHECK( | ||||
|         (rowwise_min_max_opt->dim() == 2 && | ||||
|         rowwise_min_max_opt->size(0) == weight.size(0) && | ||||
|         rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols), | ||||
|         "'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2]."); | ||||
|   } | ||||
|  | ||||
|   return is_valid_rowwise_min_max; | ||||
| } | ||||
|  | ||||
| auto _get_rowwise_min_max_contig( | ||||
|   const std::optional<at::Tensor>& rowwise_min_max_opt) { | ||||
|     return rowwise_min_max_opt.has_value() | ||||
|       ? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format()) | ||||
|       : at::borrow_from_optional_tensor(rowwise_min_max_opt); | ||||
| } | ||||
| } | ||||
| #endif // USE_FBGEMM | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| // Note - This is a temporary pack function for embedding bag which quantizes | ||||
| // and packs the float weight tensor. In the next step it will be replaced by a | ||||
| // quantize and pack function once we support FP scale and FP zero_point | ||||
| // | ||||
| // The optional rowwise_min_max argument is to support callers to pass in the min/max | ||||
| // values of the weight tensor. If the rowwise_min_max is not provided, the min/max | ||||
| // values will be computed from the weight tensor. | ||||
| // | ||||
| // Python example examining a packed 8bit zero_point and scale: | ||||
| // | ||||
| // >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], | ||||
| @ -255,10 +221,7 @@ namespace at::native { | ||||
| // | ||||
| //        [[50.        , 60.00000035], | ||||
| //         [70.        , 80.00000035]]]) | ||||
| Tensor& qembeddingbag_byte_prepack_out( | ||||
|     Tensor& output, | ||||
|     const Tensor& weight, | ||||
|     const std::optional<Tensor>& rowwise_min_max_opt) { | ||||
| Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { | ||||
|   // The "last" dimension of an N-Dimensioned batch of embedding bags is | ||||
|   // quantization channel. E.g. for a 2D embedding bag, this has | ||||
|   // [ row, col ] dimensions, for batched of embedding bags, dimensions might be | ||||
| @ -293,16 +256,9 @@ Tensor& qembeddingbag_byte_prepack_out( | ||||
|   auto* output_data = output.data_ptr<uint8_t>(); | ||||
|  | ||||
| #ifdef USE_FBGEMM | ||||
|   // Move these outside of the ifdef when we support non-FBGEMM flow. | ||||
|   const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt); | ||||
|   const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt); | ||||
|  | ||||
|   if (weight_contig->scalar_type() == at::ScalarType::Half) { | ||||
|     const auto weight_data = | ||||
|         static_cast<fbgemm::float16*>(weight_contig->data_ptr()); | ||||
|     const auto rowwise_min_max_data = is_valid_rowwise_min_max | ||||
|         ? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr()) | ||||
|         : nullptr; | ||||
|     at::parallel_for( | ||||
|         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { | ||||
|           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat< | ||||
| @ -310,21 +266,17 @@ Tensor& qembeddingbag_byte_prepack_out( | ||||
|               weight_data + start_idx * embedding_cols, | ||||
|               end_idx - start_idx, | ||||
|               embedding_cols, | ||||
|               output_data + start_idx * output_columns, | ||||
|               (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); | ||||
|               output_data + start_idx * output_columns); | ||||
|         }); | ||||
|   } else { | ||||
|     const auto weight_data = weight_contig->data_ptr<float>(); | ||||
|     const auto rowwise_min_max_data = | ||||
|         is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr; | ||||
|     at::parallel_for( | ||||
|         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { | ||||
|           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>( | ||||
|               weight_data + start_idx * embedding_cols, | ||||
|               end_idx - start_idx, | ||||
|               embedding_cols, | ||||
|               output_data + start_idx * output_columns, | ||||
|               (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); | ||||
|               output_data + start_idx * output_columns); | ||||
|         }); | ||||
|   } | ||||
|  | ||||
| @ -374,22 +326,6 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { | ||||
|   return output; | ||||
| } | ||||
|  | ||||
| static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max( | ||||
|     const Tensor& weight, | ||||
|     const Tensor& rowwise_min_max) { | ||||
|   const auto weight_contig = | ||||
|       weight.expect_contiguous(weight.suggest_memory_format()); | ||||
|   Tensor output = at::detail::empty_cpu( | ||||
|       {0}, | ||||
|       at::kByte, | ||||
|       weight_contig->layout(), | ||||
|       weight_contig->device(), | ||||
|       std::nullopt, | ||||
|       std::nullopt); | ||||
|   qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max); | ||||
|   return output; | ||||
| } | ||||
|  | ||||
| Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { | ||||
|   const auto weight_contig = | ||||
|       weight.expect_contiguous(weight.suggest_memory_format()); | ||||
| @ -399,7 +335,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { | ||||
|       "'embedding_bag_byte_prepack' only support float32 or float16."); | ||||
|   const auto weight_sizes = weight.sym_sizes(); | ||||
|   const auto cols_dim = weight.ndimension() - 1; | ||||
|   const auto& embedding_cols = weight_sizes[cols_dim]; | ||||
|   const auto embedding_cols = weight_sizes[cols_dim]; | ||||
|   // Add 8 bytes per column to store FP32 scale and zero_point per row. | ||||
|   const auto output_columns = embedding_cols + 2 * sizeof(float); | ||||
|  | ||||
| @ -423,8 +359,7 @@ Tensor _qembeddingbag_nbit_prepack_helper( | ||||
|     int bit_width, | ||||
|     const bool optimized_qparams, | ||||
|     const int64_t nbins, | ||||
|     const double ratio, | ||||
|     const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt) { | ||||
|     const double ratio) { | ||||
|   TORCH_CHECK( | ||||
|       weight.scalar_type() == at::ScalarType::Float || | ||||
|           weight.scalar_type() == at::ScalarType::Half, | ||||
| @ -466,17 +401,10 @@ Tensor _qembeddingbag_nbit_prepack_helper( | ||||
|   auto* output_data = output.data_ptr<uint8_t>(); | ||||
|  | ||||
| #ifdef USE_FBGEMM | ||||
|   // Move these outside of the ifdef when we support non-FBGEMM flow. | ||||
|   const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt); | ||||
|   const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt); | ||||
|  | ||||
|   if (!optimized_qparams) { | ||||
|     if (weight_contig.scalar_type() == at::ScalarType::Half) { | ||||
|       const auto weight_data = | ||||
|           static_cast<fbgemm::float16*>(weight_contig.data_ptr()); | ||||
|       const auto rowwise_min_max_data = is_valid_rowwise_min_max | ||||
|           ? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr()) | ||||
|           : nullptr; | ||||
|       at::parallel_for( | ||||
|           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { | ||||
|             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf< | ||||
| @ -485,13 +413,10 @@ Tensor _qembeddingbag_nbit_prepack_helper( | ||||
|                 weight_data + start_idx * embedding_cols, | ||||
|                 end_idx - start_idx, | ||||
|                 static_cast<int>(embedding_cols), | ||||
|                 output_data + start_idx * output_shape[1], | ||||
|                 (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); | ||||
|                 output_data + start_idx * output_shape[1]); | ||||
|           }); | ||||
|     } else { | ||||
|       const auto weight_data = weight_contig.data_ptr<float>(); | ||||
|       const auto rowwise_min_max_data = | ||||
|           is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr; | ||||
|       at::parallel_for( | ||||
|           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { | ||||
|             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>( | ||||
| @ -499,8 +424,7 @@ Tensor _qembeddingbag_nbit_prepack_helper( | ||||
|                 weight_data + start_idx * embedding_cols, | ||||
|                 end_idx - start_idx, | ||||
|                 static_cast<int>(embedding_cols), | ||||
|                 output_data + start_idx * output_shape[1], | ||||
|                 (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); | ||||
|                 output_data + start_idx * output_shape[1]); | ||||
|           }); | ||||
|     } | ||||
|   } else { | ||||
| @ -590,16 +514,6 @@ Tensor qembeddingbag_4bit_prepack( | ||||
|       weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio); | ||||
| } | ||||
|  | ||||
| Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max( | ||||
|     const Tensor& weight, | ||||
|     const Tensor& rowwise_min_max, | ||||
|     const bool optimized_qparams, | ||||
|     const int64_t nbins, | ||||
|     const double ratio) { | ||||
|   return _qembeddingbag_nbit_prepack_helper( | ||||
|       weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max); | ||||
| } | ||||
|  | ||||
| // Applies 2-bit row-wise quantization by determining the range | ||||
| // (maximum - minimum) and bias (minimum value) of each row in the input | ||||
| // matrix, and then scaling each element to an 2-bit number between 0 and | ||||
| @ -617,16 +531,6 @@ Tensor qembeddingbag_2bit_prepack( | ||||
|       weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio); | ||||
| } | ||||
|  | ||||
| Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max( | ||||
|     const Tensor& weight, | ||||
|     const Tensor& rowwise_min_max, | ||||
|     const bool optimized_qparams, | ||||
|     const int64_t nbins, | ||||
|     const double ratio) { | ||||
|   return _qembeddingbag_nbit_prepack_helper( | ||||
|       weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max); | ||||
| } | ||||
|  | ||||
| class QEmbeddingPackWeights final { | ||||
|  public: | ||||
|   static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) { | ||||
| @ -638,21 +542,12 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), | ||||
|       TORCH_FN(qembeddingbag_byte_prepack)); | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"), | ||||
|       TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max)); | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), | ||||
|       TORCH_FN(qembeddingbag_4bit_prepack)); | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"), | ||||
|       TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max)); | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), | ||||
|       TORCH_FN(qembeddingbag_2bit_prepack)); | ||||
|   m.impl( | ||||
|       TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"), | ||||
|       TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max)); | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { | ||||
|  | ||||
| @ -3,10 +3,7 @@ | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| Tensor& qembeddingbag_byte_prepack_out( | ||||
|     Tensor& output, | ||||
|     const Tensor& weight, | ||||
|     const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt); | ||||
| Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); | ||||
|  | ||||
| Tensor qembeddingbag_byte_prepack(const Tensor& weight); | ||||
|  | ||||
|  | ||||
| @ -121,12 +121,9 @@ TORCH_LIBRARY(quantized, m) { | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|   m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); | ||||
|  | ||||
| @ -120,7 +120,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { | ||||
|   // buffer (in bytes) | ||||
|   size_t orig_m = sparse_input.size(0); | ||||
|   size_t div = orig_m * sparse_input.itemsize(); | ||||
|   size_t new_n = (compressed_size + div - 1) / div; // ceil(s,d) = (s+d-1)/d | ||||
|   size_t new_n = (compressed_size + div - 1) / div; // floor | ||||
|   auto compressed_tensor = sparse_input.new_empty({(int64_t)orig_m, (int64_t)new_n}); | ||||
|  | ||||
|   auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); | ||||
| @ -155,7 +155,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl( | ||||
|     TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); | ||||
|     handle_initialized = true; | ||||
|   } | ||||
|   // cuSPARSELt constructs | ||||
|   // cupsarselt constructs | ||||
|   cusparseLtMatmulDescriptor_t matmul; | ||||
|   cusparseLtMatmulPlan_t plan; | ||||
|   cusparseLtMatmulAlgSelection_t alg_sel; | ||||
|  | ||||
| @ -176,28 +176,6 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { | ||||
|     } | ||||
|     return false; | ||||
|   } | ||||
|   if constexpr(caller_is_meff) { | ||||
|     bool is_half = (params.query.dtype() == at::kHalf) || | ||||
|       (params.query.dtype() == at::kBFloat16); | ||||
|     const int64_t alignment = is_half ? 8 : 4; | ||||
|     if (!(query_size_last % alignment == 0 && query_size_last > 0 && | ||||
|           value_size_last % alignment == 0 && value_size_last > 0)) { | ||||
|       if (debug) { | ||||
|         TORCH_WARN( | ||||
|             "Mem efficient attention requires last dimension of inputs to be divisible by ", | ||||
|             alignment, | ||||
|             ". ", | ||||
|             "Got Query.size(-1): ", | ||||
|             query_size_last, | ||||
|             ", Key.size(-1): ", | ||||
|             params.key.sym_size(-1), | ||||
|             ", Value.size(-1): ", | ||||
|             params.value.sym_size(-1), | ||||
|             " instead."); | ||||
|       } | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   return true; | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -462,11 +462,10 @@ mha_varlen_fwd_aot(const at::Tensor &q,  // total_q x num_heads x head_size, tot | ||||
|     using sdp::aotriton_adapter::mk_aotensor; | ||||
|     using sdp::aotriton_adapter::mk_aoscalartensor; | ||||
|     using sdp::aotriton_adapter::mk_philoxtensor; | ||||
|     using sdp::aotriton_adapter::mk_atomictensor; | ||||
|     using sdp::aotriton_adapter::cast_dtype; | ||||
|     at::Tensor atomic_counter; | ||||
|     if (is_causal) { | ||||
|       atomic_counter = at::zeros({1}, q.options().dtype(at::kInt)); | ||||
|       atomic_counter = at::zeros({1}, q.options()); | ||||
|     } | ||||
|     aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); | ||||
|     auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); | ||||
| @ -475,7 +474,7 @@ mha_varlen_fwd_aot(const at::Tensor &q,  // total_q x num_heads x head_size, tot | ||||
|     auto nullscalar = mk_philoxtensor(nullptr); | ||||
|     auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar; | ||||
|     auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar; | ||||
|     auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr); | ||||
|     auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar; | ||||
|     if (uses_swa || AOTRITON_ALWAYS_V3_API) { | ||||
| #if AOTRITON_V3_API | ||||
|       using aotriton::v3::flash::CausalType; | ||||
|  | ||||
| @ -2,12 +2,22 @@ | ||||
|  | ||||
| // ${generated_comment} | ||||
|  | ||||
| #include <ATen/FunctionalStorageImpl.h> | ||||
| #include <ATen/Tensor.h> | ||||
|  | ||||
| namespace at { | ||||
| namespace functionalization { | ||||
|  | ||||
| enum class InverseReturnMode { | ||||
|   /// Specifies that functional inverses should always return a view. | ||||
|   AlwaysView, | ||||
|   /// Specifies that functional inverses should always return a non-view / copy. | ||||
|   NeverView, | ||||
|   /// Specifies that functional inverses should return a view unless a (copying) scatter | ||||
|   /// inverse exists, in which case that will be used instead. | ||||
|   /// This avoids as_strided() calls that can be difficult for subclasses to handle. | ||||
|   ViewOrScatterInverse, | ||||
| }; | ||||
|  | ||||
| struct FunctionalInverses { | ||||
|  | ||||
| ${view_inverse_declarations} | ||||
|  | ||||
| @ -4,7 +4,7 @@ | ||||
| #include <ATen/core/LegacyTypeDispatch.h> | ||||
| #include <ATen/EmptyTensor.h> | ||||
| #include <ATen/FunctionalTensorWrapper.h> | ||||
| #include <ATen/ViewMetaClasses.h> | ||||
| #include <ATen/FunctionalInverses.h> | ||||
| #include <ATen/MemoryOverlap.h> | ||||
| #include <torch/library.h> | ||||
|  | ||||
|  | ||||
| @ -1,19 +0,0 @@ | ||||
| // ${generated_comment} | ||||
|  | ||||
| #include <ATen/FunctionalInverses.h> | ||||
| #include <ATen/ViewMetaClasses.h> | ||||
|  | ||||
| #ifndef AT_PER_OPERATOR_HEADERS | ||||
| #include <ATen/Operators.h> | ||||
| #include <ATen/NativeFunctions.h> | ||||
| #else | ||||
| ${op_headers} | ||||
| #endif | ||||
|  | ||||
| namespace at { | ||||
| namespace functionalization { | ||||
|  | ||||
| ${view_meta_implementations} | ||||
|  | ||||
| } // namespace functionalization | ||||
| } // namespace at | ||||
| @ -1,12 +0,0 @@ | ||||
| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS | ||||
| // ${generated_comment} | ||||
|  | ||||
| #include <ATen/FunctionalStorageImpl.h> | ||||
|  | ||||
| namespace at { | ||||
| namespace functionalization { | ||||
|  | ||||
| ${view_meta_declarations} | ||||
|  | ||||
| } // namespace functionalization | ||||
| } // namespace at | ||||
| @ -1,11 +0,0 @@ | ||||
| #include <ATen/ViewMetaClasses.h> | ||||
| #include <torch/csrc/functionalization/Module.h> | ||||
|  | ||||
| namespace torch::functionalization { | ||||
|  | ||||
| void initGenerated(PyObject* module) { | ||||
|   auto functionalization = py::handle(module).cast<py::module>(); | ||||
|   $view_meta_bindings | ||||
| } | ||||
|  | ||||
| } // namespace torch::functionalization | ||||
| @ -1561,38 +1561,6 @@ namespace { | ||||
|               << "Failure Details:\nTest Seed to reproduce: " << seed; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|     TYPED_TEST(Quantization8BitTests, TransposePackVNNI4) { | ||||
|         using VT = ValueType<TypeParam>; | ||||
|         constexpr auto K = 197; | ||||
|         constexpr auto N = 64; | ||||
|         constexpr auto L = K * N; | ||||
|         constexpr auto ld_src = N; | ||||
|         constexpr auto ld_dst = K * 4; | ||||
|         CACHE_ALIGN VT x[L]; | ||||
|         CACHE_ALIGN VT y[L]; | ||||
|         CACHE_ALIGN VT ref[L]; | ||||
|         auto seed = TestSeed(); | ||||
|         ValueGen<VT> generator(VT(-100), VT(100), seed); | ||||
|         for (const auto i : c10::irange(L)) { | ||||
|           x[i] = generator.get(); | ||||
|         } | ||||
|         at::vec::transpose_pack_vnni4(x, y, ld_src, K, N); | ||||
|         int64_t _N = N / 4; | ||||
|         for (int64_t k = 0; k < K; k++) { | ||||
|           for(int64_t n = 0; n < _N; n++) { | ||||
|             for(int64_t l = 0; l < 4; l++) { | ||||
|               ref[n * ld_dst + k * 4 + l] = | ||||
|                   c10::load(&(x[k * ld_src + n * 4 + l])); | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|         for (const auto i : c10::irange(L)) { | ||||
|           ASSERT_EQ(y[i], ref[i]) | ||||
|               << "Failure Details:\nTest Seed to reproduce: " << seed; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|     TYPED_TEST(FunctionalTests, Map) { | ||||
|         using vec = TypeParam; | ||||
|  | ||||
| @ -318,7 +318,7 @@ timm_vovnet,pass,0 | ||||
|  | ||||
|  | ||||
|  | ||||
| torch_multimodal_clip,pass,0 | ||||
| torch_multimodal_clip,pass,3 | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| 
 | 
| @ -391,8 +391,6 @@ def get_aten_generated_files(enabled_backends): | ||||
|         "CompositeExplicitAutogradFunctions_inl.h", | ||||
|         "CompositeExplicitAutogradNonFunctionalFunctions.h", | ||||
|         "CompositeExplicitAutogradNonFunctionalFunctions_inl.h", | ||||
|         "ViewMetaClasses.h", | ||||
|         "ViewMetaClasses.cpp", | ||||
|         "VmapGeneratedPlumbing.h", | ||||
|         "core/ATenOpList.cpp", | ||||
|         "core/TensorBody.h", | ||||
| @ -1194,7 +1192,6 @@ def define_buck_targets( | ||||
|             "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]", | ||||
|             "Operators.h": ":gen_aten[Operators.h]", | ||||
|             "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]", | ||||
|             "ViewMetaClasses.h": ":gen_aten[ViewMetaClasses.h]", | ||||
|             "core/TensorBody.h": ":gen_aten[core/TensorBody.h]", | ||||
|             "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]", | ||||
|             "core/enum_tag.h": ":gen_aten[core/enum_tag.h]", | ||||
|  | ||||
| @ -118,9 +118,6 @@ def define_targets(rules): | ||||
|             ":LazyNonNativeIr.h", | ||||
|             ":RegisterDispatchDefinitions.ini", | ||||
|             ":RegisterDispatchKey.cpp", | ||||
|             ":ViewMetaClassesPythonBinding.cpp", | ||||
|             ":ViewMetaClasses.cpp", | ||||
|             ":ViewMetaClasses.h", | ||||
|             ":native_functions.yaml", | ||||
|             ":shape_inference.h", | ||||
|             ":tags.yaml", | ||||
| @ -173,7 +170,6 @@ GENERATED_H = [ | ||||
|     "FunctionalInverses.h", | ||||
|     "RedispatchFunctions.h", | ||||
|     "RegistrationDeclarations.h", | ||||
|     "ViewMetaClasses.h", | ||||
|     "VmapGeneratedPlumbing.h", | ||||
| ] | ||||
|  | ||||
| @ -250,7 +246,6 @@ GENERATED_CPP = [ | ||||
|     "RegisterFunctionalization_1.cpp", | ||||
|     "RegisterFunctionalization_2.cpp", | ||||
|     "RegisterFunctionalization_3.cpp", | ||||
|     "ViewMetaClasses.cpp", | ||||
| ] | ||||
|  | ||||
| GENERATED_CPP_CORE = [ | ||||
| @ -312,7 +307,6 @@ _GENERATED_AUTOGRAD_PYTHON_CPP = [ | ||||
|     "torch/csrc/autograd/generated/python_torch_functions_1.cpp", | ||||
|     "torch/csrc/autograd/generated/python_torch_functions_2.cpp", | ||||
|     "torch/csrc/autograd/generated/python_variable_methods.cpp", | ||||
|     "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" | ||||
| ] | ||||
|  | ||||
| GENERATED_AUTOGRAD_PYTHON = _GENERATED_AUTOGRAD_PYTHON_HEADERS + _GENERATED_AUTOGRAD_PYTHON_CPP | ||||
|  | ||||
| @ -1010,7 +1010,6 @@ libtorch_python_core_sources = [ | ||||
|     "torch/csrc/utils/disable_torch_function.cpp", | ||||
|     "torch/csrc/utils/verbose.cpp", | ||||
|     "torch/csrc/cpu/Module.cpp", | ||||
|     "torch/csrc/functionalization/Module.cpp", | ||||
|     "torch/csrc/instruction_counter/Module.cpp", | ||||
|     "torch/nativert/python/Bindings.cpp", | ||||
| ] + lazy_tensor_core_python_sources | ||||
| @ -1053,7 +1052,6 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"): | ||||
|         "torch/csrc/autograd/generated/python_torch_functions_1.cpp", | ||||
|         "torch/csrc/autograd/generated/python_torch_functions_2.cpp", | ||||
|         "torch/csrc/autograd/generated/python_variable_methods.cpp", | ||||
|         "torch/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp", | ||||
|     ]] | ||||
|  | ||||
|     _libtorch_python_sources.extend(libtorch_python_core_sources) | ||||
|  | ||||
| @ -3244,7 +3244,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
|     are_equal<sizeof(autograd_meta_),      4,  FieldNameEnum::autograd_meta_>(); | ||||
|     are_equal<sizeof(extra_meta_),         4,  FieldNameEnum::extra_meta_>(); | ||||
|     are_equal<sizeof(version_counter_),    4,  FieldNameEnum::version_counter_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),         4,  FieldNameEnum::pyobj_slot_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),    8,  FieldNameEnum::pyobj_slot_>(); | ||||
|     is_le<sizeof(sizes_and_strides_),     88, FieldNameEnum::sizes_and_strides_>(); | ||||
|     are_equal<sizeof(storage_offset_),     8,  FieldNameEnum::storage_offset_>(); | ||||
|     are_equal<sizeof(numel_),              8,  FieldNameEnum::numel_>(); | ||||
| @ -3269,7 +3269,7 @@ class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl { | ||||
|     is_le<sizeof(autograd_meta_),         16,  FieldNameEnum::autograd_meta_>(); | ||||
|     is_le<sizeof(extra_meta_),            16,  FieldNameEnum::extra_meta_>(); | ||||
|     are_equal<sizeof(version_counter_),    8,  FieldNameEnum::version_counter_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),         8,  FieldNameEnum::pyobj_slot_>(); | ||||
|     are_equal<sizeof(pyobj_slot_),   16,  FieldNameEnum::pyobj_slot_>(); | ||||
|     are_equal<sizeof(sizes_and_strides_), 88,  FieldNameEnum::sizes_and_strides_>(); | ||||
|     are_equal<sizeof(storage_offset_),     8,  FieldNameEnum::storage_offset_>(); | ||||
|     are_equal<sizeof(numel_),              8,  FieldNameEnum::numel_>(); | ||||
|  | ||||
							
								
								
									
										21
									
								
								c10/core/impl/HermeticPyObjectTLS.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								c10/core/impl/HermeticPyObjectTLS.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| #include <c10/core/impl/HermeticPyObjectTLS.h> | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| thread_local static std::atomic<bool> hermeticPyObjectState{false}; | ||||
|  | ||||
| std::atomic<bool> HermeticPyObjectTLS::haveState_{false}; | ||||
|  | ||||
| void HermeticPyObjectTLS::set_state(bool state) { | ||||
|   hermeticPyObjectState = state; | ||||
| } | ||||
|  | ||||
| bool HermeticPyObjectTLS::get_tls_state() { | ||||
|   return hermeticPyObjectState; | ||||
| } | ||||
|  | ||||
| void HermeticPyObjectTLS::init_state() { | ||||
|   haveState_ = true; | ||||
| } | ||||
|  | ||||
| } // namespace c10::impl | ||||
							
								
								
									
										62
									
								
								c10/core/impl/HermeticPyObjectTLS.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								c10/core/impl/HermeticPyObjectTLS.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/macros/Export.h> | ||||
| #include <atomic> | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| // This TLS controls whether or not we permanently associate PyObject | ||||
| // with Tensor the first time it is allocated.  When hermetic PyObject | ||||
| // TLS is enabled (state is true), we DO NOT save PyObjects to Tensor, | ||||
| // meaning you get a distinct PyObject whenever you execute the code in | ||||
| // question. | ||||
| struct C10_API HermeticPyObjectTLS { | ||||
|   static void set_state(bool state); | ||||
|   static bool get_state() { | ||||
|     // Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy | ||||
|     // isn't used. Per | ||||
|     // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf | ||||
|     // this qualifies relaxed access because it is a single-location data | ||||
|     // structure (only the boolean here). | ||||
|     // | ||||
|     // Forgetting about data races for a moment, is there a logical race? | ||||
|     // | ||||
|     //  - Boolean only ever transitions from false to true.  So the | ||||
|     //    critical situation is when one interpreter is already running | ||||
|     //    when a second interpreter switches haveState from false to true. | ||||
|     // | ||||
|     //  - The first interpreter is indifferent whether or not it sees | ||||
|     //    hasState true/false; obviously false works (this is what the | ||||
|     //    interpreter was previously using; more directly, the interpreter | ||||
|     //    calls into itself as the handler, so being hermetic is not | ||||
|     //    required), and true simply means serviced python operator calls will | ||||
|     //    be hermetic; in these cases it is expected to be functionally | ||||
|     //    equivalent. | ||||
|     // | ||||
|     //  - The second interpreter MUST see hasState true (as its requests will | ||||
|     //    be forwarded to the first interpreter), but it is assumed that there | ||||
|     //    is a synchronization between the interpreter initialization, and | ||||
|     //    when we actually perform operations, so it is guaranteed to see | ||||
|     //    hasState true. | ||||
|     // | ||||
|     // QED. | ||||
|     // | ||||
|     // This fastpath is currently disabled so that we can more easily test that | ||||
|     // hermetic mode works correctly even on stock build of PyTorch. | ||||
|     if (false && !haveState_.load(std::memory_order_relaxed)) | ||||
|       return false; | ||||
|     return get_tls_state(); | ||||
|   } | ||||
|   // Call this from the multipy/torchdeploy // codespell:ignore multipy | ||||
|   // top level | ||||
|   static void init_state(); | ||||
|  | ||||
|  private: | ||||
|   // This only flipped once from false to true during | ||||
|   // torchdeploy/multipy initialization, // codespell:ignore multipy | ||||
|   // and never again. | ||||
|   static std::atomic<bool> haveState_; | ||||
|   static bool get_tls_state(); | ||||
| }; | ||||
|  | ||||
| } // namespace c10::impl | ||||
| @ -13,10 +13,11 @@ struct C10_API PyInterpreterHooksInterface { | ||||
|  | ||||
|   // Get the PyInterpreter instance | ||||
|   // Stub implementation throws error when Python is not available | ||||
|   // We return nullptr rather than throwing an error since there are bits of c10 | ||||
|   // that expect an empty PyObjectSlot when python is not available. | ||||
|   virtual PyInterpreter* getPyInterpreter() const { | ||||
|     return nullptr; | ||||
|     TORCH_CHECK( | ||||
|         false, | ||||
|         "PyTorch was compiled without Python support. " | ||||
|         "Cannot access Python interpreter from C++."); | ||||
|   } | ||||
| }; | ||||
|  | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| PyObjectSlot::PyObjectSlot() : pyobj_(nullptr) {} | ||||
| PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} | ||||
|  | ||||
| PyObjectSlot::~PyObjectSlot() { | ||||
|   maybe_destroy_pyobj(); | ||||
| @ -10,9 +10,9 @@ PyObjectSlot::~PyObjectSlot() { | ||||
|  | ||||
| void PyObjectSlot::maybe_destroy_pyobj() { | ||||
|   if (owns_pyobj()) { | ||||
|     TORCH_INTERNAL_ASSERT(getGlobalPyInterpreter() != nullptr); | ||||
|     TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr); | ||||
|     TORCH_INTERNAL_ASSERT(pyobj_ != nullptr); | ||||
|     (*getGlobalPyInterpreter()) | ||||
|     (*pyobj_interpreter_.load(std::memory_order_acquire)) | ||||
|         ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true); | ||||
|     // NB: this destructor can only be entered when there are no | ||||
|     // references to this C++ object (obviously), NOR any references | ||||
| @ -25,7 +25,7 @@ void PyObjectSlot::maybe_destroy_pyobj() { | ||||
| } | ||||
|  | ||||
| PyInterpreter* PyObjectSlot::pyobj_interpreter() { | ||||
|   return getGlobalPyInterpreter(); | ||||
|   return pyobj_interpreter_.load(std::memory_order_acquire); | ||||
| } | ||||
|  | ||||
| PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { | ||||
| @ -35,7 +35,7 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { | ||||
| } | ||||
|  | ||||
| PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { | ||||
|   auto interpreter = getGlobalPyInterpreter(); | ||||
|   auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); | ||||
|   if (interpreter) { | ||||
|     return *interpreter; | ||||
|   } | ||||
|  | ||||
| @ -1,21 +1,15 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <c10/core/impl/HermeticPyObjectTLS.h> | ||||
| #include <c10/core/impl/PyInterpreter.h> | ||||
| #include <c10/core/impl/PyInterpreterHooks.h> | ||||
| #include <c10/util/python_stub.h> | ||||
| #include <optional> | ||||
|  | ||||
| #include <atomic> | ||||
|  | ||||
| namespace c10::impl { | ||||
|  | ||||
| // Function pointer type for getting the global interpreter | ||||
| using GetPyInterpreterFn = PyInterpreter* (*)(); | ||||
|  | ||||
| // Global function pointer (set by csrc initialization) | ||||
| C10_API extern GetPyInterpreterFn g_get_pyinterpreter_fn; | ||||
|  | ||||
| // Helper function to get the global interpreter | ||||
| C10_API PyInterpreter* getGlobalPyInterpreter(); | ||||
|  | ||||
| struct C10_API PyObjectSlot { | ||||
|  public: | ||||
|   PyObjectSlot(); | ||||
| @ -32,6 +26,8 @@ struct C10_API PyObjectSlot { | ||||
|   // NB: THIS FUNCTION CAN RAISE AN EXCEPTION.  Make sure to clean up after | ||||
|   // PyObject if necessary! | ||||
|   void init_pyobj(PyObject* pyobj) { | ||||
|     pyobj_interpreter_.store( | ||||
|         getGlobalPyInterpreter(), std::memory_order_relaxed); | ||||
|     pyobj_ = pyobj; | ||||
|   } | ||||
|  | ||||
| @ -41,16 +37,36 @@ struct C10_API PyObjectSlot { | ||||
|  | ||||
|   PyObject* _unchecked_untagged_pyobj() const; | ||||
|  | ||||
|   // Test the interpreter / PyObj as they may be null | ||||
|   // Test the interpreter tag.  If tagged for the current interpreter, return | ||||
|   // a non-nullopt (but possibly null) PyObject.  If (possibly) untagged, | ||||
|   // returns a nullopt.  If it is definitely invalid, raises an error. | ||||
|   // | ||||
|   // If `ignore_hermetic_tls` is false and this function is called from a | ||||
|   // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then | ||||
|   // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic | ||||
|   // context is ignored, allowing you to check the interpreter tag of a | ||||
|   // nonhermetic PyObject from within a hermetic context. This is necessary | ||||
|   // because there are some cases where the deallocator function of a | ||||
|   // nonhermetic PyObject is called from within a hermetic context, so it must | ||||
|   // be properly treated as a nonhermetic PyObject. | ||||
|   // | ||||
|   // NB: this lives in header so that we can avoid actually creating the | ||||
|   // std::optional | ||||
|  | ||||
|   std::optional<PyObject*> check_pyobj() const { | ||||
|     impl::PyInterpreter* interpreter = getGlobalPyInterpreter(); | ||||
|     if (interpreter == nullptr || pyobj_ == nullptr) { | ||||
|   // @todo alban: I'm not too sure what's going on here, we can probably delete | ||||
|   // it but it's worthwhile making sure | ||||
|   std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const { | ||||
|     impl::PyInterpreter* interpreter = | ||||
|         pyobj_interpreter_.load(std::memory_order_acquire); | ||||
|     if (interpreter == nullptr) { | ||||
|       return std::nullopt; | ||||
|     } | ||||
|     return _unchecked_untagged_pyobj(); | ||||
|  | ||||
|     if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { | ||||
|       return std::nullopt; | ||||
|     } else { | ||||
|       return _unchecked_untagged_pyobj(); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   PyInterpreter& load_pyobj_interpreter() const; | ||||
| @ -60,6 +76,30 @@ struct C10_API PyObjectSlot { | ||||
|   void set_owns_pyobj(bool b); | ||||
|  | ||||
|  private: | ||||
|   // This field contains the interpreter tag for this object.  See | ||||
|   // Note [Python interpreter tag] for general context | ||||
|   // | ||||
|   // Note [Memory ordering on Python interpreter tag] | ||||
|   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
|   // What memory_order do we need when accessing this atomic?  We don't | ||||
|   // need a single total modification order (as provided by | ||||
|   // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only | ||||
|   // transition from -1 to some positive integer and never changes afterwards. | ||||
|   // Because there is only one modification, it trivially already has a total | ||||
|   // modification order (e.g., we don't need fences or locked instructions on | ||||
|   // x86) | ||||
|   // | ||||
|   // In fact, one could make a reasonable argument that relaxed reads are OK, | ||||
|   // due to the presence of external locking (GIL) to ensure that interactions | ||||
|   // with other data structures are still correctly synchronized, so that | ||||
|   // we fall in the "Single-Location Data Structures" case as described in | ||||
|   // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf | ||||
|   // However, on x86, it doesn't matter if I use acquire or relaxed on the load | ||||
|   // as I get the same assembly in both cases.  So I just use the more | ||||
|   // conservative acquire (which will impede compiler optimizations but I don't | ||||
|   // care) | ||||
|   std::atomic<PyInterpreter*> pyobj_interpreter_; | ||||
|  | ||||
|   // This field contains a reference to a PyObject representing this Tensor. | ||||
|   // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new | ||||
|   // PyObject for it and set this field.  This field does not have to be | ||||
|  | ||||
| @ -316,7 +316,6 @@ set(GENERATED_CXX_PYTHON | ||||
|   "${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp" | ||||
|   "${TORCH_SRC_DIR}/csrc/autograd/generated/python_return_types.cpp" | ||||
|   "${TORCH_SRC_DIR}/csrc/autograd/generated/python_enum_tag.cpp" | ||||
|   "${TORCH_SRC_DIR}/csrc/functionalization/generated/ViewMetaClassesPythonBinding.cpp" | ||||
|   ) | ||||
|  | ||||
| set(GENERATED_H_PYTHON | ||||
| @ -380,9 +379,6 @@ add_custom_command( | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h" | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h" | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp" | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.h" | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClasses.cpp" | ||||
|     "${TORCH_ROOT}/aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp" | ||||
|     ${autograd_python} | ||||
|     ${autograd_yaml} | ||||
|     ${autograd_templates} | ||||
|  | ||||
| @ -38,7 +38,7 @@ def unroll(num_unrolls, IndexType, InType, OutType): | ||||
|     code = [] | ||||
|  | ||||
|     if num_unrolls == 1: | ||||
|         code.append("    // tail loop") | ||||
|         code.append(f"    // tail loop") | ||||
|         code.append("    if (j < end_offset) {") | ||||
|     else: | ||||
|         code.append(f"    // unrolling {num_unrolls} times") | ||||
|  | ||||
| @ -153,6 +153,7 @@ _ZN3c104impl12PyObjectSlot10owns_pyobjEv | ||||
| _ZN3c104impl12PyObjectSlot19maybe_destroy_pyobjEv | ||||
| _ZN3c104impl12PyObjectSlotC1Ev | ||||
| _ZN3c104impl12PyObjectSlotD2Ev | ||||
| _ZN3c104impl19HermeticPyObjectTLS13get_tls_stateEv | ||||
| _ZN3c104impl20TorchDispatchModeTLS13any_modes_setEb | ||||
| _ZN3c104impl23ExcludeDispatchKeyGuardC1ENS_14DispatchKeySetE | ||||
| _ZN3c104impl23ExcludeDispatchKeyGuardD2Ev | ||||
|  | ||||
| @ -40,34 +40,7 @@ extensions = [ | ||||
|     "sphinx.ext.intersphinx", | ||||
| ] + (["breathe", "exhale"] if run_doxygen else []) | ||||
|  | ||||
| intersphinx_mapping = {"pytorch": ("https://docs.pytorch.org/docs/main", None)} | ||||
|  | ||||
| # Configure Sphinx warnings and error handling | ||||
| suppress_warnings = [ | ||||
|     "ref.citation", | ||||
|     "ref.footnote", | ||||
|     "ref.doc", | ||||
|     "toc.excluded", | ||||
|     "toc.not_readable", | ||||
|     "misc.highlighting_failure", | ||||
| ] | ||||
|  | ||||
| # Configure Breathe | ||||
| breathe_show_define_initializer = True | ||||
| breathe_show_enumvalue_initializer = True | ||||
| breathe_default_members = ("members", "undoc-members") | ||||
|  | ||||
|  | ||||
| # Fix for Python 3.10+ compatibility with exhale 2.3.0 | ||||
| # MutableMapping was moved from collections to collections.abc in Python 3.10 | ||||
| try: | ||||
|     import collections | ||||
|     from collections.abc import MutableMapping | ||||
|  | ||||
|     if not hasattr(collections, "MutableMapping"): | ||||
|         collections.MutableMapping = MutableMapping | ||||
| except ImportError: | ||||
|     pass | ||||
| intersphinx_mapping = {"pytorch": ("https://pytorch.org/docs/main", None)} | ||||
|  | ||||
| # Setup absolute paths for communicating with breathe / exhale where | ||||
| # items are expected / should be trimmed by. | ||||
| @ -128,21 +101,6 @@ exhale_args = { | ||||
|         Welcome to the developer reference for the PyTorch C++ API. | ||||
|     """ | ||||
|     ), | ||||
|     ############################################################################ | ||||
|     # Duplicate handling and error management.                                 # | ||||
|     ############################################################################ | ||||
|     # Note: Using Doxyfile instead of stdin configuration | ||||
|     # "exhaleDoxygenStdin" is not compatible with "exhaleUseDoxyfile" | ||||
|     # Handle unresolved references more gracefully | ||||
|     "unabridgedOrphanKinds": { | ||||
|         "function", | ||||
|         "define", | ||||
|         "enum", | ||||
|         "enumvalue", | ||||
|         "typedef", | ||||
|         "variable", | ||||
|     }, | ||||
|     "fullToctreeMaxDepth": 2, | ||||
| } | ||||
|  | ||||
| # Tell sphinx what the primary language being documented is. | ||||
|  | ||||
| @ -1093,9 +1093,6 @@ The set of leaf modules can be customized by overriding | ||||
| ```{eval-rst} | ||||
| .. autofunction:: torch.fx.replace_pattern | ||||
| ``` | ||||
| ```{eval-rst} | ||||
| .. autofunction:: torch.fx.traceback.annotate | ||||
| ``` | ||||
|  | ||||
| <!-- The experimental and passes submodules are missing docs. --> | ||||
| <!-- Adding it here for coverage but this doesn't add anything to the --> | ||||
|  | ||||
| @ -156,7 +156,6 @@ def get_generate_code_bin_outs(): | ||||
|             "autograd/generated/python_torch_functions_1.cpp": ["autograd/generated/python_torch_functions_1.cpp"], | ||||
|             "autograd/generated/python_torch_functions_2.cpp": ["autograd/generated/python_torch_functions_2.cpp"], | ||||
|             "autograd/generated/python_variable_methods.cpp": ["autograd/generated/python_variable_methods.cpp"], | ||||
|             "functionalization/generated/ViewMetaClassesPythonBinding.cpp": ["functionalization/generated/ViewMetaClassesPythonBinding.cpp"], | ||||
|         }) | ||||
|     return outs | ||||
|  | ||||
|  | ||||
| @ -182,6 +182,7 @@ ignore = [ | ||||
|     "SIM117", | ||||
|     "SIM118", | ||||
|     "UP007", # keep-runtime-typing | ||||
|     "UP038", # Was removed from newer versions, results in slower code | ||||
|     "UP045", # keep-runtime-typing | ||||
|     "TC006", | ||||
|     # TODO: Remove Python-3.10 specific suppressions | ||||
|  | ||||
							
								
								
									
										13
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								setup.py
									
									
									
									
									
								
							| @ -1704,18 +1704,7 @@ def main() -> None: | ||||
|     package_data = { | ||||
|         "torch": torch_package_data, | ||||
|     } | ||||
|     # some win libraries are excluded | ||||
|     # these are statically linked | ||||
|     exclude_windows_libs = [ | ||||
|         "lib/dnnl.lib", | ||||
|         "lib/kineto.lib", | ||||
|         "lib/libprotobuf-lite.lib", | ||||
|         "lib/libprotobuf.lib", | ||||
|         "lib/libprotoc.lib", | ||||
|     ] | ||||
|     exclude_package_data = { | ||||
|         "torch": exclude_windows_libs, | ||||
|     } | ||||
|     exclude_package_data = {} | ||||
|  | ||||
|     if not BUILD_LIBTORCH_WHL: | ||||
|         package_data["torchgen"] = torchgen_package_data | ||||
|  | ||||
| @ -1,7 +1,9 @@ | ||||
| if(WIN32) | ||||
|   set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/${CMAKE_IMPORT_LIBRARY_PREFIX}torch_python${CMAKE_IMPORT_LIBRARY_SUFFIX}") | ||||
|   set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib") | ||||
| elseif(APPLE) | ||||
|   set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib") | ||||
| else() | ||||
|   set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/${CMAKE_SHARED_LIBRARY_PREFIX}torch_python${CMAKE_SHARED_LIBRARY_SUFFIX}") | ||||
|   set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so") | ||||
| endif() | ||||
|  | ||||
| add_library(torch_python SHARED IMPORTED) | ||||
|  | ||||
| @ -11,12 +11,7 @@ from typing import Union | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| import torch.nn as nn | ||||
| from torch.distributed._composable import checkpoint | ||||
| from torch.distributed._composable.replicate_with_fsdp import replicate | ||||
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||||
|     _CHECKPOINT_PREFIX, | ||||
|     apply_activation_checkpointing, | ||||
| ) | ||||
| from torch.distributed.fsdp import CPUOffloadPolicy, FSDPModule, OffloadPolicy | ||||
| from torch.distributed.tensor import DTensor, init_device_mesh | ||||
| from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | ||||
| @ -657,190 +652,5 @@ class TestReplicate1DTrainingCore(FSDPTest): | ||||
|             self.assertEqual(ref_loss, loss) | ||||
|  | ||||
|  | ||||
| class TestReplicateTrainingCompose(FSDPTest): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         # Since these tests run with a larger transformer model, they may see | ||||
|         # some numeric drift with >2 GPUs | ||||
|         return min(torch.get_device_module(device_type).device_count(), 2) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @compiled_fsdp_test(compile_compute_on_module=Transformer) | ||||
|     def test_train_parity_with_activation_checkpointing(self): | ||||
|         """ | ||||
|         Tests train parity against DDP when composing with activation | ||||
|         checkpointing. | ||||
|         """ | ||||
|         self.run_subtests( | ||||
|             { | ||||
|                 "reshard_after_forward": [True, False], | ||||
|                 "checkpoint_impl": ["composable", "utils", "wrapper"], | ||||
|                 "module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"], | ||||
|                 "test_device_type": [device_type.type], | ||||
|             }, | ||||
|             self._test_train_parity_with_activation_checkpointing, | ||||
|         ) | ||||
|  | ||||
|     def _test_train_parity_with_activation_checkpointing( | ||||
|         self, | ||||
|         reshard_after_forward: Union[bool, int], | ||||
|         checkpoint_impl: str, | ||||
|         module_grouping: str, | ||||
|         test_device_type: str, | ||||
|     ): | ||||
|         assert checkpoint_impl in ("composable", "utils", "wrapper") | ||||
|         testing_compile = replicate != torch.distributed._composable.replicate_with_fsdp | ||||
|         if testing_compile and checkpoint_impl == "composable": | ||||
|             return | ||||
|         torch.manual_seed(42) | ||||
|         vocab_size = 1024 | ||||
|         with torch.device(device_type): | ||||
|             model_args = ModelArgs( | ||||
|                 n_layers=3, | ||||
|                 n_heads=4, | ||||
|                 vocab_size=vocab_size, | ||||
|                 max_seq_len=64, | ||||
|                 dropout_p=0, | ||||
|                 checkpoint_activations=(checkpoint_impl == "utils"), | ||||
|                 # For the mem-efficient module grouping, we separate the | ||||
|                 # embeddings from the output projection, which does not support | ||||
|                 # weight tying | ||||
|                 weight_tying=module_grouping != "mem_eff", | ||||
|             ) | ||||
|             model = Transformer(model_args) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|  | ||||
|         # Apply activation checkpointing | ||||
|         prefixes_to_ignore = () | ||||
|         if checkpoint_impl == "wrapper": | ||||
|             prefixes_to_ignore = (_CHECKPOINT_PREFIX,) | ||||
|             apply_activation_checkpointing( | ||||
|                 model, check_fn=lambda m: isinstance(m, TransformerBlock) | ||||
|             ) | ||||
|         elif checkpoint_impl == "composable": | ||||
|             for module in model.modules(): | ||||
|                 if isinstance(module, TransformerBlock): | ||||
|                     checkpoint(module) | ||||
|  | ||||
|         # Apply Replicate | ||||
|         device_mesh = init_device_mesh( | ||||
|             test_device_type, | ||||
|             (self.world_size, 1), | ||||
|             mesh_dim_names=("replicate", "shard"), | ||||
|         ) | ||||
|         fsdp_kwargs = { | ||||
|             "reshard_after_forward": reshard_after_forward, | ||||
|             "device_mesh": device_mesh, | ||||
|         } | ||||
|         if module_grouping == "mem_eff": | ||||
|             assert model_args.n_layers == 3 | ||||
|             replicate(model.layers[0], **fsdp_kwargs) | ||||
|             replicate([model.layers[1], model.layers[2]], **fsdp_kwargs) | ||||
|             replicate([model.tok_embeddings, model.pos_embeddings], **fsdp_kwargs) | ||||
|             # Embedding weights are not needed for embedding backward | ||||
|             model.tok_embeddings.set_unshard_in_backward(False) | ||||
|             replicate([model.norm, model.output], **fsdp_kwargs) | ||||
|         elif module_grouping == "mem_eff_weight_tied": | ||||
|             replicate([model.tok_embeddings, model.output], **fsdp_kwargs) | ||||
|             for layer in model.layers: | ||||
|                 replicate(layer, **fsdp_kwargs) | ||||
|         elif module_grouping == "block": | ||||
|             for layer in model.layers: | ||||
|                 replicate(layer, **fsdp_kwargs) | ||||
|         else: | ||||
|             raise NotImplementedError(f"Unknown module grouping: {module_grouping}") | ||||
|         replicate(model, **fsdp_kwargs) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         # Reuse the same input across iterations to avoid loss explosion from | ||||
|         # trying to learn from random inputs | ||||
|         inp = torch.randint(0, vocab_size, (3, 64), device=device_type.type) | ||||
|         check_sharded_parity( | ||||
|             self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore | ||||
|         ) | ||||
|         for iter_idx in range(10): | ||||
|             losses: list[torch.Tensor] = [] | ||||
|             for _model in (ref_model, model): | ||||
|                 torch.manual_seed(iter_idx + 1)  # for dropout determinism | ||||
|                 losses.append(_model(inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             if not testing_compile: | ||||
|                 check_sharded_parity( | ||||
|                     self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore | ||||
|                 ) | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.step() | ||||
|                 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|             if not testing_compile: | ||||
|                 check_sharded_parity( | ||||
|                     self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| class TestReplicateSharedParams(FSDPTest): | ||||
|     @property | ||||
|     def world_size(self) -> int: | ||||
|         return min(4, torch.get_device_module(device_type).device_count()) | ||||
|  | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_train_parity_with_shared_params(self): | ||||
|         self.run_subtests( | ||||
|             { | ||||
|                 "reshard_after_forward": [False, True], | ||||
|                 "use_activation_checkpointing": [False, True], | ||||
|             }, | ||||
|             self._test_train_shared_params, | ||||
|         ) | ||||
|  | ||||
|     def _test_train_shared_params( | ||||
|         self, | ||||
|         reshard_after_forward: bool, | ||||
|         use_activation_checkpointing: bool, | ||||
|     ): | ||||
|         torch.manual_seed(42) | ||||
|         model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True) | ||||
|         model = Transformer(model_args) | ||||
|         ref_model = copy.deepcopy(model).to(device_type) | ||||
|  | ||||
|         ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) | ||||
|         for module in model.modules(): | ||||
|             if isinstance(module, TransformerBlock): | ||||
|                 if use_activation_checkpointing: | ||||
|                     checkpoint(module) | ||||
|                 replicate(module, reshard_after_forward=reshard_after_forward) | ||||
|         replicate(model, reshard_after_forward=reshard_after_forward) | ||||
|         optim = torch.optim.Adam(model.parameters(), lr=1e-2) | ||||
|  | ||||
|         torch.manual_seed(42 + self.rank + 1) | ||||
|         for iter_idx in range(10): | ||||
|             inp = torch.randint( | ||||
|                 0, model_args.vocab_size, (2, 16), device=device_type.type | ||||
|             ) | ||||
|             losses: list[torch.Tensor] = [] | ||||
|             for _model in (ref_model, model): | ||||
|                 losses.append(_model(inp).sum()) | ||||
|                 losses[-1].backward() | ||||
|  | ||||
|             for param in ref_model.parameters(): | ||||
|                 if param.grad is not None: | ||||
|                     dist.all_reduce(param.grad) | ||||
|                     param.grad.div_(self.world_size) | ||||
|  | ||||
|             for _optim in (ref_optim, optim): | ||||
|                 _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) | ||||
|                 _optim.step() | ||||
|  | ||||
|             self.assertEqual(losses[0], losses[1]) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["module: fsdp"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import functools | ||||
| import gc | ||||
| from typing import Union | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import gc | ||||
| import unittest | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
|  | ||||
| from copy import copy | ||||
|  | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import unittest | ||||
| from dataclasses import dataclass | ||||
| from typing import Any, Callable, cast, Union | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import unittest | ||||
|  | ||||
| import torch | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| # Owner(s): ["oncall: distributed"] | ||||
| # Owner(s): ["module: unknown"] | ||||
| import copy | ||||
| import unittest | ||||
|  | ||||
|  | ||||
| @ -143,19 +143,6 @@ class FlightRecorderEventTest(TestCase): | ||||
|             match_one_event(e11, e12, membership, "0").state, | ||||
|             MatchState.FULLY_MATCHED, | ||||
|         ) | ||||
|         e13 = create_one_event( | ||||
|             "gather", | ||||
|             ("0", "default"), | ||||
|             [[4, 4]], | ||||
|             [[4, 4]], | ||||
|             "completed", | ||||
|             1, | ||||
|             output_dtypes="", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             match_one_event(e11, e13, membership, "0").state, | ||||
|             MatchState.FULLY_MATCHED, | ||||
|         ) | ||||
|  | ||||
|     def test_all_events(self): | ||||
|         for collective in sorted(COLLECTIVES): | ||||
|  | ||||
| @ -202,62 +202,6 @@ class ScheduleTest(TestCase): | ||||
|  | ||||
|         torch.distributed.destroy_process_group() | ||||
|  | ||||
|     @parametrize( | ||||
|         "ScheduleClass", | ||||
|         [ | ||||
|             Schedule1F1B, | ||||
|             ScheduleGPipe, | ||||
|             ScheduleInterleaved1F1B, | ||||
|             ScheduleInterleavedZeroBubble, | ||||
|             ScheduleLoopedBFS, | ||||
|         ], | ||||
|     ) | ||||
|     def test_schedule_eval_then_train(self, ScheduleClass): | ||||
|         """ | ||||
|         Test that simply runs evaluation followed by training. | ||||
|         """ | ||||
|         store = FakeStore() | ||||
|         torch.distributed.init_process_group( | ||||
|             backend="fake", rank=0, world_size=1, store=store | ||||
|         ) | ||||
|         d_hid, batch_size = 512, 256 | ||||
|         n_stages = 1 | ||||
|         device = "cpu" | ||||
|         full_mod = MultiMLP(d_hid, n_layers=n_stages) | ||||
|         full_mod.to(device) | ||||
|  | ||||
|         x = torch.randn(batch_size, d_hid, device=device) | ||||
|         target = torch.randn(batch_size, d_hid, device=device) | ||||
|  | ||||
|         def loss_fn(y, target): | ||||
|             return torch.nn.functional.cross_entropy(y, target) | ||||
|  | ||||
|         submod_name = "layers.0" | ||||
|         stage_module = full_mod.get_submodule(submod_name) | ||||
|  | ||||
|         # Create a pipeline stage to wrap that submodule | ||||
|         num_microbatches = 2 | ||||
|         stages = [PipelineStage(stage_module, 0, n_stages, device)] | ||||
|  | ||||
|         if issubclass(ScheduleClass, PipelineScheduleSingle): | ||||
|             stages = stages[0] | ||||
|  | ||||
|         # Attach to a schedule | ||||
|         schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn) | ||||
|         # Run eval | ||||
|         for _ in range(2): | ||||
|             # Zero gradients | ||||
|             stage_module.zero_grad() | ||||
|             losses = [] | ||||
|             schedule.eval(x, target=target, losses=losses) | ||||
|         # Run training | ||||
|         try: | ||||
|             for _ in range(2): | ||||
|                 losses = [] | ||||
|                 schedule.step(x, target=target, losses=losses) | ||||
|         finally: | ||||
|             torch.distributed.destroy_process_group() | ||||
|  | ||||
|     def test_zero_bubble_schedule_errors_with_compile(self): | ||||
|         """ | ||||
|         Test that zero bubble schedules raise an error when used with torch.compile. | ||||
|  | ||||
| @ -248,16 +248,6 @@ class TestDTensorDebugMode(TestCase): | ||||
|             "redistribute_input(1, [S(0)] -> [R])" in debug_mode.debug_string() | ||||
|         ) | ||||
|  | ||||
|     def test_debug_mode_higher_order_cond(self): | ||||
|         """Test DebugMode with higher order operation.""" | ||||
|         x = torch.randn(1, 8, requires_grad=True) | ||||
|  | ||||
|         with DebugMode(record_torchfunction=True) as debug_mode: | ||||
|             torch.cond(torch.tensor(True), lambda x: x + 1, lambda x: x - 1, [x]) | ||||
|  | ||||
|         # Verify that cond operations are captured in debug mode | ||||
|         self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string()) | ||||
|  | ||||
|  | ||||
| instantiate_parametrized_tests(TestDTensorDebugMode) | ||||
|  | ||||
|  | ||||
| @ -352,7 +352,7 @@ class MicroPipelineTPTest(TestCase): | ||||
|     @parametrize("scatter_dim", [0, 1, 2]) | ||||
|     @fresh_cache() | ||||
|     def test_fuse_scaled_matmul_reduce_scatter(self, A_dims, scatter_dim): | ||||
|         if scatter_dim >= A_dims - 1: | ||||
|         if scatter_dim >= A_dims: | ||||
|             return | ||||
|  | ||||
|         group = dist.group.WORLD | ||||
| @ -402,7 +402,7 @@ class MicroPipelineTPTest(TestCase): | ||||
|  | ||||
|     @runOnRocmArch(MI300_ARCH) | ||||
|     @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") | ||||
|     @parametrize("scatter_dim", [0, 1]) | ||||
|     @parametrize("scatter_dim", [0, 1, 2]) | ||||
|     @fresh_cache() | ||||
|     def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape( | ||||
|         self, scatter_dim | ||||
|  | ||||
| @ -880,34 +880,6 @@ class DistMathOpsTest(DTensorTestBase): | ||||
|                 out_full = out_dt.full_tensor() | ||||
|                 self.assertEqual(global_bins, out_full) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_logsumexp(self): | ||||
|         mesh = self.build_device_mesh() | ||||
|         comm_mode = CommDebugMode() | ||||
|         inp = torch.rand(3, 5, device=self.device_type) | ||||
|  | ||||
|         shard_dim = 0 | ||||
|         input_dtensor = distribute_tensor( | ||||
|             inp, device_mesh=mesh, placements=[Shard(shard_dim)] | ||||
|         ) | ||||
|  | ||||
|         logsumexp_dims = [0, 1] | ||||
|         for dim in logsumexp_dims: | ||||
|             output = torch.logsumexp(inp, dim=dim) | ||||
|             with comm_mode: | ||||
|                 output_dtensor = torch.logsumexp(input_dtensor, dim=dim) | ||||
|                 if dim == shard_dim: | ||||
|                     self.assertEqual(comm_mode.get_total_counts(), 1) | ||||
|                     self.assertEqual( | ||||
|                         comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], | ||||
|                         1, | ||||
|                     ) | ||||
|                     self.assertTrue(output_dtensor.placements[0].is_replicate()) | ||||
|                 else: | ||||
|                     self.assertEqual(comm_mode.get_total_counts(), 0) | ||||
|                     self.assertTrue(output_dtensor.placements[0].is_shard(shard_dim)) | ||||
|                 self.assertEqual(output_dtensor.full_tensor(), output) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -31,7 +31,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     skip_unless_torch_gpu, | ||||
|     with_comms, | ||||
| ) | ||||
| from torch.utils._typing_utils import not_none | ||||
|  | ||||
|  | ||||
| def get_generator_seed_for_device_type(device_type: str) -> int: | ||||
| @ -550,9 +549,7 @@ class DistTensorRandomOpTest(DTensorTestBase): | ||||
|             # local_shard_list_on_dim[i] has the list of all shards on that dim | ||||
|             # as a tuple (local_shard_offset, local_shard_size) | ||||
|             dtensor_shape = dtensor.shape | ||||
|             local_shard_list_on_dim: list[list[tuple[int, int]]] = [ | ||||
|                 [(0, l)] for l in dtensor_shape | ||||
|             ] | ||||
|             local_shard_list_on_dim = [[(0, l)] for l in dtensor_shape] | ||||
|             for idx, placement in enumerate(placements): | ||||
|                 if isinstance(placement, Shard): | ||||
|                     mesh_dim_size = device_mesh.size(idx) | ||||
| @ -568,7 +565,7 @@ class DistTensorRandomOpTest(DTensorTestBase): | ||||
|                             shard_idx_on_dim, | ||||
|                         ) | ||||
|                         local_shard_list_on_dim[shard_dim].append( | ||||
|                             (not_none(shard_offset), shard_size) | ||||
|                             (shard_offset, shard_size) | ||||
|                         ) | ||||
|  | ||||
|             local_shard_comb = itertools.product(*local_shard_list_on_dim) | ||||
|  | ||||
| @ -691,25 +691,6 @@ class TestStridedSharding(DTensorTestBase): | ||||
|         ) | ||||
|         self.assertEqual(full_tensor, x) | ||||
|  | ||||
|     @with_comms | ||||
|     def test_2d_mesh_uneven_strided_shard(self): | ||||
|         mesh = init_device_mesh( | ||||
|             self.device_type, | ||||
|             (self.world_size // 2, 2), | ||||
|             mesh_dim_names=("fsdp", "tp"), | ||||
|         ) | ||||
|  | ||||
|         for size in (2, 3, 5, 11): | ||||
|             tensor = torch.arange(size, device=self.device_type).view(1, -1) | ||||
|             dtensor = distribute_tensor( | ||||
|                 tensor, | ||||
|                 device_mesh=mesh, | ||||
|                 placements=(Replicate(), Replicate()), | ||||
|             ).redistribute( | ||||
|                 mesh, placements=(_StridedShard(dim=1, split_factor=2), Shard(1)) | ||||
|             ) | ||||
|             self.assertEqual(dtensor.full_tensor(), tensor) | ||||
|  | ||||
|  | ||||
| class Test2DStridedLocalShard(DTensorTestBase): | ||||
|     @property | ||||
|  | ||||
| @ -8,7 +8,6 @@ import torch.distributed as dist | ||||
| import torch.distributed._functional_collectives as funcol | ||||
| from torch._C._distributed_c10d import Backend as C10dBackend | ||||
| from torch._subclasses.fake_tensor import FakeTensorMode | ||||
| from torch.distributed._mesh_layout import _MeshLayout as _Layout | ||||
| from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh | ||||
| from torch.distributed.distributed_c10d import ( | ||||
|     _get_default_group, | ||||
| @ -28,7 +27,7 @@ from torch.distributed.tensor._collective_utils import ( | ||||
| ) | ||||
| from torch.distributed.tensor.placement_types import _Partial, Shard | ||||
| from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | ||||
| from torch.testing._internal.common_utils import run_tests, TEST_XPU, TestCase | ||||
| from torch.testing._internal.common_utils import run_tests, TEST_XPU | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     DTensorTestBase, | ||||
|     with_comms, | ||||
| @ -863,7 +862,7 @@ class TestDeviceMeshGetItem(DTensorTestBase): | ||||
|  | ||||
|         # Test flatten into an existing mesh_dim_name inside the mesh | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             RuntimeError, | ||||
|             "already exists for submesh of the DeviceMesh", | ||||
|         ): | ||||
|             mesh_3d._flatten("dp") | ||||
| @ -903,18 +902,6 @@ class TestDeviceMeshGetItem(DTensorTestBase): | ||||
|         cp_tp_mesh._flatten("dummy") | ||||
|         self.assertEqual(mesh_3d["dummy"].mesh_dim_names[0], "dummy") | ||||
|  | ||||
|         # Test flatten into an existing mesh_dim_name inside the mesh | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             "dp already exists for submesh of the DeviceMesh", | ||||
|         ): | ||||
|             mesh_3d._flatten("dp") | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             "Flatten mesh with mesh_dim_name dp_tp has been created before", | ||||
|         ): | ||||
|             mesh_3d["cp", "tp"]._flatten("dp_tp") | ||||
|  | ||||
|     @with_comms(eager_init=True) | ||||
|     def test_flatten_mesh_4d(self): | ||||
|         mesh_shape = (2, 2, 2, 1) | ||||
| @ -1301,204 +1288,5 @@ class DeviceMeshCollectiveTest(DTensorTestBase): | ||||
|             self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) | ||||
|  | ||||
|  | ||||
| class CuTeLayoutTest(TestCase): | ||||
|     def test_coalesce(self): | ||||
|         # ((3,2),(2,1)) -> (6,1) | ||||
|         l = _Layout((3, 2), (2, 1)) | ||||
|         l = l.coalesce() | ||||
|         self.assertEqual(list(l.sizes_and_strides), [(6, 1)]) | ||||
|  | ||||
|         # ((2,12),(3,4),(4,1)) -> (24,1) | ||||
|         l = _Layout((2, 3, 4), (12, 4, 1)) | ||||
|         l = l.coalesce() | ||||
|         self.assertEqual(list(l.sizes_and_strides), [(24, 1)]) | ||||
|  | ||||
|     def test_coalesce_non_coalescible(self): | ||||
|         # ((3,4),(2,1)) stays as-is (4 ≠ 2*1) | ||||
|         l = _Layout((3, 2), (4, 1)) | ||||
|         l = l.coalesce() | ||||
|         self.assertEqual(list(l.sizes_and_strides), [(3, 4), (2, 1)]) | ||||
|  | ||||
|     def test_complement_n_group_layout(self): | ||||
|         # complement((4,2), 8) = (2,1); together form (8,1) | ||||
|         pg_layout = _Layout( | ||||
|             (4,), | ||||
|             (2,), | ||||
|         ) | ||||
|         outer = pg_layout.complement(world_size=8) | ||||
|         self.assertEqual(list(outer.sizes_and_strides), [(2, 1)]) | ||||
|         self.assertEqual( | ||||
|             pg_layout.all_ranks_from_zero(), | ||||
|             [0, 2, 4, 6], | ||||
|         ) | ||||
|         groups = [ | ||||
|             [o + i for i in pg_layout.all_ranks_from_zero()] | ||||
|             for o in outer.all_ranks_from_zero() | ||||
|         ] | ||||
|         self.assertEqual( | ||||
|             groups, | ||||
|             [ | ||||
|                 [0, 2, 4, 6], | ||||
|                 [1, 3, 5, 7], | ||||
|             ], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             pg_layout.global_ranks(8), | ||||
|             [ | ||||
|                 [0, 2, 4, 6], | ||||
|                 [1, 3, 5, 7], | ||||
|             ], | ||||
|         ) | ||||
|         # complement((4,2), 16) = ((2,8), (2,1)); together form (16,1) | ||||
|         outer = pg_layout.complement(world_size=16) | ||||
|         self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) | ||||
|         self.assertEqual( | ||||
|             outer.all_ranks_from_zero(), | ||||
|             [0, 1, 8, 9], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             pg_layout.global_ranks(16), | ||||
|             [ | ||||
|                 [0, 2, 4, 6], | ||||
|                 [1, 3, 5, 7], | ||||
|                 [8, 10, 12, 14], | ||||
|                 [9, 11, 13, 15], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # Complement ((2,4), (2,1)) under world_size=16 → complement ((2,8), (2,2)) | ||||
|         pg_layout = _Layout((2, 2), (4, 1)) | ||||
|         self.assertEqual( | ||||
|             pg_layout.all_ranks_from_zero(), | ||||
|             [0, 1, 4, 5], | ||||
|         ) | ||||
|         outer = pg_layout.complement(world_size=16) | ||||
|         self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 2)]) | ||||
|         self.assertEqual( | ||||
|             outer.all_ranks_from_zero(), | ||||
|             [0, 2, 8, 10], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             pg_layout.global_ranks(16), | ||||
|             [ | ||||
|                 [0, 1, 4, 5], | ||||
|                 [2, 3, 6, 7], | ||||
|                 [8, 9, 12, 13], | ||||
|                 [10, 11, 14, 15], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # Test layout_to_global_ranks and layout_to_all_ranks_from_zero | ||||
|         pg_layout = _Layout((2, 2), (4, 2)) | ||||
|         self.assertEqual( | ||||
|             pg_layout.all_ranks_from_zero(), | ||||
|             [0, 2, 4, 6], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             pg_layout.global_ranks(16), | ||||
|             [ | ||||
|                 [0, 2, 4, 6], | ||||
|                 [1, 3, 5, 7], | ||||
|                 [8, 10, 12, 14], | ||||
|                 [9, 11, 13, 15], | ||||
|             ], | ||||
|         ) | ||||
|         outer = pg_layout.complement(world_size=16) | ||||
|         self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) | ||||
|         # Test when stride is not monotonically decreasing, the complement layout | ||||
|         # is same as the one sorted its stride. | ||||
|         pg_layout_r = _Layout((2, 2), (2, 4)) | ||||
|         outer = pg_layout_r.complement(world_size=16) | ||||
|         self.assertEqual(list(outer.sizes_and_strides), [(2, 8), (2, 1)]) | ||||
|         self.assertEqual( | ||||
|             pg_layout_r.global_ranks(16), | ||||
|             [ | ||||
|                 [0, 4, 2, 6], | ||||
|                 [1, 5, 3, 7], | ||||
|                 [8, 12, 10, 14], | ||||
|                 [9, 13, 11, 15], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # Test just all_ranks_from_zero and global_ranks. | ||||
|         pg_layout = _Layout((4,), (2,)) | ||||
|         self.assertEqual( | ||||
|             pg_layout.all_ranks_from_zero(), | ||||
|             [0, 2, 4, 6], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             pg_layout.global_ranks(16), | ||||
|             [ | ||||
|                 [0, 2, 4, 6], | ||||
|                 [1, 3, 5, 7], | ||||
|                 [8, 10, 12, 14], | ||||
|                 [9, 11, 13, 15], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|     def test_composition(self): | ||||
|         # self = ((4,2), (2,1)), l = (2,1)  → self o l = (2,1) | ||||
|         orig_l = _Layout((4, 2), (2, 1)) | ||||
|         right_l = _Layout((2,), (1,)) | ||||
|         composed_layout = orig_l.composition(right_l) | ||||
|         self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 1)]) | ||||
|         self.assertEqual( | ||||
|             composed_layout.global_ranks(8), | ||||
|             [ | ||||
|                 [0, 1], | ||||
|                 [2, 3], | ||||
|                 [4, 5], | ||||
|                 [6, 7], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # self = (4,2), l = (2,1)  → self o l = (2,2) | ||||
|         orig_l = _Layout((4,), (2,)) | ||||
|         right_l = _Layout((2,), (1,)) | ||||
|         composed_layout = orig_l.composition(right_l) | ||||
|         self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 2)]) | ||||
|         self.assertEqual( | ||||
|             composed_layout.global_ranks(8), | ||||
|             [ | ||||
|                 [0, 2], | ||||
|                 [1, 3], | ||||
|                 [4, 6], | ||||
|                 [5, 7], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # self = (4,2), l = ((2,2), (2,1))  → self o l = ((2,4), (2,2)) | ||||
|         # This is to mimic the un-flatten from a 2D mesh to a 1D mesh. | ||||
|         right_l = _Layout((2, 2), (2, 1)) | ||||
|         composed_layout = orig_l.composition(right_l) | ||||
|         self.assertEqual(list(composed_layout.sizes_and_strides), [(2, 4), (2, 2)]) | ||||
|         self.assertEqual( | ||||
|             composed_layout[0].global_ranks(8), | ||||
|             [ | ||||
|                 [0, 4], | ||||
|                 [1, 5], | ||||
|                 [2, 6], | ||||
|                 [3, 7], | ||||
|             ], | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             composed_layout[1].global_ranks(8), | ||||
|             [ | ||||
|                 [0, 2], | ||||
|                 [1, 3], | ||||
|                 [4, 6], | ||||
|                 [5, 7], | ||||
|             ], | ||||
|         ) | ||||
|  | ||||
|         # Error case. | ||||
|         orig_l = _Layout((4, 2), (4, 1)) | ||||
|         with self.assertRaises( | ||||
|             AssertionError, | ||||
|         ): | ||||
|             right_l = _Layout((2,), (3,)) | ||||
|             orig_l.composition(right_l) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -299,33 +299,28 @@ class NVSHMEMAll2AllTest(MultiProcContinuousTest): | ||||
|             torch.randn(max_inp_numel, dtype=dtype, device=self.device) | ||||
|         ) | ||||
|         out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1) | ||||
|         in_splits = symm_mem.empty( | ||||
|             self.world_size, dtype=torch.int64, device=self.device | ||||
|         ) | ||||
|         out_splits_offsets = symm_mem.empty( | ||||
|             (2, self.world_size), dtype=torch.int64, device=self.device | ||||
|         in_out_splits = symm_mem.empty( | ||||
|             (3, self.world_size), dtype=torch.int64, device=self.device | ||||
|         ) | ||||
|         # Row 0 is input splits | ||||
|         in_splits.copy_(inp_splits) | ||||
|         in_out_splits[0].copy_(inp_splits) | ||||
|  | ||||
|         # Sync all ranks to ensure remote tensors are allocated | ||||
|         dist.barrier() | ||||
|  | ||||
|         torch.ops.symm_mem.all_to_all_vdev( | ||||
|             inp, out, in_splits, out_splits_offsets, group_name | ||||
|         ) | ||||
|         torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name) | ||||
|  | ||||
|         # Check input splits (row 0) -- should not change | ||||
|         torch.testing.assert_close(in_splits, inp_splits) | ||||
|         torch.testing.assert_close(in_out_splits[0], inp_splits) | ||||
|  | ||||
|         # Check output splits (row 1) | ||||
|         torch.testing.assert_close(out_splits_offsets[0], out_splits) | ||||
|         torch.testing.assert_close(in_out_splits[1], out_splits) | ||||
|  | ||||
|         # Check output offsets (row 2) | ||||
|         out_offsets = torch.cumsum(out_splits, dim=0)  # inclusive scan | ||||
|         # output offsets from `all_to_all_vdev` is exclusive scan | ||||
|         self.assertEqual(out_splits_offsets[1][0], 0) | ||||
|         torch.testing.assert_close(out_splits_offsets[1][1:], out_offsets[:-1]) | ||||
|         self.assertEqual(in_out_splits[2][0], 0) | ||||
|         torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1]) | ||||
|  | ||||
|         # Check data | ||||
|         expected = torch.empty(out_numel, dtype=dtype, device=self.device) | ||||
|  | ||||
| @ -2,8 +2,6 @@ | ||||
| # To run: | ||||
| # python test/distributed/test_nvshmem_triton.py | ||||
|  | ||||
| import sys | ||||
|  | ||||
| import triton.language as tl | ||||
|  | ||||
| import torch | ||||
| @ -11,7 +9,6 @@ import torch.distributed as dist | ||||
| import torch.distributed._symmetric_memory as symm_mem | ||||
| import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem | ||||
| from torch._inductor.runtime.triton_compat import triton | ||||
| from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem | ||||
| from torch.testing._internal.common_distributed import MultiProcContinuousTest | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
| @ -23,9 +20,12 @@ from torch.testing._internal.common_utils import ( | ||||
| from torch.testing._internal.inductor_utils import IS_H100, requires_triton | ||||
|  | ||||
|  | ||||
| if not symm_mem.is_nvshmem_available(): | ||||
|     print("NVSHMEM not available, skipping tests") | ||||
|     sys.exit(0) | ||||
| # Decorators | ||||
| def requires_nvshmem(): | ||||
|     return skip_but_pass_in_sandcastle_if( | ||||
|         not symm_mem.is_nvshmem_available(), | ||||
|         "test_nvshmem requires NVSHMEM, skipping tests", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def requires_h100(): | ||||
| @ -41,11 +41,8 @@ device_module = torch.get_device_module(device_type) | ||||
|  | ||||
|  | ||||
| # Shared Triton JIT kernels | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_put_kernel( | ||||
| def nvshmem_put_kernel( | ||||
|     dest, | ||||
|     src, | ||||
|     nelems, | ||||
| @ -54,9 +51,8 @@ def my_put_kernel( | ||||
|     nvshmem.put(dest, src, nelems, pe) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_get_kernel( | ||||
| def nvshmem_get_kernel( | ||||
|     dest, | ||||
|     src, | ||||
|     nelems, | ||||
| @ -65,9 +61,8 @@ def my_get_kernel( | ||||
|     nvshmem.get(dest, src, nelems, pe) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_putmem_signal_block_kernel( | ||||
| def nvshmem_putmem_signal_block_kernel( | ||||
|     dst, | ||||
|     src, | ||||
|     size_bytes, | ||||
| @ -79,15 +74,13 @@ def my_putmem_signal_block_kernel( | ||||
|     nvshmem.putmem_signal_block(dst, src, size_bytes, signal, sig_val, sig_op, peer) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_signal_wait_until_kernel(signal, cmp_op, cmp_val): | ||||
| def nvshmem_signal_wait_until_kernel(signal, cmp_op, cmp_val): | ||||
|     nvshmem.signal_wait_until(signal, cmp_op, cmp_val) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_signal_op_kernel( | ||||
| def nvshmem_signal_op_kernel( | ||||
|     sig_addr, | ||||
|     signal, | ||||
|     sig_op, | ||||
| @ -96,9 +89,8 @@ def my_signal_op_kernel( | ||||
|     nvshmem.signal_op(sig_addr, signal, sig_op, peer) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_wait_until_kernel( | ||||
| def nvshmem_wait_until_kernel( | ||||
|     ivar, | ||||
|     cmp_op, | ||||
|     cmp_val, | ||||
| @ -106,15 +98,13 @@ def my_wait_until_kernel( | ||||
|     nvshmem.wait_until(ivar, cmp_op, cmp_val) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_fence_kernel(): | ||||
| def nvshmem_fence_kernel(): | ||||
|     nvshmem.fence() | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_put_with_fence_kernel( | ||||
| def nvshmem_put_with_fence_kernel( | ||||
|     dst1, | ||||
|     src1, | ||||
|     dst2, | ||||
| @ -136,9 +126,8 @@ def my_put_with_fence_kernel( | ||||
|     nvshmem.put(flag_dst, flag_src, 1, peer) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_put_with_quiet_kernel( | ||||
| def nvshmem_put_with_quiet_kernel( | ||||
|     dst, | ||||
|     src, | ||||
|     flag_dst, | ||||
| @ -155,9 +144,8 @@ def my_put_with_quiet_kernel( | ||||
|     nvshmem.put(flag_dst, flag_src, 1, peer) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_barrier_test_kernel( | ||||
| def nvshmem_barrier_test_kernel( | ||||
|     dst, | ||||
|     src, | ||||
|     nelems, | ||||
| @ -190,15 +178,13 @@ def my_barrier_test_kernel( | ||||
|         tl.store(p_dst, received + 1) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_barrier_all_kernel(): | ||||
| def nvshmem_barrier_all_kernel(): | ||||
|     nvshmem.barrier_all() | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_sync_test_kernel( | ||||
| def nvshmem_sync_test_kernel( | ||||
|     local_data, | ||||
|     remote_data, | ||||
|     nelems, | ||||
| @ -224,9 +210,8 @@ def my_sync_test_kernel( | ||||
|     # because sync_all() made those local stores visible | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_alltoall_kernel( | ||||
| def nvshmem_alltoall_kernel( | ||||
|     team_handle, | ||||
|     dst, | ||||
|     src, | ||||
| @ -235,9 +220,8 @@ def my_alltoall_kernel( | ||||
|     nvshmem.alltoall(team_handle, dst, src, nelems_per_pe) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_broadcast_kernel( | ||||
| def nvshmem_broadcast_kernel( | ||||
|     team_handle, | ||||
|     dst, | ||||
|     src, | ||||
| @ -247,9 +231,8 @@ def my_broadcast_kernel( | ||||
|     nvshmem.broadcast(team_handle, dst, src, nelems, pe_root) | ||||
|  | ||||
|  | ||||
| @requires_nvshmem | ||||
| @triton.jit | ||||
| def my_reduce_kernel( | ||||
| def nvshmem_reduce_kernel( | ||||
|     team_handle, | ||||
|     dest_tensor, | ||||
|     source_tensor, | ||||
| @ -260,6 +243,7 @@ def my_reduce_kernel( | ||||
|  | ||||
|  | ||||
| @instantiate_parametrized_tests | ||||
| @requires_nvshmem() | ||||
| class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def _init_device(self) -> None: | ||||
|         # TODO: relieve this (seems to hang if without) | ||||
| @ -278,6 +262,9 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         # Enable NVSHMEM for Triton | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|  | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -307,11 +294,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         peer = 1 - rank | ||||
|         if rank == 0: | ||||
|             # Rank 0 puts its data to Rank 1 | ||||
|             my_put_kernel[(1,)]( | ||||
|             nvshmem_put_kernel[(1,)]( | ||||
|                 dst, | ||||
|                 src, | ||||
|                 nelems, | ||||
|                 peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|         # Synchronize after operation | ||||
| @ -331,6 +319,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -352,11 +341,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         peer = 1 - rank | ||||
|         if rank == 1: | ||||
|             # Rank 1 gets data from rank 0 using tensor-aware API | ||||
|             my_get_kernel[(1,)]( | ||||
|             nvshmem_get_kernel[(1,)]( | ||||
|                 out, | ||||
|                 inp, | ||||
|                 numel, | ||||
|                 peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|         if rank == 1: | ||||
|             torch.testing.assert_close( | ||||
| @ -370,6 +360,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -392,11 +383,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         peer = (rank - 1) % world_size | ||||
|  | ||||
|         # All ranks execute the get operation using tensor-aware API | ||||
|         my_get_kernel[(1,)]( | ||||
|         nvshmem_get_kernel[(1,)]( | ||||
|             out, | ||||
|             inp, | ||||
|             numel, | ||||
|             peer, | ||||
|             extern_libs=nvshmem_lib, | ||||
|         ) | ||||
|  | ||||
|         expected_value = peer | ||||
| @ -411,6 +403,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|  | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -437,7 +431,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         if rank == 0: | ||||
|             # Rank 0 puts into Rank 1 | ||||
|             my_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|             nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|                 out, | ||||
|                 inp, | ||||
|                 size_bytes=msg_size_bytes, | ||||
| @ -445,14 +439,16 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|                 sig_val=SIGNAL_VAL, | ||||
|                 sig_op=NVSHMEM_SIGNAL_SET, | ||||
|                 peer=peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|         if rank == 1: | ||||
|             # Wait until signal flag is set by Rank 0 | ||||
|             my_signal_wait_until_kernel[(1,)]( | ||||
|             nvshmem_signal_wait_until_kernel[(1,)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=SIGNAL_VAL, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|             # After wait completes, verify data and flag contents | ||||
|             torch.testing.assert_close( | ||||
| @ -469,6 +465,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|  | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -495,7 +493,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         if rank == 0: | ||||
|             # Rank 0 puts into Rank 1 | ||||
|             my_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|             nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|                 out, | ||||
|                 inp, | ||||
|                 size_bytes=msg_size_bytes, | ||||
| @ -503,13 +501,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|                 sig_val=SIGNAL_VAL, | ||||
|                 sig_op=NVSHMEM_SIGNAL_ADD, | ||||
|                 peer=peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|         if rank == 1: | ||||
|             my_signal_wait_until_kernel[(1, 1, 1)]( | ||||
|             nvshmem_signal_wait_until_kernel[(1, 1, 1)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=SIGNAL_VAL, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|             torch.testing.assert_close( | ||||
|                 out, val * torch.ones(numel, dtype=dtype, device=self.device) | ||||
| @ -525,6 +525,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|  | ||||
| @ -543,12 +544,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|             [FLAG_FINAL_VALUE], dtype=torch.int32, device=self.device | ||||
|         ) | ||||
|  | ||||
|         nvshmem_barrier_all_kernel[(1,)](extern_libs=nvshmem_lib) | ||||
|  | ||||
|         if rank == 0: | ||||
|             # Rank 0 (the waiter) | ||||
|             my_wait_until_kernel[(1,)]( | ||||
|             nvshmem_wait_until_kernel[(1,)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=FLAG_FINAL_VALUE, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|             # Verification | ||||
| @ -560,11 +564,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         if rank == 1: | ||||
|             # Rank 1 (the signaler) | ||||
|             # Launch a kernel to put the value to Rank 0's flag tensor. | ||||
|             my_put_kernel[(1,)]( | ||||
|             nvshmem_put_kernel[(1,)]( | ||||
|                 flag,  # Destination symmetric tensor on the remote PE | ||||
|                 expected_flag,  # Source data tensor (local) | ||||
|                 1,  # Number of elements | ||||
|                 peer,  # The target PE (Rank 0) | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|     @skipIfRocm | ||||
| @ -572,6 +577,8 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     @requires_h100() | ||||
|     def test_triton_signal_wait_until(self) -> None: | ||||
|         self._init_device() | ||||
|         # Enable NVSHMEM for Triton | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -601,7 +608,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         if rank == 0: | ||||
|             # Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag | ||||
|             my_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|             nvshmem_putmem_signal_block_kernel[(1, 1, 1)]( | ||||
|                 out, | ||||
|                 inp, | ||||
|                 size_bytes=msg_size_bytes, | ||||
| @ -609,13 +616,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|                 sig_val=COMPLETION_FLAG_VAL, | ||||
|                 sig_op=NVSHMEM_SIGNAL_SET, | ||||
|                 peer=peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|         elif rank == 1: | ||||
|             # Consumer (rank 1): Waits on the signal variable using `signal_wait_until`. | ||||
|             my_signal_wait_until_kernel[(1, 1, 1)]( | ||||
|             nvshmem_signal_wait_until_kernel[(1, 1, 1)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=COMPLETION_FLAG_VAL, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|             # After the wait returns, verify data and flag | ||||
|             torch.testing.assert_close( | ||||
| @ -642,6 +651,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         """ | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -672,7 +682,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         NVSHMEM_CMP_EQ = 0  # compare equal | ||||
|  | ||||
|         if rank == 0: | ||||
|             my_put_with_fence_kernel[(1,)]( | ||||
|             nvshmem_put_with_fence_kernel[(1,)]( | ||||
|                 out1, | ||||
|                 inp1, | ||||
|                 out2, | ||||
| @ -681,13 +691,15 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|                 flag_update_val, | ||||
|                 nelems=numel, | ||||
|                 peer=peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|         elif rank == 1: | ||||
|             # Wait until flag is set by Rank 0 | ||||
|             my_wait_until_kernel[(1,)]( | ||||
|             nvshmem_wait_until_kernel[(1,)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=flag_val, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|  | ||||
|             # Verify ordered data arrival. | ||||
| @ -707,6 +719,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_quiet(self) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -732,19 +745,21 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         dist.barrier() | ||||
|         if rank == 1: | ||||
|             my_put_with_quiet_kernel[(1,)]( | ||||
|             nvshmem_put_with_quiet_kernel[(1,)]( | ||||
|                 out, | ||||
|                 inp, | ||||
|                 flag, | ||||
|                 flag_update_val, | ||||
|                 nelems=numel, | ||||
|                 peer=peer, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|         elif rank == 0: | ||||
|             my_wait_until_kernel[(1,)]( | ||||
|             nvshmem_wait_until_kernel[(1,)]( | ||||
|                 flag, | ||||
|                 cmp_op=NVSHMEM_CMP_EQ, | ||||
|                 cmp_val=flag_val, | ||||
|                 extern_libs=nvshmem_lib, | ||||
|             ) | ||||
|             torch.testing.assert_close( | ||||
|                 out, val * torch.ones(numel, dtype=dtype, device=self.device) | ||||
| @ -757,6 +772,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_barrier(self) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -768,10 +784,11 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         symm_mem.rendezvous(src, group=group_name) | ||||
|         symm_mem.rendezvous(dst, group=group_name) | ||||
|  | ||||
|         my_barrier_test_kernel[(1,)]( | ||||
|         nvshmem_barrier_test_kernel[(1,)]( | ||||
|             dst, | ||||
|             src, | ||||
|             nelems=numel, | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|             num_ctas=1, | ||||
|         ) | ||||
| @ -793,6 +810,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|  | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -806,10 +824,11 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         symm_mem.rendezvous(remote_data, group=group_name) | ||||
|  | ||||
|         # Launch kernel with cooperative grid | ||||
|         my_sync_test_kernel[(1,)]( | ||||
|         nvshmem_sync_test_kernel[(1,)]( | ||||
|             local_data, | ||||
|             remote_data, | ||||
|             nelems=numel, | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|             num_ctas=1, | ||||
|         ) | ||||
| @ -836,6 +855,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_alltoall(self) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         world_size = dist.get_world_size() | ||||
| @ -860,11 +880,12 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         dist.barrier() | ||||
|         team_handle = 0  # NVSHMEM_TEAM_WORLD handle is 0 | ||||
|         # Launch the kernel using new tensor-aware API | ||||
|         my_alltoall_kernel[(1,)]( | ||||
|         nvshmem_alltoall_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst, | ||||
|             src, | ||||
|             nelems_per_pe, | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|         # Synchronize after alltoall | ||||
| @ -883,6 +904,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_broadcast(self) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         rank = self.rank | ||||
| @ -913,12 +935,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         # Execute broadcast | ||||
|         team_handle = 0  # NVSHMEM_TEAM_WORLD | ||||
|         my_broadcast_kernel[(1,)]( | ||||
|         nvshmem_broadcast_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst, | ||||
|             src, | ||||
|             nelems, | ||||
|             pe_root, | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|  | ||||
| @ -951,6 +974,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_sum_reduce(self, dtype) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         world_size = dist.get_world_size() | ||||
| @ -977,12 +1001,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         # Execute sum reduction across all ranks | ||||
|         team_handle = 0  # NVSHMEM_TEAM_WORLD | ||||
|         my_reduce_kernel[(1,)]( | ||||
|         nvshmem_reduce_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst, | ||||
|             src, | ||||
|             nreduce, | ||||
|             operation="sum", | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|  | ||||
| @ -1013,6 +1038,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_minmax_reduce(self, dtype) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         world_size = dist.get_world_size() | ||||
| @ -1054,21 +1080,23 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|         dist.barrier() | ||||
|         # Execute MIN reduction | ||||
|         team_handle = 0 | ||||
|         my_reduce_kernel[(1,)]( | ||||
|         nvshmem_reduce_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst_min, | ||||
|             src_min, | ||||
|             nreduce, | ||||
|             operation="min", | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|         # Execute MAX reduction | ||||
|         my_reduce_kernel[(1,)]( | ||||
|         nvshmem_reduce_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst_max, | ||||
|             src_max, | ||||
|             nreduce, | ||||
|             operation="max", | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|         dist.barrier() | ||||
| @ -1099,6 +1127,7 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|     def test_triton_prod_reduce(self, dtype) -> None: | ||||
|         torch.manual_seed(42 + self.rank) | ||||
|         self._init_device() | ||||
|         nvshmem_lib = nvshmem.enable_triton() | ||||
|         group_name = dist.distributed_c10d._get_default_group().group_name | ||||
|         symm_mem.enable_symm_mem_for_group(group_name) | ||||
|         world_size = dist.get_world_size() | ||||
| @ -1138,12 +1167,13 @@ class NVSHMEMTritonTest(MultiProcContinuousTest): | ||||
|  | ||||
|         # Execute product reduction across all ranks | ||||
|         team_handle = 0  # NVSHMEM_TEAM_WORLD | ||||
|         my_reduce_kernel[(1,)]( | ||||
|         nvshmem_reduce_kernel[(1,)]( | ||||
|             team_handle, | ||||
|             dst, | ||||
|             src, | ||||
|             nreduce, | ||||
|             operation="prod", | ||||
|             extern_libs=nvshmem_lib, | ||||
|             launch_cooperative_grid=True, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @ -505,7 +505,7 @@ class AsyncTPTest(MultiProcContinuousTest): | ||||
|         not PLATFORM_SUPPORTS_SYMM_MEM, "SymmMem is not supported on this ROCm arch" | ||||
|     ) | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @parametrize("scatter_dim", [0, 1, 2]) | ||||
|     @parametrize("scatter_dim", [0, 1]) | ||||
|     def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: | ||||
|         self._init_process() | ||||
|  | ||||
|  | ||||
| @ -519,7 +519,11 @@ class AOTAutogradCacheTests(InductorTestCase): | ||||
|     @functorch_config.patch( | ||||
|         {"enable_autograd_cache": True, "view_replay_for_aliased_outputs": True} | ||||
|     ) | ||||
|     def test_view_replay(self): | ||||
|     def test_view_replay_bypass(self): | ||||
|         """ | ||||
|         Should bypass when view replay is turned on | ||||
|         """ | ||||
|  | ||||
|         def fn(a): | ||||
|             tmp = a.detach() | ||||
|             a.mul_(2) | ||||
| @ -527,25 +531,10 @@ class AOTAutogradCacheTests(InductorTestCase): | ||||
|  | ||||
|         with torch.autograd._force_original_view_tracking(True): | ||||
|             compiled_fn = torch.compile(fn) | ||||
|             compiled_fn(torch.rand(2, 3)) | ||||
|  | ||||
|         def run_and_check(miss, hit, bypass): | ||||
|             self._clear_dynamo_and_codecache() | ||||
|  | ||||
|             inp = torch.rand(2, 3) | ||||
|             compiled_inp = inp.clone().detach() | ||||
|  | ||||
|             with torch.autograd._force_original_view_tracking(True): | ||||
|                 out = fn(inp) | ||||
|                 compiled_out = compiled_fn(compiled_inp) | ||||
|  | ||||
|             self.assertEqual(out, compiled_out) | ||||
|             self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], miss) | ||||
|             self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], hit) | ||||
|             self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], bypass) | ||||
|  | ||||
|         run_and_check(miss=1, hit=0, bypass=0) | ||||
|         run_and_check(miss=1, hit=1, bypass=0) | ||||
|         run_and_check(miss=1, hit=2, bypass=0) | ||||
|         self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) | ||||
|         self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1) | ||||
|  | ||||
|     @inductor_config.patch("fx_graph_remote_cache", False) | ||||
|     @inductor_config.patch("fx_graph_cache", True) | ||||
|  | ||||
| @ -1,10 +1,8 @@ | ||||
| # Owner(s): ["module: dynamo"] | ||||
|  | ||||
| import inspect | ||||
| import os | ||||
| import pickle | ||||
| from contextlib import contextmanager | ||||
| from unittest.mock import patch | ||||
|  | ||||
| import torch | ||||
| import torch._dynamo.testing | ||||
| @ -31,27 +29,8 @@ class CustomCompiledFunction(torch._dynamo.aot_compile.SerializableCallable): | ||||
|  | ||||
|     @classmethod | ||||
|     def serialize_compile_artifacts(cls, fn) -> bytes: | ||||
|         import sympy | ||||
|  | ||||
|         from torch._subclasses import FakeTensorMode | ||||
|         from torch.fx._graph_pickler import Options | ||||
|  | ||||
|         state = fn.__dict__.copy() | ||||
|         graph_reducer_override = GraphPickler.reducer_override | ||||
|  | ||||
|         def _graph_reducer_override(self, obj): | ||||
|             if ( | ||||
|                 inspect.isclass(obj) | ||||
|                 and issubclass(obj, sympy.Function) | ||||
|                 and hasattr(obj, "_torch_unpickler") | ||||
|             ): | ||||
|                 return obj._torch_unpickler, (obj._torch_handler_name,) | ||||
|             if isinstance(obj, FakeTensorMode): | ||||
|                 return type(None), () | ||||
|             return graph_reducer_override(self, obj) | ||||
|  | ||||
|         with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): | ||||
|             state["gm"] = GraphPickler.dumps(state["gm"], Options(ops_filter=None)) | ||||
|         state["gm"] = GraphPickler.dumps(state["gm"]) | ||||
|         return pickle.dumps(state) | ||||
|  | ||||
|     @classmethod | ||||
| @ -75,14 +54,6 @@ class SimpleLinearModule(torch.nn.Module): | ||||
|         return self.linear(x) | ||||
|  | ||||
|  | ||||
| class RepeatInterleaveModule(torch.nn.Module): | ||||
|     def forward(self, x): | ||||
|         chunk = x.chunk(2, dim=-1) | ||||
|         y = chunk[0] | ||||
|         y_repeat = y.repeat_interleave(2, dim=-1) | ||||
|         return y_repeat | ||||
|  | ||||
|  | ||||
| @torch._dynamo.config.patch("enable_aot_compile", True) | ||||
| @instantiate_parametrized_tests | ||||
| class TestAOTCompile(torch._inductor.test_case.TestCase): | ||||
| @ -143,34 +114,6 @@ class TestAOTCompile(torch._inductor.test_case.TestCase): | ||||
|             actual = compiled_fn(mod, *inputs) | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_aot_compile_repeat_interleave(self): | ||||
|         mod = RepeatInterleaveModule() | ||||
|  | ||||
|         def backend(gm, example_inputs): | ||||
|             return CustomCompiledFunction(gm, example_inputs) | ||||
|  | ||||
|         inputs = (torch.randn(2, 4),) | ||||
|  | ||||
|         # The first dim should be dynamic to repro the issue of repeat_interleave | ||||
|         # torch._dynamo.mark_dynamic(inputs[0], [0]) | ||||
|  | ||||
|         compiled_fn = torch.compile( | ||||
|             mod, | ||||
|             fullgraph=True, | ||||
|             backend=backend, | ||||
|         ).forward.aot_compile((inputs, {})) | ||||
|  | ||||
|         expected = mod(*inputs) | ||||
|         actual = compiled_fn(mod, *inputs) | ||||
|         self.assertEqual(expected, actual) | ||||
|         compiled_fn.save_compiled_function(self.path()) | ||||
|         torch._dynamo.reset() | ||||
|         with torch.compiler.set_stance("fail_on_recompile"): | ||||
|             with open(self.path(), "rb") as f: | ||||
|                 compiled_fn = torch.compiler.load_compiled_function(f) | ||||
|             actual = compiled_fn(mod, *inputs) | ||||
|             self.assertEqual(expected, actual) | ||||
|  | ||||
|     def test_decorated_function_aot(self): | ||||
|         def check_inputs(fn): | ||||
|             def _fn(*args, **kwargs): | ||||
|  | ||||
| @ -80,7 +80,7 @@ def fn(): | ||||
|         self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) | ||||
|  | ||||
|     @unittest.skipIf( | ||||
|         sys.version_info >= (3, 11), | ||||
|         sys.version_info < (3, 10) or sys.version_info >= (3, 11), | ||||
|         "linetable test for Python 3.10", | ||||
|     ) | ||||
|     def test_linetable_310_writer(self): | ||||
| @ -95,6 +95,19 @@ def fn(): | ||||
|         result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) | ||||
|         self.assertTrue(result[1] == fn.__code__.co_linetable) | ||||
|  | ||||
|     @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") | ||||
|     def test_lnotab_writer(self): | ||||
|         def fn(): | ||||
|             a = 10 | ||||
|             b = 20 | ||||
|             c = a + b | ||||
|             f = "lnotab_writer" | ||||
|             return f"Test if {f} generates correct co_lnotab: {c}" | ||||
|  | ||||
|         inst = dis.get_instructions(fn) | ||||
|         result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) | ||||
|         self.assertTrue(result[1] == fn.__code__.co_lnotab) | ||||
|  | ||||
|     def test_if_tensor_is_none(self): | ||||
|         """ | ||||
|         Python 3.11 adds new jump instructions that check if | ||||
|  | ||||
| @ -410,6 +410,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase): | ||||
|             combs.append(torch.ones(size)) | ||||
|         return combs | ||||
|  | ||||
|     @unittest.skipIf( | ||||
|         sys.version_info < (3, 10), | ||||
|         "itertools.pairwise was added at Python 3.10", | ||||
|     ) | ||||
|     @make_test | ||||
|     def test_itertools_pairwise(a): | ||||
|         pairs = [] | ||||
| @ -4694,6 +4698,10 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): | ||||
|         self.assertEqual(len(lst), 2) | ||||
|         self.assertEqual(lst[0], lst[1]) | ||||
|  | ||||
|     @unittest.skipIf( | ||||
|         sys.version_info < (3, 10), | ||||
|         "zip strict kwargs not implemented for Python < 3.10", | ||||
|     ) | ||||
|     def test_zip_strict(self): | ||||
|         def fn(x, ys, zs): | ||||
|             x = x.clone() | ||||
|  | ||||
| @ -8005,11 +8005,8 @@ utils_device.CURRENT_DEVICE == None""".split("\n"): | ||||
|         torch._dynamo.decorators.mark_unbacked(b, 1) | ||||
|         func(a, b) | ||||
|         func(torch.rand(4, 5), torch.rand(4, 5)) | ||||
|         # This does not raise an error right now because of a recompilation. | ||||
|         # https://github.com/pytorch/pytorch/issues/163785 | ||||
|         # with self.assertRaises(AssertionError): | ||||
|         #     func(torch.rand(1, 1), torch.rand(2, 1)) | ||||
|         func(torch.rand(1, 1), torch.rand(2, 1)) | ||||
|         with self.assertRaises(RuntimeError): | ||||
|             func(torch.rand(1, 1), torch.rand(2, 1)) | ||||
|  | ||||
|     @torch._dynamo.config.patch(capture_scalar_outputs=True) | ||||
|     def test_sym_constrain_range_on_replaced_unbacked_symbol(self): | ||||
|  | ||||
| @ -443,15 +443,59 @@ def run(cnt): | ||||
|             f(t(2, 4), t(2, 2)) | ||||
|             f(t(4, 2), t(2, 2)) | ||||
|  | ||||
|             # with both default remote present, we ignore extra remote. | ||||
|             # with default remote (dynamic x) + extra remote (dynamic y), | ||||
|             # we should be able to wobble x & y with no recompiles. | ||||
|             self.reset() | ||||
|             cnts.clear() | ||||
|             with torch.compiler.config.patch(pgo_extra_read_key="sticky_1"): | ||||
|                 f(t(2, 2), t(2, 2)) | ||||
|                 f(t(6, 8), t(2, 2)) | ||||
|                 f(t(2, 4), t(4, 2)) | ||||
|                 f(t(4, 2), t(2, 4)) | ||||
|                 self.assertEqual(cnts.frame_count, 1) | ||||
|                 f(t(2, 2), t(2, 4)) | ||||
|                 self.assertEqual(cnts.frame_count, 2) | ||||
|  | ||||
|     def test_profile_merges(self): | ||||
|         from torch._dynamo.pgo import auto_dynamic, merge_pgo_entry | ||||
|  | ||||
|         @torch.compile(backend="eager", fullgraph=True) | ||||
|         def f(ints, t_scalar, tensors): | ||||
|             # arbitrary compute | ||||
|             return ints[0] + ints[1], t_scalar + 1, [t + 1 for t in tensors] | ||||
|  | ||||
|         # single static run | ||||
|         f( | ||||
|             [0, 2], | ||||
|             torch.tensor(0), | ||||
|             [ | ||||
|                 torch.randn(2), | ||||
|                 torch.randn(2, 2), | ||||
|                 torch.randn(4, 4), | ||||
|             ], | ||||
|         ) | ||||
|         # collect profiles | ||||
|         profile = next( | ||||
|             iter(torch._dynamo.pgo.get_code_state().values()) | ||||
|         ).automatic_dynamic | ||||
|         i0, i1 = profile["L['ints'][0]"], profile["L['ints'][1]"] | ||||
|         ts = profile["L['t_scalar]"] | ||||
|         t0, t1, t2 = ( | ||||
|             profile["L['tensors'][0]"], | ||||
|             profile["L['tensors'][1]"], | ||||
|             profile["L['tensors'][2]"], | ||||
|         ) | ||||
|         # merging same scalar, or tensor into scalar -> no-op | ||||
|         merge_pgo_entry(i0, i0) | ||||
|         merge_pgo_entry(ts, i0) | ||||
|         merge_pgo_entry(t0, i0) | ||||
|         self.assertEqual(i0.scalar, 0) | ||||
|         # merging different scalars -> dynamic | ||||
|         merge_pgo_entry(i1, i0) | ||||
|         self.assertEqual(i0.scalar, auto_dynamic) | ||||
|         # merging different rank tensors -> static | ||||
|         merge_pgo_entry(t0, t2) | ||||
|         self.assertEqual(t2.size, (4, 4)) | ||||
|         # merging same rank tensors -> dynamic | ||||
|         merge_pgo_entry(t1, t2) | ||||
|         self.assertEqual(t2.size, (auto_dynamic, auto_dynamic)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -32,8 +32,7 @@ class TestExperiment(TestCase): | ||||
|         m = Module() | ||||
|         example_inputs = (torch.randn(3),) | ||||
|         m(*example_inputs) | ||||
|         with torch._export.config.patch(use_new_tracer_experimental=True): | ||||
|             ep = torch.export.export(m, example_inputs, strict=True) | ||||
|         ep = torch.export.export(m, example_inputs, strict=True) | ||||
|         joint_ep = _export_forward_backward(ep) | ||||
|         self.assertExpectedInline( | ||||
|             str(joint_ep.graph_module.code).strip(), | ||||
|  | ||||
| @ -21,7 +21,6 @@ from unittest.mock import MagicMock, patch | ||||
|  | ||||
| import torch | ||||
| import torch._dynamo as torchdynamo | ||||
| import torch.fx.traceback as fx_traceback | ||||
| import torch.nn.functional as F | ||||
| import torch.utils._pytree as pytree | ||||
| from functorch.experimental.control_flow import cond, map | ||||
| @ -1087,93 +1086,6 @@ graph(): | ||||
|         args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) | ||||
|         self.assertEqual(gm(*args), m(*args)) | ||||
|  | ||||
|     # stride() is called for an undefined tensor | ||||
|     @testing.expectedFailureCppRuntimeNonStrict | ||||
|     def test_native_multi_attention_head(self): | ||||
|         embed_dim = 64 | ||||
|         num_heads = 4 | ||||
|         bs = 16 | ||||
|         sl = 8 | ||||
|         device = "cpu" | ||||
|  | ||||
|         q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3 | ||||
|         k = q | ||||
|         v = q | ||||
|  | ||||
|         qkv = torch.nn.Linear( | ||||
|             embed_dim, 3 * embed_dim, device=device, dtype=torch.float32 | ||||
|         ) | ||||
|         proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32) | ||||
|  | ||||
|         class NativeMHA(torch.nn.Module): | ||||
|             def __init__( | ||||
|                 self, | ||||
|                 embed_dim, | ||||
|                 num_heads, | ||||
|                 qkv, | ||||
|                 proj, | ||||
|                 need_weights, | ||||
|                 average_attn_weights, | ||||
|                 mask_type, | ||||
|             ): | ||||
|                 super().__init__() | ||||
|                 self.qkv = qkv | ||||
|                 self.proj = proj | ||||
|                 self.embed_dim = embed_dim | ||||
|                 self.num_heads = num_heads | ||||
|                 self.need_weights = need_weights | ||||
|                 self.average_attn_weights = average_attn_weights | ||||
|                 self.mask_type = mask_type | ||||
|  | ||||
|             def forward(self, q, k, v, key_padding_mask): | ||||
|                 return torch._native_multi_head_attention( | ||||
|                     q, | ||||
|                     k, | ||||
|                     v, | ||||
|                     self.embed_dim, | ||||
|                     self.num_heads, | ||||
|                     self.qkv.weight, | ||||
|                     self.qkv.bias, | ||||
|                     self.proj.weight, | ||||
|                     self.proj.bias, | ||||
|                     key_padding_mask, | ||||
|                     need_weights=False, | ||||
|                     average_attn_weights=False, | ||||
|                     mask_type=1,  # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask | ||||
|                 ) | ||||
|  | ||||
|         for mask_type in (0, 1): | ||||
|             for need_weights in (True, False): | ||||
|                 for average_attn_weights in (True, False): | ||||
|                     npt = NativeMHA( | ||||
|                         embed_dim=embed_dim, | ||||
|                         num_heads=num_heads, | ||||
|                         qkv=qkv, | ||||
|                         proj=proj, | ||||
|                         need_weights=need_weights, | ||||
|                         average_attn_weights=average_attn_weights, | ||||
|                         mask_type=mask_type, | ||||
|                     ) | ||||
|                     sample_input = (q, k, v, None) | ||||
|  | ||||
|                     ep = export( | ||||
|                         npt, | ||||
|                         args=sample_input, | ||||
|                         dynamic_shapes={ | ||||
|                             "q": { | ||||
|                                 0: Dim("dim0_q", max=1024), | ||||
|                             }, | ||||
|                             "k": { | ||||
|                                 0: Dim("dim0_k", max=1024), | ||||
|                             }, | ||||
|                             "v": { | ||||
|                                 0: Dim("dim0_v", max=1024), | ||||
|                             }, | ||||
|                             "key_padding_mask": None, | ||||
|                         }, | ||||
|                     ) | ||||
|                     self.assertEqual(ep.module()(*sample_input), npt(*sample_input)) | ||||
|  | ||||
|     def test_unused_constant(self): | ||||
|         class M(torch.nn.Module): | ||||
|             def forward(self, x): | ||||
| @ -2007,8 +1919,8 @@ class GraphModule(torch.nn.Module): | ||||
|                 # z = 3 | ||||
|                 return x + y + z | ||||
|  | ||||
|         with self.assertWarnsRegex( | ||||
|             UserWarning, | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             "The tensor attribute self.buf was assigned during export", | ||||
|         ): | ||||
|             export(M(), (torch.randn(2, 3),), strict=False) | ||||
| @ -2065,8 +1977,8 @@ class GraphModule(torch.nn.Module): | ||||
|                 # z = 3 + 3 | ||||
|                 return x + y + z | ||||
|  | ||||
|         with self.assertWarnsRegex( | ||||
|             UserWarning, | ||||
|         with self.assertRaisesRegex( | ||||
|             ValueError, | ||||
|             "The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export", | ||||
|         ): | ||||
|             export(M(), (torch.randn(2, 3),), strict=False) | ||||
| @ -15159,39 +15071,6 @@ def forward(self, x): | ||||
|             test_serdes=True, | ||||
|         ) | ||||
|  | ||||
|     # TODO: following tests should be fixed | ||||
|     @testing.expectedFailureTrainingIRToRunDecomp | ||||
|     @testing.expectedFailureTrainingIRToRunDecompNonStrict | ||||
|     def test_preserve_annotation(self): | ||||
|         class M(torch.nn.Module): | ||||
|             def forward(self, x): | ||||
|                 with fx_traceback.annotate({"pp_stage": 0}): | ||||
|                     with fx_traceback.annotate({"fdsp_bucket": 0}): | ||||
|                         x = x + 1 | ||||
|                     x = x - 2 | ||||
|                     with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): | ||||
|                         x = x * 2 | ||||
|                 x = x / 3 | ||||
|                 return x | ||||
|  | ||||
|         m = M() | ||||
|  | ||||
|         with fx_traceback.preserve_node_meta(): | ||||
|             ep = export(m, (torch.randn(10),)) | ||||
|  | ||||
|         for node in ep.graph.nodes: | ||||
|             if node.target == torch.ops.aten.add.default: | ||||
|                 self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0}) | ||||
|             if node.target == torch.ops.aten.sub.default: | ||||
|                 self.assertTrue(node.meta["custom"], {"pp_stage": 0}) | ||||
|             if node.target == torch.ops.aten.mul.default: | ||||
|                 self.assertTrue( | ||||
|                     node.meta["custom"], | ||||
|                     {"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1}, | ||||
|                 ) | ||||
|             if node.target == torch.ops.aten.div.default: | ||||
|                 self.assertTrue(node.meta["custom"], {}) | ||||
|  | ||||
|     def test_dynamic_shapes_serdes_generic(self): | ||||
|         from torch._export.serde.dynamic_shapes import ( | ||||
|             _dump_dynamic_shapes, | ||||
| @ -15898,50 +15777,6 @@ class GraphModule(torch.nn.Module): | ||||
|             ] | ||||
|             self.assertEqual(len(shift_op), 1) | ||||
|  | ||||
|     def test_export_rnn_variants_with_warning(self): | ||||
|         """ | ||||
|         Test that when exporting RNN, LSTM, and GRU models in non-strict mode, it: | ||||
|  | ||||
|         1. Produces expected warnings about tensor attributes being assigned during export | ||||
|         2. Does not leak fake tensors in the model's flat weights | ||||
|         3. Does not produce extra tensor constants in the graph signature | ||||
|         """ | ||||
|         rnn_types = [ | ||||
|             (torch.nn.RNN, "RNN"), | ||||
|             (torch.nn.LSTM, "LSTM"), | ||||
|             (torch.nn.GRU, "GRU"), | ||||
|         ] | ||||
|  | ||||
|         for rnn_class, rnn_name in rnn_types: | ||||
|             with self.subTest(rnn_type=rnn_name): | ||||
|                 m = rnn_class( | ||||
|                     input_size=2, hidden_size=4, num_layers=1, batch_first=True | ||||
|                 ) | ||||
|                 sample_inputs = (torch.randn(1, 2, 2),) | ||||
|                 eager_out = m(*sample_inputs) | ||||
|  | ||||
|                 # Verify that export produces the expected warning about tensor attributes | ||||
|                 with self.assertWarnsRegex( | ||||
|                     UserWarning, | ||||
|                     r"The tensor attributes self\._flat_weights\[0\], self\._flat_weights\[1\], " | ||||
|                     r"self\._flat_weights\[2\], self\._flat_weights\[3\] were assigned during export.*", | ||||
|                 ): | ||||
|                     ep = torch.export.export(m, sample_inputs, strict=False) | ||||
|  | ||||
|                 ep_out = ep.module()(*sample_inputs) | ||||
|                 self.assertEqual(eager_out, ep_out) | ||||
|  | ||||
|                 # Verify no fake tensor leakage: flat weights should be real tensors | ||||
|                 for flat_weight in m._flat_weights: | ||||
|                     self.assertTrue( | ||||
|                         not isinstance( | ||||
|                             flat_weight, torch._subclasses.fake_tensor.FakeTensor | ||||
|                         ) | ||||
|                     ) | ||||
|  | ||||
|                 # Verify no tensor constants in graph signature | ||||
|                 self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 0) | ||||
|  | ||||
|     @contextmanager | ||||
|     def distributed_env(self, world_size): | ||||
|         try: | ||||
|  | ||||
| @ -57,6 +57,7 @@ fake_export_failures = { | ||||
|     xfail("nn.functional.grid_sample"), | ||||
|     xfail("to_sparse"), | ||||
|     # cannot xfail as it is passing for cpu-only build | ||||
|     skip("nn.functional.conv2d"), | ||||
|     skip("nn.functional.scaled_dot_product_attention"), | ||||
|     # following are failing due to OptionalDeviceGuard | ||||
|     xfail("__getitem__"), | ||||
| @ -80,6 +81,7 @@ def _test_export_helper(self, dtype, op): | ||||
|     sample_inputs_itr = op.sample_inputs("cpu", dtype, requires_grad=False) | ||||
|  | ||||
|     mode = FakeTensorMode(allow_non_fake_inputs=True) | ||||
|     converter = mode.fake_tensor_converter | ||||
|     # intentionally avoid cuda:0 to flush out some bugs | ||||
|     target_device = "cuda:1" | ||||
|  | ||||
|  | ||||
| @ -9,7 +9,6 @@ | ||||
| from contextlib import ExitStack | ||||
|  | ||||
| import torch | ||||
| import torch.fx.traceback as fx_traceback | ||||
| import torch.nn as nn | ||||
| import torch.utils._pytree as pytree | ||||
| from torch._decomp import decomposition_table | ||||
| @ -762,52 +761,6 @@ class inner_f(torch.nn.Module): | ||||
|         compiled_fn(*dict(model.named_parameters()).values(), inputs).sum().backward() | ||||
|         self.assertIsNotNone(model.linear.weight.grad) | ||||
|  | ||||
|     def test_preserve_annotate_simple(self): | ||||
|         """Test basic linear module with aot_export_joint_with_descriptors""" | ||||
|  | ||||
|         class SimpleLinear(nn.Module): | ||||
|             def __init__(self): | ||||
|                 super().__init__() | ||||
|                 self.linear = nn.Linear(3, 2) | ||||
|  | ||||
|             def forward(self, x): | ||||
|                 with fx_traceback.annotate({"pp_stage": 0}): | ||||
|                     y = self.linear(x) | ||||
|                 return y - 1 | ||||
|  | ||||
|         inputs = (torch.randn(4, 3),) | ||||
|  | ||||
|         for with_export in [True, False]: | ||||
|             with ExitStack() as stack: | ||||
|                 model = None | ||||
|                 with fx_traceback.preserve_node_meta(): | ||||
|                     if with_export: | ||||
|                         ep = torch.export.export(SimpleLinear(), inputs) | ||||
|                         model = ep.module() | ||||
|                     else: | ||||
|                         model = SimpleLinear() | ||||
|  | ||||
|                     joint_with_descriptors = aot_export_joint_with_descriptors( | ||||
|                         stack, model, inputs, decompositions=decomposition_table | ||||
|                     ) | ||||
|  | ||||
|                 for node in joint_with_descriptors.graph_module.graph.nodes: | ||||
|                     if ( | ||||
|                         node.target | ||||
|                         in ( | ||||
|                             torch.ops.prims.transpose.default, | ||||
|                             torch.ops.aten.mm.default, | ||||
|                             torch.ops.prims.mul.default, | ||||
|                             torch.ops.prims.broadcast_in_dim.default, | ||||
|                             torch.ops.prims.add.default, | ||||
|                         ) | ||||
|                         # TODO: add annotation to backward graph nodes | ||||
|                         and node.meta.get("partitioner_tag") != "is_backward" | ||||
|                     ): | ||||
|                         self.assertTrue(node.meta["custom"], {"pp_stage": 0}) | ||||
|                     if node.target == torch.ops.aten.sub.default: | ||||
|                         self.assertTrue(node.meta.get("custom", {}), {}) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
|  | ||||
| @ -8500,6 +8500,7 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): | ||||
|         { | ||||
|             "enable_autograd_cache": True, | ||||
|             "strict_autograd_cache": True, | ||||
|             "view_replay_for_aliased_outputs": False, | ||||
|         } | ||||
|     ) | ||||
|     @torch._inductor.config.patch("fx_graph_cache", True) | ||||
|  | ||||
| @ -5074,52 +5074,6 @@ class AOTInductorTestsTemplate: | ||||
|  | ||||
|             self.check_model(Model(N, K, self.device), example_inputs) | ||||
|  | ||||
|     def test_aoti_user_defined_triton_kernel_profiling(self): | ||||
|         if self.device != GPU_TYPE or self.device == "mps": | ||||
|             raise unittest.SkipTest("requires GPU") | ||||
|  | ||||
|         class Model(torch.nn.Module): | ||||
|             def __init__(self) -> None: | ||||
|                 super().__init__() | ||||
|  | ||||
|             def forward(self, x, y): | ||||
|                 out = torch.zeros_like(x) | ||||
|                 add_kernel[(4,)](x, y, out, n_elements=4, BLOCK_SIZE=16) | ||||
|                 return out | ||||
|  | ||||
|         example_inputs = ( | ||||
|             torch.randn(4, 4, device=self.device), | ||||
|             torch.randn(4, 4, device=self.device), | ||||
|         ) | ||||
|  | ||||
|         with ( | ||||
|             config.patch({"cpp.enable_kernel_profile": True}), | ||||
|             torch.profiler.profile( | ||||
|                 record_shapes=True, | ||||
|                 activities=[ | ||||
|                     torch.profiler.ProfilerActivity.CPU, | ||||
|                     torch.profiler.ProfilerActivity.CUDA, | ||||
|                 ], | ||||
|             ) as prof, | ||||
|         ): | ||||
|             self.check_model(Model(), example_inputs) | ||||
|         with common_utils.TemporaryFileName(mode="w+") as fname: | ||||
|             prof.export_chrome_trace(fname) | ||||
|             with open(fname) as f: | ||||
|                 import json | ||||
|  | ||||
|                 j = json.load(f) | ||||
|                 op_events = [ | ||||
|                     e | ||||
|                     for e in j["traceEvents"] | ||||
|                     if e.get("name", "") == "kernels_.add_kernel_0" | ||||
|                 ] | ||||
|                 self.assertEqual(len(op_events), 1) | ||||
|                 self.assertEqual( | ||||
|                     op_events[0]["args"].get("Input Args", ""), | ||||
|                     ["in_ptr0", "in_ptr1", "out_ptr", "n_elements"], | ||||
|                 ) | ||||
|  | ||||
|     def test_aoti_debug_printer_user_defined_triton_kernel(self): | ||||
|         if self.device != GPU_TYPE: | ||||
|             raise unittest.SkipTest("requires GPU") | ||||
| @ -7194,37 +7148,6 @@ class AOTInductorTestsTemplate: | ||||
|                         for lib in torch_libs: | ||||
|                             self.assertTrue(lib not in line) | ||||
|  | ||||
|     def test_unbounded_expr_substitutions(self): | ||||
|         class Model(torch.nn.Module): | ||||
|             def forward(self, x, y, a, b): | ||||
|                 u0, s0 = a.item(), b.item() | ||||
|                 u_max = max(u0, 15) | ||||
|                 # construct the equality rule Max(15, u0) == s0 * Max(15, u0) | ||||
|                 torch._check(u_max == s0 * u_max) | ||||
|                 # size x - [Max(u0, 15), 64] | ||||
|                 x = x.expand(u_max, *x.shape).clone() | ||||
|                 return x @ y | ||||
|  | ||||
|         model = Model() | ||||
|  | ||||
|         example_inputs = ( | ||||
|             torch.randn((64,), dtype=torch.bfloat16, device=self.device), | ||||
|             torch.randn((64, 16), dtype=torch.bfloat16, device=self.device), | ||||
|             torch.tensor(19, device=self.device), | ||||
|             torch.tensor(1, device=self.device), | ||||
|         ) | ||||
|         torch._dynamo.mark_dynamic(example_inputs[-1], 0) | ||||
|  | ||||
|         so_path, code = run_and_get_cpp_code( | ||||
|             AOTIRunnerUtil.legacy_compile, model, example_inputs | ||||
|         ) | ||||
|  | ||||
|         compiled = AOTIRunnerUtil.legacy_load(self.device, so_path) | ||||
|         compiled_outputs = compiled(*example_inputs) | ||||
|  | ||||
|         eager_outputs = model(*example_inputs) | ||||
|         torch.testing.assert_close(eager_outputs, compiled_outputs) | ||||
|  | ||||
|  | ||||
| class AOTInductorLoggingTest(LoggingTestCase): | ||||
|     @make_logging_test(dynamic=logging.DEBUG) | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	