mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	2023-12-04 nightly release (3fbfa8cd0a5cefadb3f116c5cd0d60e96ab8c99e)
This commit is contained in:
		@ -1 +1 @@
 | 
			
		||||
d96c49e878defa7cccf976b30ae67bc7b20f531a
 | 
			
		||||
5f72802710a941d53a49b28ae0c7386332d41935
 | 
			
		||||
 | 
			
		||||
@ -1 +1 @@
 | 
			
		||||
6e4932cda85fb9128728072c821bd54410a9b5be
 | 
			
		||||
bcad9dabe15021c53b6a88296e9d7a210044f108
 | 
			
		||||
 | 
			
		||||
@ -96,13 +96,7 @@ if [[ "$PACKAGE_TYPE" == conda ]]; then
 | 
			
		||||
    conda install \${EXTRA_CONDA_FLAGS} -y "\$pkg" --offline
 | 
			
		||||
  )
 | 
			
		||||
elif [[ "$PACKAGE_TYPE" != libtorch ]]; then
 | 
			
		||||
  if [[ "$(uname -m)" == aarch64 ]]; then
 | 
			
		||||
    # Using "extra-index-url" until all needed aarch64 dependencies are
 | 
			
		||||
    # added to "https://download.pytorch.org/whl/"
 | 
			
		||||
    pip install "\$pkg" --extra-index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
 | 
			
		||||
  else
 | 
			
		||||
    pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
 | 
			
		||||
  fi
 | 
			
		||||
  pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
 | 
			
		||||
  retry pip install -q numpy protobuf typing-extensions
 | 
			
		||||
fi
 | 
			
		||||
if [[ "$PACKAGE_TYPE" == libtorch ]]; then
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
db624844f5c95bb7618fe5a5f532bf9b68efeb45
 | 
			
		||||
6518fa9b2c74e84d7eb1fc6e3eb51e43213f0c05
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										18
									
								
								.github/scripts/build_triton_wheel.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										18
									
								
								.github/scripts/build_triton_wheel.py
									
									
									
									
										vendored
									
									
								
							@ -67,10 +67,12 @@ def build_triton(
 | 
			
		||||
        max_jobs = os.cpu_count() or 1
 | 
			
		||||
        env["MAX_JOBS"] = str(max_jobs)
 | 
			
		||||
 | 
			
		||||
    version_suffix = ""
 | 
			
		||||
    if not release:
 | 
			
		||||
        # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
 | 
			
		||||
        # while release build should only include the version, i.e. 2.1.0
 | 
			
		||||
        version = f"{version}+{commit_hash[:10]}"
 | 
			
		||||
        version_suffix = f"+{commit_hash[:10]}"
 | 
			
		||||
        version += version_suffix
 | 
			
		||||
 | 
			
		||||
    with TemporaryDirectory() as tmpdir:
 | 
			
		||||
        triton_basedir = Path(tmpdir) / "triton"
 | 
			
		||||
@ -132,17 +134,21 @@ def build_triton(
 | 
			
		||||
            shutil.copy(conda_path, Path.cwd())
 | 
			
		||||
            return Path.cwd() / conda_path.name
 | 
			
		||||
 | 
			
		||||
        patch_setup_py(
 | 
			
		||||
            triton_pythondir / "setup.py",
 | 
			
		||||
            name=triton_pkg_name,
 | 
			
		||||
            version=f"{version}",
 | 
			
		||||
        )
 | 
			
		||||
        # change built wheel name and version
 | 
			
		||||
        env["TRITON_WHEEL_NAME"] = triton_pkg_name
 | 
			
		||||
        env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
 | 
			
		||||
        patch_init_py(
 | 
			
		||||
            triton_pythondir / "triton" / "__init__.py",
 | 
			
		||||
            version=f"{version}",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if build_rocm:
 | 
			
		||||
            # TODO: Remove me when ROCM triton is updated
 | 
			
		||||
            patch_setup_py(
 | 
			
		||||
                triton_pythondir / "setup.py",
 | 
			
		||||
                name=triton_pkg_name,
 | 
			
		||||
                version=f"{version}",
 | 
			
		||||
            )
 | 
			
		||||
            check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
 | 
			
		||||
            print("ROCm libraries setup for triton installation...")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -174,7 +174,7 @@ LIBTORCH_CONTAINER_IMAGES: Dict[Tuple[str, str], str] = {
 | 
			
		||||
    ("cpu", CXX11_ABI): f"pytorch/libtorch-cxx11-builder:cpu-{DEFAULT_TAG}",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FULL_PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11"]
 | 
			
		||||
FULL_PYTHON_VERSIONS = ["3.8", "3.9", "3.10", "3.11", "3.12"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
 | 
			
		||||
@ -288,7 +288,7 @@ def generate_wheels_matrix(
 | 
			
		||||
        package_type = "manywheel"
 | 
			
		||||
 | 
			
		||||
    if python_versions is None:
 | 
			
		||||
        python_versions = FULL_PYTHON_VERSIONS + ["3.12"]
 | 
			
		||||
        python_versions = FULL_PYTHON_VERSIONS
 | 
			
		||||
 | 
			
		||||
    if arches is None:
 | 
			
		||||
        # Define default compute archivectures
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										11
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -4,7 +4,6 @@ on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
      - nightly
 | 
			
		||||
    tags:
 | 
			
		||||
      # NOTE: Binary build pipelines should only get triggered on release candidate builds
 | 
			
		||||
      # Release candidate tags look like: v1.11.0-rc1
 | 
			
		||||
@ -134,7 +133,7 @@ jobs:
 | 
			
		||||
    needs: build-wheel
 | 
			
		||||
    container:
 | 
			
		||||
      image: continuumio/miniconda3:4.12.0
 | 
			
		||||
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
 | 
			
		||||
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
@ -145,7 +144,7 @@ jobs:
 | 
			
		||||
          path: ${{ runner.temp }}/artifacts/
 | 
			
		||||
 | 
			
		||||
      - name: Set DRY_RUN (only for tagged pushes)
 | 
			
		||||
        if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || (startsWith(github.event.ref, 'refs/tags/v'))) }}
 | 
			
		||||
        if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) }}
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "DRY_RUN=disabled" >> "$GITHUB_ENV"
 | 
			
		||||
@ -161,7 +160,7 @@ jobs:
 | 
			
		||||
            echo "UPLOAD_CHANNEL=test" >> "$GITHUB_ENV"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      # NB: This step is gated by DRY_RUN, which is enabled everywhere except nightly and release branches
 | 
			
		||||
      # NB: This step is gated by DRY_RUN, which is enabled everywhere except main and release branches
 | 
			
		||||
      - name: Upload binaries
 | 
			
		||||
        env:
 | 
			
		||||
          PACKAGE_TYPE: wheel
 | 
			
		||||
@ -247,7 +246,7 @@ jobs:
 | 
			
		||||
    needs: build-conda
 | 
			
		||||
    container:
 | 
			
		||||
      image: continuumio/miniconda3:4.12.0
 | 
			
		||||
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
 | 
			
		||||
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
@ -258,7 +257,7 @@ jobs:
 | 
			
		||||
          path: ${{ runner.temp }}/artifacts/
 | 
			
		||||
 | 
			
		||||
      - name: Set DRY_RUN (only for tagged pushes)
 | 
			
		||||
        if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || (startsWith(github.event.ref, 'refs/tags/v'))) }}
 | 
			
		||||
        if: ${{ github.event_name == 'push' && (github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) }}
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "DRY_RUN=disabled" >> "$GITHUB_ENV"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										182
									
								
								.github/workflows/generated-linux-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										182
									
								
								.github/workflows/generated-linux-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -763,3 +763,185 @@ jobs:
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  conda-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cpu-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cpu-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-build
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cpu-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
      runs_on: linux.4xlarge
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cpu-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  conda-py3_12-cuda11_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      runs_on: linux.24xlarge
 | 
			
		||||
      build_name: conda-py3_12-cuda11_8
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cuda11_8-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda11_8-build
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda11_8
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
      runs_on: linux.4xlarge.nvidia.gpu
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cuda11_8-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda11_8-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda11.8-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda11_8
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  conda-py3_12-cuda12_1-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      runs_on: linux.24xlarge
 | 
			
		||||
      build_name: conda-py3_12-cuda12_1
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cuda12_1-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda12_1-build
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda12_1
 | 
			
		||||
      build_environment: linux-binary-conda
 | 
			
		||||
      runs_on: linux.4xlarge.nvidia.gpu
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  conda-py3_12-cuda12_1-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda12_1-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cuda12.1-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda12_1
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										117
									
								
								.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										117
									
								
								.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -502,3 +502,120 @@ jobs:
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  conda-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      # For sccache access (only on non-forked PRs)
 | 
			
		||||
      AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
 | 
			
		||||
      AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Install conda and dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on
 | 
			
		||||
          curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh"
 | 
			
		||||
          chmod +x "${RUNNER_TEMP}/conda.sh"
 | 
			
		||||
          /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda"
 | 
			
		||||
          echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}"
 | 
			
		||||
          if [ -d "/Applications/Xcode_14.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          elif [ -d "/Applications/Xcode_13.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          fi
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Install sccache (only for non-forked PRs, and pushes to trunk)
 | 
			
		||||
        uses: nick-fields/retry@v2.8.2
 | 
			
		||||
        if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }}
 | 
			
		||||
        with:
 | 
			
		||||
          timeout_minutes: 5
 | 
			
		||||
          max_attempts: 3
 | 
			
		||||
          retry_wait_seconds: 90
 | 
			
		||||
          command: |
 | 
			
		||||
            sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache
 | 
			
		||||
            sudo chmod +x /usr/local/bin/sccache
 | 
			
		||||
            echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v3
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cpu
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
  conda-py3_12-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-build
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cpu-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
      use_s3: False
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										117
									
								
								.github/workflows/generated-macos-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										117
									
								
								.github/workflows/generated-macos-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -500,3 +500,120 @@ jobs:
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  conda-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      # For sccache access (only on non-forked PRs)
 | 
			
		||||
      AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }}
 | 
			
		||||
      AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }}
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Install conda and dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on
 | 
			
		||||
          curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh"
 | 
			
		||||
          chmod +x "${RUNNER_TEMP}/conda.sh"
 | 
			
		||||
          /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda"
 | 
			
		||||
          echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}"
 | 
			
		||||
          if [ -d "/Applications/Xcode_14.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          elif [ -d "/Applications/Xcode_13.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          fi
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Install sccache (only for non-forked PRs, and pushes to trunk)
 | 
			
		||||
        uses: nick-fields/retry@v2.8.2
 | 
			
		||||
        if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }}
 | 
			
		||||
        with:
 | 
			
		||||
          timeout_minutes: 5
 | 
			
		||||
          max_attempts: 3
 | 
			
		||||
          retry_wait_seconds: 90
 | 
			
		||||
          command: |
 | 
			
		||||
            sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache
 | 
			
		||||
            sudo chmod +x /usr/local/bin/sccache
 | 
			
		||||
            echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v3
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cpu
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
  conda-py3_12-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-build
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      BUILDER_ROOT: /builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: pytorch/conda-builder:cpu-main
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
      use_s3: False
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										729
									
								
								.github/workflows/generated-windows-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										729
									
								
								.github/workflows/generated-windows-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -2948,3 +2948,732 @@ jobs:
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  conda-py3_12-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: windows.4xlarge.nonephemeral
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v3
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cpu
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cpu-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-build
 | 
			
		||||
    runs-on: windows.4xlarge.nonephemeral
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v3
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cpu
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cpu-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cpu
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  conda-py3_12-cuda11_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: windows.4xlarge.nonephemeral
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v3
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cuda11_8
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cuda11_8-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda11_8-build
 | 
			
		||||
    runs-on: windows.8xlarge.nvidia.gpu
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v3
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cuda11_8
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cuda11_8-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda11_8-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu118
 | 
			
		||||
      GPU_ARCH_VERSION: 11.8
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda11_8
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  conda-py3_12-cuda12_1-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: windows.4xlarge.nonephemeral
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
 | 
			
		||||
      - uses: actions/upload-artifact@v3
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cuda12_1
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cuda12_1-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda12_1-build
 | 
			
		||||
    runs-on: windows.8xlarge.nvidia.gpu
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Display EC2 information
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          set -euo pipefail
 | 
			
		||||
          function get_ec2_metadata() {
 | 
			
		||||
            # Pulled from instance metadata endpoint for EC2
 | 
			
		||||
            # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
 | 
			
		||||
            category=$1
 | 
			
		||||
            curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
 | 
			
		||||
          }
 | 
			
		||||
          echo "ami-id: $(get_ec2_metadata ami-id)"
 | 
			
		||||
          echo "instance-id: $(get_ec2_metadata instance-id)"
 | 
			
		||||
          echo "instance-type: $(get_ec2_metadata instance-type)"
 | 
			
		||||
          echo "system info $(uname -a)"
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/setup-ssh@main
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        with:
 | 
			
		||||
          github-secret: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
 | 
			
		||||
      - name: Enable long paths on Windows
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
 | 
			
		||||
      # Since it's just a defensive command, the workflow should continue even the command fails. This step can be
 | 
			
		||||
      # removed once Windows Defender is removed from the AMI
 | 
			
		||||
      - name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        run: |
 | 
			
		||||
          Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
 | 
			
		||||
          # Let's both exclude the path and disable Windows Defender completely just to be sure
 | 
			
		||||
          # that it doesn't interfere
 | 
			
		||||
          Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
 | 
			
		||||
      - uses: actions/download-artifact@v3
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: conda-py3_12-cuda12_1
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Checkout pytorch/builder
 | 
			
		||||
        uses: malfet/checkout@silent-checkout
 | 
			
		||||
        with:
 | 
			
		||||
          ref: main
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          repository: pytorch/builder
 | 
			
		||||
          path: builder
 | 
			
		||||
          quiet-checkout: true
 | 
			
		||||
      - name: Clean pytorch/builder checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: builder
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Test PyTorch binary
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
 | 
			
		||||
      - name: Wait until all sessions have drained
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        timeout-minutes: 120
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\wait_for_ssh_to_drain.ps1
 | 
			
		||||
      - name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
 | 
			
		||||
        shell: powershell
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
        if: always()
 | 
			
		||||
        run: |
 | 
			
		||||
          .github\scripts\kill_active_ssh_sessions.ps1
 | 
			
		||||
  conda-py3_12-cuda12_1-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs: conda-py3_12-cuda12_1-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      BUILDER_ROOT: ${{ github.workspace }}/builder
 | 
			
		||||
      PACKAGE_TYPE: conda
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu121
 | 
			
		||||
      GPU_ARCH_VERSION: 12.1
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: conda-py3_12-cuda12_1
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
      aws-pytorch-uploader-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }}
 | 
			
		||||
      aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }}
 | 
			
		||||
      conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
 | 
			
		||||
      conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
@ -118,7 +118,7 @@ hf_T5,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_generate,pass,20
 | 
			
		||||
hf_T5_generate,pass,18
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -118,7 +118,7 @@ hf_T5,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_generate,fail_to_run,10
 | 
			
		||||
hf_T5_generate,fail_to_run,9
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -102,7 +102,7 @@ hf_Reformer,pass,27
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_base,OOM,7
 | 
			
		||||
hf_T5_base,pass,7
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -118,7 +118,7 @@ hf_T5,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_generate,fail_to_run,10
 | 
			
		||||
hf_T5_generate,fail_to_run,9
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -118,7 +118,7 @@ hf_T5,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_generate,pass,20
 | 
			
		||||
hf_T5_generate,pass,18
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -118,7 +118,7 @@ hf_T5,pass,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
hf_T5_generate,pass,20
 | 
			
		||||
hf_T5_generate,pass,18
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -69,6 +69,7 @@ BATCH_SIZE_DIVISORS = {
 | 
			
		||||
 | 
			
		||||
REQUIRE_HIGHER_TOLERANCE = {
 | 
			
		||||
    "fbnetv3_b",
 | 
			
		||||
    "gmixer_24_224",
 | 
			
		||||
    "hrnet_w18",
 | 
			
		||||
    "inception_v3",
 | 
			
		||||
    "sebotnet33ts_256",
 | 
			
		||||
 | 
			
		||||
@ -1461,7 +1461,6 @@ def define_buck_targets(
 | 
			
		||||
            "torch/csrc/jit/mobile/train/random.cpp",
 | 
			
		||||
            "torch/csrc/jit/mobile/train/sequential.cpp",
 | 
			
		||||
            ":gen_aten_libtorch[autograd/generated/Functions.cpp]",
 | 
			
		||||
            "torch/csrc/quantized/quantized_backward.cpp",
 | 
			
		||||
        ],
 | 
			
		||||
        compiler_flags = get_pt_compiler_flags(),
 | 
			
		||||
        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
 | 
			
		||||
 | 
			
		||||
@ -923,12 +923,12 @@ Here is an example that shows logging modes of each type::
 | 
			
		||||
  class FunctionLog(TorchFunctionMode):
 | 
			
		||||
      def __torch_function__(self, func, types, args, kwargs=None):
 | 
			
		||||
          print(f"Function Log: {resolve_name(func)}(*{args}, **{kwargs})")
 | 
			
		||||
          return func(*args, **kwargs or {})
 | 
			
		||||
          return func(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
  class DispatchLog(TorchDispatchMode):
 | 
			
		||||
      def __torch_dispatch__(self, func, types, args, kwargs=None):
 | 
			
		||||
          print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
 | 
			
		||||
          return func(*args, **kwargs or {})
 | 
			
		||||
          return func(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
  def f():
 | 
			
		||||
      a = torch.rand(10, requires_grad=True)
 | 
			
		||||
 | 
			
		||||
@ -85,6 +85,9 @@ select = [
 | 
			
		||||
    "PIE804",
 | 
			
		||||
    "PIE807",
 | 
			
		||||
    "PIE810",
 | 
			
		||||
    "PLC0131", # type bivariance
 | 
			
		||||
    "PLC0132", # type param mismatch
 | 
			
		||||
    "PLC0205", # string as __slots__
 | 
			
		||||
    "PLE",
 | 
			
		||||
    "PLR0133", # constant comparison
 | 
			
		||||
    "PLR0206", # property with params
 | 
			
		||||
 | 
			
		||||
@ -2679,6 +2679,18 @@ TEST_F(ModulesTest, AdaptiveLogSoftmaxWithLoss) {
 | 
			
		||||
    ASSERT_TRUE(
 | 
			
		||||
        torch::allclose(asfm(x, y).output.squeeze(0), asfm(x2, y2).output));
 | 
			
		||||
  }
 | 
			
		||||
  {
 | 
			
		||||
    // test div_value
 | 
			
		||||
    auto options =
 | 
			
		||||
        AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.);
 | 
			
		||||
    ASSERT_THROWS_WITH(
 | 
			
		||||
        AdaptiveLogSoftmaxWithLoss(options),
 | 
			
		||||
        "div_value should not be equal to 0");
 | 
			
		||||
 | 
			
		||||
    options =
 | 
			
		||||
        AdaptiveLogSoftmaxWithLossOptions(16, 20, {4, 10, 15}).div_value(0.25);
 | 
			
		||||
    ASSERT_TRUE(AdaptiveLogSoftmaxWithLoss(options));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ModulesTest, Softmax2d) {
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,7 @@ def make_dynamic_cls(cls):
 | 
			
		||||
    test_classes[test_class.__name__] = test_class
 | 
			
		||||
    # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING
 | 
			
		||||
    globals()[test_class.__name__] = test_class
 | 
			
		||||
    test_class.__module__ = __name__
 | 
			
		||||
    return test_class
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -745,6 +745,9 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
            dd["a"] = x + 1
 | 
			
		||||
            dd[param] = 123
 | 
			
		||||
            dd["c"] = x * 2
 | 
			
		||||
            dd.update({"b": x * 3})
 | 
			
		||||
            dd.update([["d", x - 2], ("e", x + 2)])
 | 
			
		||||
            dd.update(zip("ab", [x + 3, x + 4]))
 | 
			
		||||
            return dd["b"], dd
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(10, 10)
 | 
			
		||||
@ -754,7 +757,10 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(same(ref[0], res[0]))
 | 
			
		||||
        self.assertTrue(same(ref[1]["a"], res[1]["a"]))
 | 
			
		||||
        self.assertTrue(same(ref[1]["b"], res[1]["b"]))
 | 
			
		||||
        self.assertTrue(same(ref[1]["c"], res[1]["c"]))
 | 
			
		||||
        self.assertTrue(same(ref[1]["d"], res[1]["d"]))
 | 
			
		||||
        self.assertTrue(same(ref[1]["e"], res[1]["e"]))
 | 
			
		||||
        self.assertTrue(same(ref[1][param], res[1][param]))
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
@ -802,6 +808,42 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
 | 
			
		||||
        d2["c"] = x + 20
 | 
			
		||||
        return d1["a"] + d2["c"] + 1
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_dict_fromkeys(x, y):
 | 
			
		||||
        lst = ["a", "b"]
 | 
			
		||||
        d = dict.fromkeys(lst)
 | 
			
		||||
        d1 = dict.fromkeys(d, x + 1)
 | 
			
		||||
        d2 = collections.defaultdict.fromkeys(iter(d1), x - 2)
 | 
			
		||||
        d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
 | 
			
		||||
        return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_dict_copy(x):
 | 
			
		||||
        my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
 | 
			
		||||
        d1 = dict(my_list)
 | 
			
		||||
        d1["a"] = x + 10
 | 
			
		||||
        d2 = d1.copy()
 | 
			
		||||
        d2["a"] = x - 5
 | 
			
		||||
        d2["b"] = x + 3
 | 
			
		||||
        d3 = collections.OrderedDict(my_list)
 | 
			
		||||
        d3["c"] = x + 20
 | 
			
		||||
        d4 = d3.copy()
 | 
			
		||||
        d4["c"] = x - 10
 | 
			
		||||
        return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_dict_update(x, y, z):
 | 
			
		||||
        d = {"a": x, "b": y}
 | 
			
		||||
        d.update({"a": y - 1})
 | 
			
		||||
        d.update([("b", z + 1), ["c", z]])
 | 
			
		||||
        d.update(zip("ab", [z + 3, y + 2]))
 | 
			
		||||
 | 
			
		||||
        od = collections.OrderedDict(a=x * 3, b=y + 2)
 | 
			
		||||
        od.update({"a": y + 5})
 | 
			
		||||
        od.update([["b", z + 6], ("c", z - 7)])
 | 
			
		||||
        od.update(zip("ab", [z - 3, x + 2]))
 | 
			
		||||
        return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"]
 | 
			
		||||
 | 
			
		||||
    @make_test
 | 
			
		||||
    def test_min_max(a, b):
 | 
			
		||||
        c = a + b
 | 
			
		||||
 | 
			
		||||
@ -2940,6 +2940,52 @@ class <lambda>(torch.nn.Module):
 | 
			
		||||
        ):
 | 
			
		||||
            aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)
 | 
			
		||||
 | 
			
		||||
    @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run")
 | 
			
		||||
    def test_aot_export_with_torch_cond(self):
 | 
			
		||||
        class M(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                def true_fn(x):
 | 
			
		||||
                    y = x + 4
 | 
			
		||||
                    y.add_(5)
 | 
			
		||||
                    return x.cos()
 | 
			
		||||
 | 
			
		||||
                def false_fn(x):
 | 
			
		||||
                    y = x + 5
 | 
			
		||||
                    y.add_(6)
 | 
			
		||||
                    return x.sin()
 | 
			
		||||
 | 
			
		||||
                a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x])
 | 
			
		||||
                return (a + 3, a + 4)
 | 
			
		||||
 | 
			
		||||
        inp = torch.randn(3, 4)
 | 
			
		||||
        gm, _ = aot_export_module(M(), (inp,), trace_joint=False)
 | 
			
		||||
        self.assertExpectedInline(gm.code.strip(), """\
 | 
			
		||||
def forward(self, arg0_1):
 | 
			
		||||
    true_graph_0 = self.true_graph_0
 | 
			
		||||
    false_graph_0 = self.false_graph_0
 | 
			
		||||
    conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]);  true_graph_0 = false_graph_0 = arg0_1 = None
 | 
			
		||||
    getitem = conditional[0];  conditional = None
 | 
			
		||||
    add = torch.ops.aten.add.Tensor(getitem, 3)
 | 
			
		||||
    add_1 = torch.ops.aten.add.Tensor(getitem, 4);  getitem = None
 | 
			
		||||
    return (add, add_1)""")  # noqa: B950
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(gm.true_graph_0.code.strip(), """\
 | 
			
		||||
def forward(self, arg0_1):
 | 
			
		||||
    add = torch.ops.aten.add.Tensor(arg0_1, 4)
 | 
			
		||||
    add_1 = torch.ops.aten.add.Tensor(add, 5);  add = None
 | 
			
		||||
    cos = torch.ops.aten.cos.default(arg0_1);  arg0_1 = None
 | 
			
		||||
    return (cos,)""")
 | 
			
		||||
 | 
			
		||||
        self.assertExpectedInline(gm.false_graph_0.code.strip(), """\
 | 
			
		||||
def forward(self, arg0_1):
 | 
			
		||||
    add = torch.ops.aten.add.Tensor(arg0_1, 5)
 | 
			
		||||
    add_1 = torch.ops.aten.add.Tensor(add, 6);  add = None
 | 
			
		||||
    sin = torch.ops.aten.sin.default(arg0_1);  arg0_1 = None
 | 
			
		||||
    return (sin,)""")
 | 
			
		||||
 | 
			
		||||
    def test_aot_export_simplified_input_mutations_banned(self):
 | 
			
		||||
        def fn(x):
 | 
			
		||||
            x.mul_(2)
 | 
			
		||||
 | 
			
		||||
@ -1422,6 +1422,27 @@ class AOTInductorTestsTemplate:
 | 
			
		||||
 | 
			
		||||
        self.check_model(Model(), inputs)
 | 
			
		||||
 | 
			
		||||
    def test_convolution(self):
 | 
			
		||||
        class Model(torch.nn.Module):
 | 
			
		||||
            def __init__(self):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
 | 
			
		||||
            def forward(self, x, w, b):
 | 
			
		||||
                return torch.ops.aten.convolution(x, w, b, [4], [0], [1], True, [0], 1)
 | 
			
		||||
 | 
			
		||||
        example_inputs = (
 | 
			
		||||
            torch.randn([2, 32, 90], device=self.device),
 | 
			
		||||
            torch.randn([32, 16, 8], device=self.device),
 | 
			
		||||
            torch.randn([16], device=self.device),
 | 
			
		||||
        )
 | 
			
		||||
        with config.patch(
 | 
			
		||||
            {
 | 
			
		||||
                "max_autotune": True,
 | 
			
		||||
                "max_autotune_gemm_backends": "Triton",
 | 
			
		||||
            }
 | 
			
		||||
        ):
 | 
			
		||||
            self.check_model(Model(), example_inputs)
 | 
			
		||||
 | 
			
		||||
    def test_zero_size_weight(self):
 | 
			
		||||
        class Model(torch.nn.Module):
 | 
			
		||||
            def __init__(self, channel, r=8):
 | 
			
		||||
@ -1470,8 +1491,8 @@ copy_tests(
 | 
			
		||||
        # TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
 | 
			
		||||
        #   NotImplementedError: Cannot access storage of OpaqueTensorImpl
 | 
			
		||||
        "test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
 | 
			
		||||
        # Need to support convolution
 | 
			
		||||
        "test_missing_cubin": TestFailure(("abi_compatible_cpu",)),
 | 
			
		||||
        # FIXME: failed with Segfault while exiting the Python runtime
 | 
			
		||||
        "test_missing_cubin": TestFailure(("abi_compatible_cpu",), is_skip=True),
 | 
			
		||||
        "test_normal_functional": TestFailure(("abi_compatible_cpu",)),
 | 
			
		||||
        "test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
 | 
			
		||||
        # There is a double-free issue which will be fixed in another PR
 | 
			
		||||
@ -1501,8 +1522,6 @@ copy_tests(
 | 
			
		||||
    # test_failures, xfail by default, set is_skip=True to skip
 | 
			
		||||
    {
 | 
			
		||||
        "test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
 | 
			
		||||
        # Need to support convolution
 | 
			
		||||
        "test_missing_cubin": TestFailure(("abi_compatible_cuda",)),
 | 
			
		||||
        "test_normal_functional": TestFailure(("abi_compatible_cuda",)),
 | 
			
		||||
        # There is a double-free issue which will be fixed in another PR
 | 
			
		||||
        "test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
 | 
			
		||||
 | 
			
		||||
@ -2423,6 +2423,21 @@ class CPUReproTests(TestCase):
 | 
			
		||||
            self.assertFalse("= as_strided(" in code)
 | 
			
		||||
            self.assertEqual(run(*v), mod(*v))
 | 
			
		||||
 | 
			
		||||
    def test_invalid_dropout_args(self):
 | 
			
		||||
        class MyModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                x = x * 2
 | 
			
		||||
                x = torch.nn.functional.dropout(x, p=0.5)
 | 
			
		||||
                x = torch.relu(x)
 | 
			
		||||
                return x
 | 
			
		||||
 | 
			
		||||
        example_inputs = torch.tensor([[1, 2, 3], [4, 5, 6]])
 | 
			
		||||
 | 
			
		||||
        func = MyModel()
 | 
			
		||||
        jit_func = torch.compile(func)
 | 
			
		||||
        self.assertRaises(RuntimeError, lambda: func(example_inputs))
 | 
			
		||||
        self.assertRaises(RuntimeError, lambda: jit_func(example_inputs))
 | 
			
		||||
 | 
			
		||||
    @config.patch(inplace_buffers=True)
 | 
			
		||||
    def test_in_out_buffer(self):
 | 
			
		||||
        def fn(x, y):
 | 
			
		||||
@ -2463,6 +2478,55 @@ class CPUReproTests(TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(metrics.generated_kernel_count, 1)
 | 
			
		||||
 | 
			
		||||
    def test_attention_size_mismatch(self):
 | 
			
		||||
        class Attention(torch.nn.Module):
 | 
			
		||||
            def __init__(self, hidden_size, num_heads):
 | 
			
		||||
                super().__init__()
 | 
			
		||||
                self.hidden_size = hidden_size
 | 
			
		||||
                self.num_heads = num_heads
 | 
			
		||||
                self.head_size = hidden_size // num_heads
 | 
			
		||||
                self.query = torch.nn.Linear(hidden_size, hidden_size)
 | 
			
		||||
                self.key = torch.nn.Linear(hidden_size, hidden_size)
 | 
			
		||||
                self.value = torch.nn.Linear(hidden_size, hidden_size)
 | 
			
		||||
                self.inv_scale = torch.nn.Parameter(
 | 
			
		||||
                    torch.Tensor([1 / self.head_size**0.5]), requires_grad=False
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                query = self.query(x)
 | 
			
		||||
                key = self.key(x)
 | 
			
		||||
                value = self.value(x)
 | 
			
		||||
                (batch_size, seq_len, hidden_size) = query.size()
 | 
			
		||||
                query = query.view(
 | 
			
		||||
                    batch_size, seq_len, self.num_heads, self.head_size
 | 
			
		||||
                ).permute(0, 2, 1, 3)
 | 
			
		||||
                key = key.view(
 | 
			
		||||
                    batch_size, seq_len, self.num_heads, self.head_size
 | 
			
		||||
                ).permute(0, 2, 3, 1)
 | 
			
		||||
                value = value.view(
 | 
			
		||||
                    batch_size, seq_len, self.num_heads, self.head_size
 | 
			
		||||
                ).permute(0, 2, 1, 3)
 | 
			
		||||
                attention_weights = (
 | 
			
		||||
                    torch.matmul(query, key).div(self.inv_scale).softmax(dim=-1)
 | 
			
		||||
                )
 | 
			
		||||
                output = torch.matmul(attention_weights, value)
 | 
			
		||||
                return output
 | 
			
		||||
 | 
			
		||||
        torch.manual_seed(123)
 | 
			
		||||
        hidden_size = 16
 | 
			
		||||
        num_heads = 1
 | 
			
		||||
        seq_len = 4
 | 
			
		||||
        batch_size = 1
 | 
			
		||||
        x = torch.randn(batch_size, seq_len, hidden_size)
 | 
			
		||||
 | 
			
		||||
        func = Attention(hidden_size, num_heads).to("cpu")
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            res1 = func(x)
 | 
			
		||||
            jit_func = torch.compile(func)
 | 
			
		||||
            res2 = jit_func(x)
 | 
			
		||||
        self.assertEqual(res1, res2)
 | 
			
		||||
 | 
			
		||||
    def test_scalar_mul_bfloat16(self):
 | 
			
		||||
        def f(x):
 | 
			
		||||
            return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
 | 
			
		||||
 | 
			
		||||
@ -217,9 +217,9 @@ class TestPoitwiseOps(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@requires_cuda()
 | 
			
		||||
@torch._inductor.config.patch(
 | 
			
		||||
    post_grad_fusion_options={"group_linear": {"require_fbgemm": True}}
 | 
			
		||||
)
 | 
			
		||||
@torch._inductor.config.patch(post_grad_fusion_options={})
 | 
			
		||||
@torch._inductor.config.patch(pre_grad_fusion_options={})
 | 
			
		||||
@torch._inductor.config.patch(group_fusion=True, batch_fusion=True)
 | 
			
		||||
class TestGroupBatchFusion(TestCase):
 | 
			
		||||
    def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
 | 
			
		||||
        if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
 | 
			
		||||
@ -433,9 +433,7 @@ class TestBMMFusionModule(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@requires_cuda()
 | 
			
		||||
@torch._inductor.config.patch(
 | 
			
		||||
    post_grad_fusion_options={"batch_linear": {"require_fbgemm": False}}
 | 
			
		||||
)
 | 
			
		||||
@torch._inductor.config.patch(post_grad_fusion_options={"batch_linear_post_grad": {}})
 | 
			
		||||
class TestPostGradBatchLinearFusion(TestCase):
 | 
			
		||||
    def test_batch_linear_post_grad_fusion(self):
 | 
			
		||||
        pt1_module = TestBMMFusionModule().cuda()
 | 
			
		||||
 | 
			
		||||
@ -4143,7 +4143,7 @@ SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
 | 
			
		||||
                # Don't use node.name() here as it is not consistent on windows
 | 
			
		||||
                node_name = node.__class__.__name__ if node else "None"
 | 
			
		||||
                pr.append(f"Running {func} from within {node_name}")
 | 
			
		||||
                return func(*args, **kwargs or {})
 | 
			
		||||
                return func(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
        with MyMode():
 | 
			
		||||
            pr.append("FW")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								third_party/fbgemm
									
									
									
									
										vendored
									
									
								
							
							
								
								
								
								
								
							
						
						
									
										2
									
								
								third_party/fbgemm
									
									
									
									
										vendored
									
									
								
							 Submodule third_party/fbgemm updated: f49dea6a16...88fc6e741b
									
								
							@ -1095,6 +1095,10 @@ def native_dropout(input: Tensor, p: float, train: Optional[bool]):
 | 
			
		||||
    if train and p != 0:
 | 
			
		||||
        if p == 1:
 | 
			
		||||
            return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
 | 
			
		||||
        if not input.dtype.is_floating_point:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                "result type Float can't be cast to the desired output type Long"
 | 
			
		||||
            )
 | 
			
		||||
        bool_mask = torch.rand_like(input) > p
 | 
			
		||||
        res = bool_mask * input * float(1.0 / (1.0 - p))
 | 
			
		||||
        return (res, bool_mask)
 | 
			
		||||
 | 
			
		||||
@ -294,7 +294,12 @@ class VariableBuilder:
 | 
			
		||||
        # NB: Careful not to close over self to avoid ref cycle from lru_cache
 | 
			
		||||
        entries = [
 | 
			
		||||
            (
 | 
			
		||||
                (torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor),
 | 
			
		||||
                (
 | 
			
		||||
                    torch.Tensor,
 | 
			
		||||
                    torch.nn.Parameter,
 | 
			
		||||
                    torch._subclasses.FakeTensor,
 | 
			
		||||
                    torch._subclasses.functional_tensor.FunctionalTensor,
 | 
			
		||||
                ),
 | 
			
		||||
                cls.wrap_tensor,
 | 
			
		||||
            ),
 | 
			
		||||
            ((tuple, list, odict_values, collections.deque), cls.wrap_listlike),
 | 
			
		||||
@ -1005,6 +1010,7 @@ class VariableBuilder:
 | 
			
		||||
                torch.Tensor,
 | 
			
		||||
                torch.nn.Parameter,
 | 
			
		||||
                torch._subclasses.fake_tensor.FakeTensor,
 | 
			
		||||
                torch._subclasses.functional_tensor.FunctionalTensor,
 | 
			
		||||
            ) or is_traceable_wrapper_subclass(value), type(value)
 | 
			
		||||
            subclass_type = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import operator
 | 
			
		||||
import types
 | 
			
		||||
from collections import defaultdict, OrderedDict
 | 
			
		||||
from typing import Dict, List
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -39,7 +40,7 @@ from ..utils import (
 | 
			
		||||
from .base import MutableLocal, typestr, VariableTracker
 | 
			
		||||
from .constant import ConstantVariable
 | 
			
		||||
from .ctx_manager import EventVariable, StreamVariable
 | 
			
		||||
from .dicts import ConstDictVariable, SetVariable
 | 
			
		||||
from .dicts import ConstDictVariable, DefaultDictVariable, SetVariable
 | 
			
		||||
from .lists import (
 | 
			
		||||
    BaseListVariable,
 | 
			
		||||
    ListIteratorVariable,
 | 
			
		||||
@ -662,6 +663,17 @@ class BuiltinVariable(VariableTracker):
 | 
			
		||||
                )
 | 
			
		||||
        return super().call_function(tx, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    def call_method(
 | 
			
		||||
        self,
 | 
			
		||||
        tx,
 | 
			
		||||
        name,
 | 
			
		||||
        args: "List[VariableTracker]",
 | 
			
		||||
        kwargs: "Dict[str, VariableTracker]",
 | 
			
		||||
    ) -> "VariableTracker":
 | 
			
		||||
        if self.fn == dict and name == "fromkeys":
 | 
			
		||||
            return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
 | 
			
		||||
        return super().call_method(tx, name, args, kwargs)
 | 
			
		||||
 | 
			
		||||
    def _call_min_max(self, tx, *args):
 | 
			
		||||
        if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
 | 
			
		||||
            # expand iterable
 | 
			
		||||
@ -898,7 +910,44 @@ class BuiltinVariable(VariableTracker):
 | 
			
		||||
            return variables.ConstDictVariable(
 | 
			
		||||
                dict(kwargs), user_cls=user_cls, mutable_local=MutableLocal()
 | 
			
		||||
            )
 | 
			
		||||
        unimplemented(f"dict(): {args} {kwargs}")
 | 
			
		||||
        unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def call_custom_dict_fromkeys(tx, user_cls, *args, **kwargs):
 | 
			
		||||
        assert user_cls in {dict, OrderedDict, defaultdict}
 | 
			
		||||
        if kwargs:
 | 
			
		||||
            # Only `OrderedDict.fromkeys` accepts `value` passed by keyword
 | 
			
		||||
            assert user_cls is OrderedDict
 | 
			
		||||
            assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs
 | 
			
		||||
            args = (*args, kwargs.pop("value"))
 | 
			
		||||
        if len(args) == 0:
 | 
			
		||||
            raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0")
 | 
			
		||||
        if len(args) == 1:
 | 
			
		||||
            args = (*args, ConstantVariable.create(None))
 | 
			
		||||
        assert len(args) == 2
 | 
			
		||||
        arg, value = args
 | 
			
		||||
        DictVariableType = (
 | 
			
		||||
            ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if isinstance(arg, dict):
 | 
			
		||||
            return DictVariableType(
 | 
			
		||||
                dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal()
 | 
			
		||||
            )
 | 
			
		||||
        elif isinstance(
 | 
			
		||||
            arg,
 | 
			
		||||
            (
 | 
			
		||||
                ConstDictVariable,
 | 
			
		||||
                ListVariable,
 | 
			
		||||
                TupleVariable,
 | 
			
		||||
                ListIteratorVariable,
 | 
			
		||||
            ),
 | 
			
		||||
        ):
 | 
			
		||||
            keys = [DictVariableType.get_key(x) for x in arg.unpack_var_sequence(tx)]
 | 
			
		||||
            return DictVariableType(
 | 
			
		||||
                dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal()
 | 
			
		||||
            )
 | 
			
		||||
        unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
 | 
			
		||||
 | 
			
		||||
    def call_zip(self, tx, *args, **kwargs):
 | 
			
		||||
        if kwargs:
 | 
			
		||||
 | 
			
		||||
@ -77,13 +77,18 @@ class ConstDictVariable(VariableTracker):
 | 
			
		||||
        args: "List[VariableTracker]",
 | 
			
		||||
        kwargs: "Dict[str, VariableTracker]",
 | 
			
		||||
    ) -> "VariableTracker":
 | 
			
		||||
        from . import ConstantVariable, TupleVariable
 | 
			
		||||
        from . import (
 | 
			
		||||
            ConstantVariable,
 | 
			
		||||
            ListIteratorVariable,
 | 
			
		||||
            ListVariable,
 | 
			
		||||
            TupleVariable,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        val = self.items
 | 
			
		||||
 | 
			
		||||
        if name == "__getitem__":
 | 
			
		||||
            assert len(args) == 1
 | 
			
		||||
            return self.getitem_const(args[0])
 | 
			
		||||
 | 
			
		||||
        elif name == "items":
 | 
			
		||||
            assert not (args or kwargs)
 | 
			
		||||
            return TupleVariable(
 | 
			
		||||
@ -112,10 +117,12 @@ class ConstDictVariable(VariableTracker):
 | 
			
		||||
                ],
 | 
			
		||||
                mutable_local=MutableLocal(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        elif name == "values":
 | 
			
		||||
            assert not (args or kwargs)
 | 
			
		||||
            return TupleVariable(list(val.values()))
 | 
			
		||||
        elif name == "copy":
 | 
			
		||||
            assert not (args or kwargs)
 | 
			
		||||
            return self.modifed(self.items.copy(), mutable_local=MutableLocal())
 | 
			
		||||
        elif name == "__len__":
 | 
			
		||||
            assert not (args or kwargs)
 | 
			
		||||
            return ConstantVariable.create(len(self.items))
 | 
			
		||||
@ -139,10 +146,9 @@ class ConstDictVariable(VariableTracker):
 | 
			
		||||
            )
 | 
			
		||||
        elif (
 | 
			
		||||
            name in ("pop", "get")
 | 
			
		||||
            and args
 | 
			
		||||
            and len(args) == 2
 | 
			
		||||
            and ConstDictVariable.is_valid_key(args[0])
 | 
			
		||||
            and ConstDictVariable.get_key(args[0]) not in self.items
 | 
			
		||||
            and len(args) == 2
 | 
			
		||||
        ):
 | 
			
		||||
            # missing item, return the default value
 | 
			
		||||
            return args[1]
 | 
			
		||||
@ -158,12 +164,34 @@ class ConstDictVariable(VariableTracker):
 | 
			
		||||
            return result
 | 
			
		||||
        elif (
 | 
			
		||||
            name == "update"
 | 
			
		||||
            and args
 | 
			
		||||
            and len(args) == 1
 | 
			
		||||
            and isinstance(args[0], ConstDictVariable)
 | 
			
		||||
            and self.mutable_local
 | 
			
		||||
        ):
 | 
			
		||||
            newval = dict(val)
 | 
			
		||||
            newval.update(args[0].items)
 | 
			
		||||
            newval.update(kwargs)  # all keys in kwargs are valid (`str`s)
 | 
			
		||||
            result = self.modifed(newval)
 | 
			
		||||
            return tx.replace_all(self, result)
 | 
			
		||||
        elif (
 | 
			
		||||
            name == "update"
 | 
			
		||||
            and len(args) == 1
 | 
			
		||||
            and isinstance(
 | 
			
		||||
                args[0],
 | 
			
		||||
                (
 | 
			
		||||
                    ListVariable,
 | 
			
		||||
                    TupleVariable,
 | 
			
		||||
                    ListIteratorVariable,
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            and self.mutable_local
 | 
			
		||||
        ):
 | 
			
		||||
            newval = dict(val)
 | 
			
		||||
            for x in args[0].unpack_var_sequence(tx):
 | 
			
		||||
                k, v = x.unpack_var_sequence(tx)
 | 
			
		||||
                assert ConstDictVariable.is_valid_key(k)
 | 
			
		||||
                newval[ConstDictVariable.get_key(k)] = v
 | 
			
		||||
            newval.update(kwargs)  # all keys in kwargs are valid (`str`s)
 | 
			
		||||
            result = self.modifed(newval)
 | 
			
		||||
            return tx.replace_all(self, result)
 | 
			
		||||
        elif (
 | 
			
		||||
 | 
			
		||||
@ -972,6 +972,24 @@ class SkipFilesVariable(VariableTracker):
 | 
			
		||||
            msg += f"', {self.reason}'" if self.reason else ""
 | 
			
		||||
            unimplemented(msg)
 | 
			
		||||
 | 
			
		||||
    def call_method(
 | 
			
		||||
        self,
 | 
			
		||||
        tx,
 | 
			
		||||
        name,
 | 
			
		||||
        args: "List[VariableTracker]",
 | 
			
		||||
        kwargs: "Dict[str, VariableTracker]",
 | 
			
		||||
    ) -> "VariableTracker":
 | 
			
		||||
        if (
 | 
			
		||||
            self.value in {collections.OrderedDict, collections.defaultdict}
 | 
			
		||||
            and name == "fromkeys"
 | 
			
		||||
        ):
 | 
			
		||||
            from .builtin import BuiltinVariable
 | 
			
		||||
 | 
			
		||||
            return BuiltinVariable.call_custom_dict_fromkeys(
 | 
			
		||||
                tx, self.value, *args, **kwargs
 | 
			
		||||
            )
 | 
			
		||||
        return super().call_method(tx, name, args, kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TypingVariable(VariableTracker):
 | 
			
		||||
    def __init__(self, value, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -1147,7 +1147,7 @@ def aot_compile(
 | 
			
		||||
        constraints,
 | 
			
		||||
        disable_constraint_solver=disable_constraint_solver
 | 
			
		||||
    )
 | 
			
		||||
    flat_example_inputs = pytree.arg_tree_leaves(*args, **kwargs or {})
 | 
			
		||||
    flat_example_inputs = pytree.arg_tree_leaves(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        so_path = torch._inductor.aot_compile(gm, flat_example_inputs, options)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ from ..codecache import code_hash, get_path, PyCodeCache
 | 
			
		||||
from ..dependencies import MemoryDep, StarDep
 | 
			
		||||
from ..ir import IRNode, ReductionHint, TritonTemplateBuffer
 | 
			
		||||
from ..optimize_indexing import indexing_dtype_strength_reduction
 | 
			
		||||
from ..scheduler import BaseScheduling
 | 
			
		||||
from ..scheduler import BaseScheduling, WhyNoFuse
 | 
			
		||||
from ..triton_heuristics import AutotuneHint
 | 
			
		||||
from ..utils import (
 | 
			
		||||
    do_bench,
 | 
			
		||||
@ -2226,12 +2226,13 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
 | 
			
		||||
        _, (numel1, rnumel1) = node1.group
 | 
			
		||||
        _, (numel2, rnumel2) = node2.group
 | 
			
		||||
        why = WhyNoFuse(node1, node2)
 | 
			
		||||
 | 
			
		||||
        if node1.is_reduction() and node2.is_reduction():
 | 
			
		||||
            reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2
 | 
			
		||||
            if not reduction_can_fuse:
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "cannot fuse (triton:1): numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)",
 | 
			
		||||
                why(
 | 
			
		||||
                    "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)",
 | 
			
		||||
                    numel1,
 | 
			
		||||
                    numel2,
 | 
			
		||||
                    rnumel1,
 | 
			
		||||
@ -2241,8 +2242,8 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
 | 
			
		||||
        if not node1.is_reduction() and not node2.is_reduction():
 | 
			
		||||
            if not (numel1 == numel2 and rnumel1 == rnumel2):
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "cannot fuse (triton:2): numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)",
 | 
			
		||||
                why(
 | 
			
		||||
                    "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)",
 | 
			
		||||
                    numel1,
 | 
			
		||||
                    numel2,
 | 
			
		||||
                    rnumel1,
 | 
			
		||||
@ -2255,10 +2256,7 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
                # Fusion for CUDATemplates are not supported.
 | 
			
		||||
                is_triton_template = isinstance(node1.node, TritonTemplateBuffer)
 | 
			
		||||
                if not is_triton_template:
 | 
			
		||||
                    fusion_log.debug(
 | 
			
		||||
                        "cannot fuse (triton:3): is not TritonTemplateBuffer %s",
 | 
			
		||||
                        node1,
 | 
			
		||||
                    )
 | 
			
		||||
                    why("node1 is not TritonTemplateBuffer")
 | 
			
		||||
                return is_triton_template
 | 
			
		||||
 | 
			
		||||
            # check for a bad combined tiling
 | 
			
		||||
@ -2277,13 +2275,13 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
                elif len(tiling2) > 2:
 | 
			
		||||
                    cond = tiling2 == tiling3
 | 
			
		||||
                if not cond:
 | 
			
		||||
                    fusion_log.debug(
 | 
			
		||||
                        "cannot fuse (triton:4): tiling mismatch (%s, %s, %s)",
 | 
			
		||||
                    why(
 | 
			
		||||
                        "tiling mismatch (%s, %s, %s)",
 | 
			
		||||
                        tiling1,
 | 
			
		||||
                        tiling2,
 | 
			
		||||
                        tiling3,
 | 
			
		||||
                    )
 | 
			
		||||
                    return cond
 | 
			
		||||
                    return False
 | 
			
		||||
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
@ -2294,9 +2292,7 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
                    TritonKernel.is_compatible((numel2, rnumel2), n.get_ranges())
 | 
			
		||||
                    for n in node1.get_nodes()
 | 
			
		||||
                ):
 | 
			
		||||
                    fusion_log.debug(
 | 
			
		||||
                        "cannot fuse (triton:5): nodes numel/rnumel incompatibility"
 | 
			
		||||
                    )
 | 
			
		||||
                    why("nodes numel/rnumel incompatibility")
 | 
			
		||||
                    return False
 | 
			
		||||
                if (
 | 
			
		||||
                    config.triton.tiling_prevents_reduction_fusion
 | 
			
		||||
@ -2309,12 +2305,12 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
                        (numel2, rnumel2, 1),
 | 
			
		||||
                    )
 | 
			
		||||
                    if not is_reduction_tiling_valid:
 | 
			
		||||
                        fusion_log.debug(
 | 
			
		||||
                            "cannot fuse (triton:6): invalid tiling for reduction"
 | 
			
		||||
                        )
 | 
			
		||||
                        why("invalid tiling for reduction")
 | 
			
		||||
                    return is_reduction_tiling_valid
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
            if numel1 != numel2:
 | 
			
		||||
                why("nodes numel incompatibility")
 | 
			
		||||
            return numel1 == numel2
 | 
			
		||||
 | 
			
		||||
        assert node1.is_reduction() and not node2.is_reduction()
 | 
			
		||||
 | 
			
		||||
@ -497,6 +497,9 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
    def generate_end(self, result):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def generate_fallback_kernel(self, fallback_kernel, args):
 | 
			
		||||
        self.generate_extern_kernel_alloc(fallback_kernel, args)
 | 
			
		||||
 | 
			
		||||
    def generate_extern_kernel_alloc(self, extern_kernel, args):
 | 
			
		||||
        ending = self.ending
 | 
			
		||||
        if config.memory_planning and "view_as_complex" in str(extern_kernel.kernel):
 | 
			
		||||
@ -879,10 +882,8 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
        signature: List[Union[TensorArg, SizeArg]] = []
 | 
			
		||||
        constants = {}
 | 
			
		||||
        for key, arg in kwargs.items():
 | 
			
		||||
            if (
 | 
			
		||||
                key in kernel.__annotations__
 | 
			
		||||
                and "constexpr" in kernel.__annotations__[key]
 | 
			
		||||
            ):
 | 
			
		||||
            idx = kernel.arg_names.index(key)
 | 
			
		||||
            if idx in kernel.constexprs:
 | 
			
		||||
                constants[key] = arg
 | 
			
		||||
                continue
 | 
			
		||||
            if isinstance(arg, (ir.Buffer, ir.ReinterpretView)):
 | 
			
		||||
@ -1732,11 +1733,28 @@ class CppWrapperCodeGen(WrapperCodeGen):
 | 
			
		||||
        shim_fn = f"aoti_torch_{kernel_suffix}"
 | 
			
		||||
        self.writeline(f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));")
 | 
			
		||||
 | 
			
		||||
    def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
 | 
			
		||||
    def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
 | 
			
		||||
        # registered output buffer name
 | 
			
		||||
        name = extern_kernel.name
 | 
			
		||||
        output_handle_name = f"{name}_handle"
 | 
			
		||||
        self.writeline(f"AtenTensorHandle {output_handle_name};")
 | 
			
		||||
        output_arg = f"&{output_handle_name}"
 | 
			
		||||
        self.generate_c_shim_extern_kernel_call(
 | 
			
		||||
            extern_kernel.codegen_kernel_name(), args + [output_arg]
 | 
			
		||||
        )
 | 
			
		||||
        self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")
 | 
			
		||||
 | 
			
		||||
    def generate_extern_kernel_alloc(self, extern_kernel, args):
 | 
			
		||||
        if V.graph.aot_mode and config.aot_inductor.abi_compatible:
 | 
			
		||||
            self.generate_c_shim_extern_kernel_alloc(extern_kernel, args)
 | 
			
		||||
        else:
 | 
			
		||||
            super().generate_extern_kernel_alloc(extern_kernel, args)
 | 
			
		||||
 | 
			
		||||
    def generate_c_shim_fallback_kernel(self, fallback_kernel, args):
 | 
			
		||||
        output_args = []
 | 
			
		||||
        output_raii_handles = []
 | 
			
		||||
        output_name_base = extern_kernel.get_name()
 | 
			
		||||
        for idx, output in enumerate(extern_kernel.outputs):
 | 
			
		||||
        output_name_base = fallback_kernel.get_name()
 | 
			
		||||
        for idx, output in enumerate(fallback_kernel.outputs):
 | 
			
		||||
            if isinstance(output, ir.MultiOutput):
 | 
			
		||||
                name = f"{output.get_name()}"
 | 
			
		||||
                output_handle_name = f"{name}_handle"
 | 
			
		||||
@ -1759,19 +1777,19 @@ class CppWrapperCodeGen(WrapperCodeGen):
 | 
			
		||||
                raise NotImplementedError("unsupported type of {output=}")
 | 
			
		||||
        args = args + output_args
 | 
			
		||||
        assert (
 | 
			
		||||
            extern_kernel.abi_compatible_kernel is not None
 | 
			
		||||
        ), f"abi_compatible_kernel is None for {extern_kernel.kernel=}"
 | 
			
		||||
            fallback_kernel.abi_compatible_kernel is not None
 | 
			
		||||
        ), f"abi_compatible_kernel is None for {fallback_kernel.kernel=}"
 | 
			
		||||
        self.generate_c_shim_extern_kernel_call(
 | 
			
		||||
            extern_kernel.abi_compatible_kernel, args
 | 
			
		||||
            fallback_kernel.abi_compatible_kernel, args
 | 
			
		||||
        )
 | 
			
		||||
        for raii_handle in output_raii_handles:
 | 
			
		||||
            self.writeline(raii_handle)
 | 
			
		||||
 | 
			
		||||
    def generate_extern_kernel_alloc(self, extern_kernel, args):
 | 
			
		||||
    def generate_fallback_kernel(self, fallback_kernel, args):
 | 
			
		||||
        if V.graph.aot_mode and config.aot_inductor.abi_compatible:
 | 
			
		||||
            self.generate_c_shim_extern_kernel_alloc_call(extern_kernel, args)
 | 
			
		||||
            self.generate_c_shim_fallback_kernel(fallback_kernel, args)
 | 
			
		||||
        else:
 | 
			
		||||
            super().generate_extern_kernel_alloc(extern_kernel, args)
 | 
			
		||||
            super().generate_fallback_kernel(fallback_kernel, args)
 | 
			
		||||
 | 
			
		||||
    def generate_extern_kernel_out(self, output_view, codegen_reference, args, kernel):
 | 
			
		||||
        if output_view:
 | 
			
		||||
 | 
			
		||||
@ -117,7 +117,7 @@ class BatchPointwiseOpsFusionFactory(BatchFusion):
 | 
			
		||||
        self.op = op
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@register_fusion("batch_linear", pre_grad=False)
 | 
			
		||||
@register_fusion("batch_linear_post_grad", pre_grad=False)
 | 
			
		||||
class PostGradBatchLinearFusion(BatchFusion):
 | 
			
		||||
    """
 | 
			
		||||
    Fuse ops in a batch way in post grad (aten level).
 | 
			
		||||
@ -172,20 +172,18 @@ class PostGradBatchLinearFusion(BatchFusion):
 | 
			
		||||
            fused_inputs = decompose_stack(graph, batch_inputs)
 | 
			
		||||
            fused_weights = decompose_stack(graph, batch_weights)
 | 
			
		||||
            fused_bmm = graph.call_function(
 | 
			
		||||
                torch.ops.aten.bmm,
 | 
			
		||||
                aten.bmm,
 | 
			
		||||
                args=(fused_inputs, fused_weights),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for i, original_mm in enumerate(batch_nodes):
 | 
			
		||||
            has_bias = False
 | 
			
		||||
            with graph.inserting_after(fused_bmm):
 | 
			
		||||
                new_mm = graph.call_function(
 | 
			
		||||
                    torch.ops.aten.select, args=((fused_bmm, 0, i))
 | 
			
		||||
                )
 | 
			
		||||
                new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i)))
 | 
			
		||||
                if batch_biases[i]:
 | 
			
		||||
                    has_bias = True
 | 
			
		||||
                    new_bias_add = graph.call_function(
 | 
			
		||||
                        torch.ops.aten.add, args=((batch_biases[i], new_mm))
 | 
			
		||||
                        aten.add, args=((batch_biases[i], new_mm))
 | 
			
		||||
                    )
 | 
			
		||||
            new_mm_cont = new_bias_add if has_bias else new_mm
 | 
			
		||||
            original_mm.replace_all_uses_with(new_mm_cont)
 | 
			
		||||
@ -763,7 +761,25 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):
 | 
			
		||||
def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
 | 
			
		||||
    print_graph(graph, "Before group_batch fusion in pre grad pass.")
 | 
			
		||||
    fusions: List[GroupBatchFusionBase] = []
 | 
			
		||||
 | 
			
		||||
    # we keep all current pre grad fusions to keep
 | 
			
		||||
    # current implementation, will remove this later
 | 
			
		||||
    # TODO: deperate batch_fusion and group_fusion flags
 | 
			
		||||
    if config.batch_fusion:
 | 
			
		||||
        config.pre_grad_fusion_options = {
 | 
			
		||||
            "batch_linear": {},
 | 
			
		||||
            "batch_linear_lhs": {},
 | 
			
		||||
            "batch_layernorm": {},
 | 
			
		||||
            "batch_tanh": {},
 | 
			
		||||
            "batch_relu": {},
 | 
			
		||||
            "batch_sigmoid": {},
 | 
			
		||||
        }
 | 
			
		||||
        # config.post_grad_fusion_options = {
 | 
			
		||||
        #     "batch_linear_post_grad": {},
 | 
			
		||||
        # }
 | 
			
		||||
    if config.group_fusion:
 | 
			
		||||
        config.post_grad_fusion_options = {
 | 
			
		||||
            "group_linear": {"require_fbgemm": True},
 | 
			
		||||
        }
 | 
			
		||||
    if pre_grad:
 | 
			
		||||
        fusions += generate_fusion_from_config(
 | 
			
		||||
            config.pre_grad_fusion_options, pre_grad=True
 | 
			
		||||
 | 
			
		||||
@ -3787,38 +3787,12 @@ class RandomSeeds(ExternKernelOut):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExternKernelAlloc(ExternKernel):
 | 
			
		||||
    # Generate abi-compatible kernel names for shim kernels.
 | 
			
		||||
    # Each individual shim kernel may have its own versioning rule.
 | 
			
		||||
    # However, we don't expect we would end up with too many of such rules.
 | 
			
		||||
    def _get_abi_compatible_kernel(self):
 | 
			
		||||
        if not V.graph.cpp_wrapper:
 | 
			
		||||
            return self.kernel
 | 
			
		||||
 | 
			
		||||
        def sdpa_ver_fn():
 | 
			
		||||
            # For sdpa, we need the v2 version only if any optional
 | 
			
		||||
            # kwarg is missing.
 | 
			
		||||
            if any(
 | 
			
		||||
                self.get_kwargs_value(arg_name) is None
 | 
			
		||||
                for arg_name in self.ordered_kwargs_for_cpp_kernel
 | 
			
		||||
            ):
 | 
			
		||||
                return f"{self.cpp_kernel}_v2"
 | 
			
		||||
            else:
 | 
			
		||||
                return self.cpp_kernel
 | 
			
		||||
 | 
			
		||||
        kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
 | 
			
		||||
        if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None:
 | 
			
		||||
            return ver_fn()
 | 
			
		||||
        return self.cpp_kernel
 | 
			
		||||
 | 
			
		||||
    def codegen_kernel_name(self):
 | 
			
		||||
        return self.cpp_kernel if V.graph.cpp_wrapper else self.kernel
 | 
			
		||||
 | 
			
		||||
    def codegen(self, wrapper):
 | 
			
		||||
        self.codegen_comment(wrapper)
 | 
			
		||||
        args = [*self.codegen_args(), *self.codegen_kwargs()]
 | 
			
		||||
        # Now we setup abi_compatible_kernel after self.kernel
 | 
			
		||||
        # and kwargs are adjusted appropriately.
 | 
			
		||||
        self.abi_compatible_kernel = self._get_abi_compatible_kernel()
 | 
			
		||||
        V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
 | 
			
		||||
        if isinstance(self.layout, Layout):
 | 
			
		||||
            self.codegen_size_asserts(wrapper)
 | 
			
		||||
@ -3839,7 +3813,6 @@ class ExternKernelAlloc(ExternKernel):
 | 
			
		||||
        self.name = V.graph.register_buffer(self)
 | 
			
		||||
        self.kernel = kernel
 | 
			
		||||
        self.cpp_kernel = cpp_kernel
 | 
			
		||||
        self.abi_compatible_kernel = None
 | 
			
		||||
        self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
 | 
			
		||||
 | 
			
		||||
    def should_allocate(self):
 | 
			
		||||
@ -4294,6 +4267,7 @@ class FallbackKernel(ExternKernelAlloc):
 | 
			
		||||
        # output through the abi-compatible interface.
 | 
			
		||||
        self.outputs: Sequence[Any] = []
 | 
			
		||||
        self.use_runtime_dispatch = False
 | 
			
		||||
        self.abi_compatible_kernel = None
 | 
			
		||||
 | 
			
		||||
        assert isinstance(
 | 
			
		||||
            kernel,
 | 
			
		||||
@ -4351,6 +4325,29 @@ class FallbackKernel(ExternKernelAlloc):
 | 
			
		||||
        ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}"
 | 
			
		||||
        return self.args_default_value[pos]["value"]
 | 
			
		||||
 | 
			
		||||
    # Generate abi-compatible kernel names for shim kernels.
 | 
			
		||||
    # Each individual shim kernel may have its own versioning rule.
 | 
			
		||||
    # However, we don't expect we would end up with too many of such rules.
 | 
			
		||||
    def _get_abi_compatible_kernel(self):
 | 
			
		||||
        if not V.graph.cpp_wrapper:
 | 
			
		||||
            return self.kernel
 | 
			
		||||
 | 
			
		||||
        def sdpa_ver_fn():
 | 
			
		||||
            # For sdpa, we need the v2 version only if any optional
 | 
			
		||||
            # kwarg is missing.
 | 
			
		||||
            if any(
 | 
			
		||||
                self.get_kwargs_value(arg_name) is None
 | 
			
		||||
                for arg_name in self.ordered_kwargs_for_cpp_kernel
 | 
			
		||||
            ):
 | 
			
		||||
                return f"{self.cpp_kernel}_v2"
 | 
			
		||||
            else:
 | 
			
		||||
                return self.cpp_kernel
 | 
			
		||||
 | 
			
		||||
        kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
 | 
			
		||||
        if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None:
 | 
			
		||||
            return ver_fn()
 | 
			
		||||
        return self.cpp_kernel
 | 
			
		||||
 | 
			
		||||
    def codegen_args(self):
 | 
			
		||||
        @dataclasses.dataclass
 | 
			
		||||
        class Shim:
 | 
			
		||||
@ -4361,6 +4358,9 @@ class FallbackKernel(ExternKernelAlloc):
 | 
			
		||||
 | 
			
		||||
        tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
 | 
			
		||||
        args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
 | 
			
		||||
        # Now we setup abi_compatible_kernel after self.kernel
 | 
			
		||||
        # and kwargs are adjusted appropriately.
 | 
			
		||||
        self.abi_compatible_kernel = self._get_abi_compatible_kernel()
 | 
			
		||||
 | 
			
		||||
        if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
 | 
			
		||||
            args = [
 | 
			
		||||
@ -4579,7 +4579,11 @@ class FallbackKernel(ExternKernelAlloc):
 | 
			
		||||
                self.outputs,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            super().codegen(wrapper)
 | 
			
		||||
            self.codegen_comment(wrapper)
 | 
			
		||||
            args = [*self.codegen_args(), *self.codegen_kwargs()]
 | 
			
		||||
            V.graph.wrapper_code.generate_fallback_kernel(self, args)
 | 
			
		||||
            if isinstance(self.layout, Layout):
 | 
			
		||||
                self.codegen_size_asserts(wrapper)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tensor_to_layout(output: torch.Tensor):
 | 
			
		||||
 | 
			
		||||
@ -919,7 +919,15 @@ def register_replacement(
 | 
			
		||||
                        device=args[i].device,
 | 
			
		||||
                        requires_grad=grad,
 | 
			
		||||
                    )
 | 
			
		||||
            specific_graph = trace_fn(search_fn, args)
 | 
			
		||||
            try:
 | 
			
		||||
                specific_graph = trace_fn(search_fn, args)
 | 
			
		||||
            except RuntimeError as e:
 | 
			
		||||
                log.info(
 | 
			
		||||
                    "Replacement pattern %s failed to apply due to shape mismatch: %s",
 | 
			
		||||
                    search_fn.__name__,
 | 
			
		||||
                    e,
 | 
			
		||||
                )
 | 
			
		||||
                return False
 | 
			
		||||
            specific_pattern = fx_to_pattern(
 | 
			
		||||
                specific_graph,
 | 
			
		||||
                argnames=argnames,
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,18 @@ import math
 | 
			
		||||
import os
 | 
			
		||||
import pprint
 | 
			
		||||
import textwrap
 | 
			
		||||
from typing import Counter, DefaultDict, Dict, List, Optional, Sequence, Set, Union
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Counter,
 | 
			
		||||
    DefaultDict,
 | 
			
		||||
    Dict,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Set,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import sympy
 | 
			
		||||
 | 
			
		||||
@ -41,6 +52,28 @@ log = logging.getLogger(__name__)
 | 
			
		||||
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WhyNoFuse:
 | 
			
		||||
    # TODO when we drop support for Python < 3.10, we can use
 | 
			
		||||
    # @dataclass(slots=True) instead of manually specifying __slots__.
 | 
			
		||||
    __slots__ = ["node1", "node2", "reason", "args"]
 | 
			
		||||
    reason: str
 | 
			
		||||
    args: Tuple[Any, ...]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, node1: "BaseSchedulerNode", node2: "BaseSchedulerNode"):
 | 
			
		||||
        self.node1 = node1
 | 
			
		||||
        self.node2 = node2
 | 
			
		||||
 | 
			
		||||
    def __call__(self, reason, *args):
 | 
			
		||||
        self.reason = reason
 | 
			
		||||
        self.args = args
 | 
			
		||||
        fusion_log.debug(self)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"cannot fuse {self.node1.get_name()} with {self.node2.get_name()}: " + (
 | 
			
		||||
            self.reason % self.args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pformat(obj):
 | 
			
		||||
    if isinstance(obj, set):
 | 
			
		||||
        # pformat has trouble with sets of sympy exprs
 | 
			
		||||
@ -930,12 +963,11 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def can_fuse(cls, producer, consumer):
 | 
			
		||||
        why = WhyNoFuse(producer, consumer)
 | 
			
		||||
        if producer.is_foreach() and consumer.is_foreach():
 | 
			
		||||
            foreach_match = len(producer.snodes) == len(consumer.snodes)
 | 
			
		||||
            if not foreach_match:
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "cannot fuse (foreach:1): foreach do not have same length"
 | 
			
		||||
                )
 | 
			
		||||
                why("foreach do not have same length")
 | 
			
		||||
            return foreach_match and all(
 | 
			
		||||
                producer.scheduler.can_fuse(l, r)
 | 
			
		||||
                for l, r in zip(producer.snodes, consumer.snodes)
 | 
			
		||||
@ -945,9 +977,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
            if consumer_subnode is not None:
 | 
			
		||||
                return consumer.scheduler.can_fuse(producer, consumer_subnode)
 | 
			
		||||
 | 
			
		||||
            fusion_log.debug(
 | 
			
		||||
                "cannot fuse (foreach:2): candidate producer is not dep of any foreach consumer"
 | 
			
		||||
            )
 | 
			
		||||
            why("candidate producer is not dep of any foreach consumer")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        elif producer.is_foreach():
 | 
			
		||||
@ -955,9 +985,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
            if producer_subnode is not None:
 | 
			
		||||
                return producer.scheduler.can_fuse(producer_subnode, consumer)
 | 
			
		||||
 | 
			
		||||
            fusion_log.debug(
 | 
			
		||||
                "cannot fuse (foreach:3): candidate consumer has no dep in any foreach producer"
 | 
			
		||||
            )
 | 
			
		||||
            why("candidate consumer has no dep in any foreach producer")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        raise AssertionError(
 | 
			
		||||
@ -1785,9 +1813,7 @@ class Scheduler:
 | 
			
		||||
        combined_ancestors = (node1.ancestors | node2.ancestors) - combined_names
 | 
			
		||||
        cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
 | 
			
		||||
        if cycle:
 | 
			
		||||
            fusion_log.debug(
 | 
			
		||||
                "cannot fuse (cycle): will create cycle - %s %s", node1, node2
 | 
			
		||||
            )
 | 
			
		||||
            WhyNoFuse(node1, node2)("will create cycle")
 | 
			
		||||
        return cycle
 | 
			
		||||
 | 
			
		||||
    def can_fusion_increase_peak_memory(
 | 
			
		||||
@ -1825,24 +1851,27 @@ class Scheduler:
 | 
			
		||||
 | 
			
		||||
        if node1 is node2:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        why = WhyNoFuse(node1, node2)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
 | 
			
		||||
            and not node1.is_template()
 | 
			
		||||
        ):
 | 
			
		||||
            fusion_log.debug("cannot fuse (1): node1 %s is extern or nop", node1)
 | 
			
		||||
            why("node1 is extern or nop")
 | 
			
		||||
            return False
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
 | 
			
		||||
            and not node2.is_template()
 | 
			
		||||
        ):
 | 
			
		||||
            fusion_log.debug("cannot fuse (2): node2 %s is extern or nop", node2)
 | 
			
		||||
            why("node2 is extern or nop")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        if node1.is_foreach() or node2.is_foreach():
 | 
			
		||||
            return ForeachKernelSchedulerNode.can_fuse(node1, node2)
 | 
			
		||||
 | 
			
		||||
        if node2.get_names() & node1.ancestors:
 | 
			
		||||
            fusion_log.debug("cannot fuse (3): node1 must go before node2")
 | 
			
		||||
            why("node1 must go before node2")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
@ -1865,24 +1894,20 @@ class Scheduler:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        if node2.is_template():
 | 
			
		||||
            fusion_log.debug("cannot fuse (4): templates can only fuse epilogues")
 | 
			
		||||
            why("templates can only fuse epilogues")
 | 
			
		||||
            return False
 | 
			
		||||
        if node1.is_template() and (
 | 
			
		||||
            node2.has_aliasing_or_mutation()
 | 
			
		||||
            or node2.is_reduction()
 | 
			
		||||
            or not config.epilogue_fusion
 | 
			
		||||
        ):
 | 
			
		||||
            fusion_log.debug("cannot fuse (5): template epilogue not satisfied")
 | 
			
		||||
            why("template epilogue not satisfied")
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        device = node1.get_device()
 | 
			
		||||
        device2 = node2.get_device()
 | 
			
		||||
        if device != device2:
 | 
			
		||||
            fusion_log.debug(
 | 
			
		||||
                "cannot fuse (6): device mismatch (node1: %s, node2: %s)",
 | 
			
		||||
                device,
 | 
			
		||||
                device2,
 | 
			
		||||
            )
 | 
			
		||||
            why("device mismatch (%s vs %s)", device, device2)
 | 
			
		||||
            return False
 | 
			
		||||
        del device2
 | 
			
		||||
 | 
			
		||||
@ -1890,7 +1915,7 @@ class Scheduler:
 | 
			
		||||
        if no_shared_data and (
 | 
			
		||||
            not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
 | 
			
		||||
        ):
 | 
			
		||||
            fusion_log.debug("cannot fuse (7): no shared data")
 | 
			
		||||
            why("no shared data")
 | 
			
		||||
            return False  # heuristic not needed for correctness
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
@ -1898,7 +1923,7 @@ class Scheduler:
 | 
			
		||||
            and not node2.is_foreach()
 | 
			
		||||
            and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
 | 
			
		||||
        ):
 | 
			
		||||
            fusion_log.debug("cannot fuse (8): exceeds max fusion")
 | 
			
		||||
            why("exceeds max fusion")
 | 
			
		||||
            return False  # heuristic not needed for correctness
 | 
			
		||||
 | 
			
		||||
        if node1.get_names() & node2.ancestors:
 | 
			
		||||
@ -1908,7 +1933,7 @@ class Scheduler:
 | 
			
		||||
            return self.get_backend(device).can_fuse_vertical(node1, node2)
 | 
			
		||||
        else:  # nodes don't depend on each other, but may have common reads
 | 
			
		||||
            if self.can_fusion_increase_peak_memory(node1, node2):
 | 
			
		||||
                fusion_log.debug("cannot fuse (9): will increase peak memory")
 | 
			
		||||
                why("will increase peak memory")
 | 
			
		||||
                return False
 | 
			
		||||
            return self.get_backend(device).can_fuse_horizontal(node1, node2)
 | 
			
		||||
 | 
			
		||||
@ -1922,6 +1947,7 @@ class Scheduler:
 | 
			
		||||
        """
 | 
			
		||||
        node1_names = node1.get_names()
 | 
			
		||||
        computed_deps = set()
 | 
			
		||||
        why = WhyNoFuse(node1, node2)
 | 
			
		||||
 | 
			
		||||
        for rd in node2.unmet_dependencies:
 | 
			
		||||
            for cd in node1.read_writes.writes:
 | 
			
		||||
@ -1946,13 +1972,11 @@ class Scheduler:
 | 
			
		||||
            # Examples here include:
 | 
			
		||||
            #   - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
 | 
			
		||||
            #   - MemoryDep("foo", x) != StarDep("foo")
 | 
			
		||||
            fusion_log.debug("cannot fuse (vert:1): memory deps did not match")
 | 
			
		||||
            why("memory deps did not match")
 | 
			
		||||
            return False
 | 
			
		||||
        for name in remaining_deps:
 | 
			
		||||
            if node1_names & self.name_to_fused_node[name].ancestors:
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "cannot fuse (vert:2): intermediate nodes between node1 & node2"
 | 
			
		||||
                )
 | 
			
		||||
                why("intermediate nodes between node1 & node2")
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1110,16 +1110,17 @@ def blue_text(msg):
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
def get_device_tflops(dtype):
 | 
			
		||||
    from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
 | 
			
		||||
    from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi
 | 
			
		||||
 | 
			
		||||
    assert dtype in (torch.float16, torch.bfloat16, torch.float32)
 | 
			
		||||
    cur_sm_clock = nvsmi(["clocks.current.sm"])[0]
 | 
			
		||||
    if dtype in (torch.float16, torch.bfloat16):
 | 
			
		||||
        return get_max_tensorcore_tflops(dtype)
 | 
			
		||||
        return get_max_tensorcore_tflops(dtype, cur_sm_clock)
 | 
			
		||||
 | 
			
		||||
    if torch.backends.cuda.matmul.allow_tf32:
 | 
			
		||||
        return get_max_tensorcore_tflops(torch.float32)
 | 
			
		||||
        return get_max_tensorcore_tflops(torch.float32, cur_sm_clock)
 | 
			
		||||
    else:
 | 
			
		||||
        return get_max_simd_tflops(torch.float32)
 | 
			
		||||
        return get_max_simd_tflops(torch.float32, cur_sm_clock)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(None)
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar]
 | 
			
		||||
 | 
			
		||||
DTypeLike = typing.TypeVar("DTypeLike")
 | 
			
		||||
AxisLike = typing.TypeVar("AxisLike")
 | 
			
		||||
NDArray = typing.TypeVar("NDarray")
 | 
			
		||||
NDArray = typing.TypeVar("NDArray")
 | 
			
		||||
CastingModes = typing.TypeVar("CastingModes")
 | 
			
		||||
KeepDims = typing.TypeVar("KeepDims")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@ import types
 | 
			
		||||
from typing import Any, Callable, Dict, List, Type, Union
 | 
			
		||||
 | 
			
		||||
import torch._C
 | 
			
		||||
 | 
			
		||||
import torch.utils._pytree as pytree
 | 
			
		||||
from torch import _utils_internal
 | 
			
		||||
from torch._functorch.pyfunctorch import dispatch_functorch
 | 
			
		||||
 | 
			
		||||
@ -369,7 +369,7 @@ class HigherOrderOperator(OperatorBase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _to_flat_tuple(args, kwargs):
 | 
			
		||||
    return torch.utils._pytree.arg_tree_leaves(*args, **kwargs)
 | 
			
		||||
    return pytree.arg_tree_leaves(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _compute_keyset(args, kwargs, non_fallthrough_keys):
 | 
			
		||||
@ -506,7 +506,7 @@ class OpOverload(OperatorBase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __call__(self, *args, **kwargs):
 | 
			
		||||
        return self._op(*args, **kwargs or {})
 | 
			
		||||
        return self._op(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
    def __hash__(self):
 | 
			
		||||
        return hash(self._op)
 | 
			
		||||
@ -601,9 +601,7 @@ class OpOverload(OperatorBase):
 | 
			
		||||
                    with temporarily_pop_mode(curr_stack) as curr_mode:
 | 
			
		||||
                        assert hasattr(curr_mode, "__torch_dispatch__")
 | 
			
		||||
                        overload_types = []
 | 
			
		||||
                        args_flattened, _ = torch.utils._pytree.tree_flatten(
 | 
			
		||||
                            (args, kwargs.values())
 | 
			
		||||
                        )
 | 
			
		||||
                        args_flattened = pytree.arg_tree_leaves(*args, **kwargs)
 | 
			
		||||
                        for a in args_flattened:
 | 
			
		||||
                            # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
 | 
			
		||||
                            # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
 | 
			
		||||
@ -750,7 +748,7 @@ class OpOverloadPacket:
 | 
			
		||||
        # is still callable from JIT
 | 
			
		||||
        # We save the function ptr as the `op` attribute on
 | 
			
		||||
        # OpOverloadPacket to access it here.
 | 
			
		||||
        return self._op(*args, **kwargs or {})
 | 
			
		||||
        return self._op(*args, **(kwargs or {}))
 | 
			
		||||
 | 
			
		||||
    # TODO: use this to make a __dir__
 | 
			
		||||
    def overloads(self):
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,7 @@ void AdaptiveLogSoftmaxWithLossImpl::reset() {
 | 
			
		||||
                  .size() == options.cutoffs().size(),
 | 
			
		||||
      "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ",
 | 
			
		||||
      "where each value is between 1 and n_classes-1");
 | 
			
		||||
  TORCH_CHECK(options.div_value() != 0, "div_value should not be equal to 0");
 | 
			
		||||
 | 
			
		||||
  cutoffs = options.cutoffs();
 | 
			
		||||
  cutoffs.push_back(options.n_classes());
 | 
			
		||||
@ -53,8 +54,8 @@ void AdaptiveLogSoftmaxWithLossImpl::reset() {
 | 
			
		||||
  tail = this->register_module("tail", ModuleList());
 | 
			
		||||
 | 
			
		||||
  for (const auto i : c10::irange(n_clusters)) {
 | 
			
		||||
    int64_t hsz = options.in_features() /
 | 
			
		||||
        static_cast<int64_t>(std::pow(options.div_value(), (i + 1)));
 | 
			
		||||
    int64_t hsz = static_cast<int64_t>(std::floor(
 | 
			
		||||
        options.in_features() / std::pow(options.div_value(), (i + 1))));
 | 
			
		||||
    int64_t osz = cutoffs[i + 1] - cutoffs[i];
 | 
			
		||||
 | 
			
		||||
    Sequential projection(
 | 
			
		||||
 | 
			
		||||
@ -233,6 +233,23 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm(
 | 
			
		||||
    AtenTensorHandle* ret0,
 | 
			
		||||
    AtenTensorHandle* ret1);
 | 
			
		||||
 | 
			
		||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution(
 | 
			
		||||
    AtenTensorHandle input,
 | 
			
		||||
    AtenTensorHandle weight,
 | 
			
		||||
    AtenTensorHandle bias, // optional argument
 | 
			
		||||
    int64_t* stride_ptr,
 | 
			
		||||
    int64_t stride_size,
 | 
			
		||||
    int64_t* padding_ptr,
 | 
			
		||||
    int64_t padding_size,
 | 
			
		||||
    int64_t* dilation_ptr,
 | 
			
		||||
    int64_t dilation_size,
 | 
			
		||||
    int transposed,
 | 
			
		||||
    int64_t* output_padding_ptr,
 | 
			
		||||
    int64_t output_padding_size,
 | 
			
		||||
    int64_t groups,
 | 
			
		||||
    AtenTensorHandle* ret // returns new reference
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// This function will create a new uninitialized tensor object
 | 
			
		||||
// and its pointer is returned through *ret.
 | 
			
		||||
AOTI_TORCH_EXPORT AOTITorchError
 | 
			
		||||
 | 
			
		||||
@ -359,6 +359,47 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
 | 
			
		||||
      ret8);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution(
 | 
			
		||||
    AtenTensorHandle input,
 | 
			
		||||
    AtenTensorHandle weight,
 | 
			
		||||
    AtenTensorHandle bias, // optional argument
 | 
			
		||||
    int64_t* stride_ptr,
 | 
			
		||||
    int64_t stride_size,
 | 
			
		||||
    int64_t* padding_ptr,
 | 
			
		||||
    int64_t padding_size,
 | 
			
		||||
    int64_t* dilation_ptr,
 | 
			
		||||
    int64_t dilation_size,
 | 
			
		||||
    int transposed,
 | 
			
		||||
    int64_t* output_padding_ptr,
 | 
			
		||||
    int64_t output_padding_size,
 | 
			
		||||
    int64_t groups,
 | 
			
		||||
    AtenTensorHandle* out // returns new reference
 | 
			
		||||
) {
 | 
			
		||||
  AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
 | 
			
		||||
    at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input);
 | 
			
		||||
    at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight);
 | 
			
		||||
    at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias);
 | 
			
		||||
    auto optional_bias = pointer_to_optional(bias_tensor);
 | 
			
		||||
    c10::IntArrayRef stride(stride_ptr, stride_size);
 | 
			
		||||
    c10::IntArrayRef padding(padding_ptr, padding_size);
 | 
			
		||||
    c10::IntArrayRef dilation(dilation_ptr, dilation_size);
 | 
			
		||||
    c10::IntArrayRef output_padding(output_padding_ptr, output_padding_size);
 | 
			
		||||
 | 
			
		||||
    at::Tensor out_tensor = at::convolution(
 | 
			
		||||
        *input_tensor,
 | 
			
		||||
        *weight_tensor,
 | 
			
		||||
        optional_bias,
 | 
			
		||||
        stride,
 | 
			
		||||
        padding,
 | 
			
		||||
        dilation,
 | 
			
		||||
        static_cast<bool>(transposed),
 | 
			
		||||
        output_padding,
 | 
			
		||||
        groups);
 | 
			
		||||
    at::Tensor* out_tensor_ptr = new at::Tensor(std::move(out_tensor));
 | 
			
		||||
    *out = tensor_pointer_to_tensor_handle(out_tensor_ptr);
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
 | 
			
		||||
  AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
 | 
			
		||||
    at::Tensor* out_tensor = new at::Tensor();
 | 
			
		||||
 | 
			
		||||
@ -1,167 +0,0 @@
 | 
			
		||||
#include <ATen/native/quantized/PackedParams.h>
 | 
			
		||||
#include <ATen/native/quantized/cpu/QuantUtils.h>
 | 
			
		||||
#include <torch/library.h>
 | 
			
		||||
#include <torch/torch.h>
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
using namespace torch::autograd;
 | 
			
		||||
using namespace at;
 | 
			
		||||
// This class is a custom gradient function that enables quantized tensor to
 | 
			
		||||
// pass input gradient back to the previous layers This function can be used
 | 
			
		||||
// when the user is adapting mixed precision for traninig after quantization
 | 
			
		||||
// From torch layer, we have no access to linear_dynamic operator which needs to
 | 
			
		||||
// access via redispatching mechanism TO-DO : currently we are supporting per
 | 
			
		||||
// tensor quantization only, will expand to per channel later on
 | 
			
		||||
class PackedLinearWeightDynamicBackward
 | 
			
		||||
    : public Function<PackedLinearWeightDynamicBackward> {
 | 
			
		||||
 public:
 | 
			
		||||
  static torch::Tensor forward(
 | 
			
		||||
      AutogradContext* ctx,
 | 
			
		||||
      at::Tensor input,
 | 
			
		||||
      const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
 | 
			
		||||
      bool reduce_range) {
 | 
			
		||||
    static auto op =
 | 
			
		||||
        at::Dispatcher::singleton()
 | 
			
		||||
            .findSchemaOrThrow("quantized::linear_dynamic", "")
 | 
			
		||||
            .typed<at::Tensor(
 | 
			
		||||
                at::Tensor,
 | 
			
		||||
                c10::intrusive_ptr<
 | 
			
		||||
                    LinearPackedParamsBase,
 | 
			
		||||
                    c10::detail::intrusive_target_default_null_type<
 | 
			
		||||
                        LinearPackedParamsBase>> const&,
 | 
			
		||||
                bool)>();
 | 
			
		||||
    // Calculate statistics for quantization of input Tensor
 | 
			
		||||
    float x_min = 0;
 | 
			
		||||
    float x_max = 0;
 | 
			
		||||
    if (input.numel() > 0) {
 | 
			
		||||
      auto input_contig = input.contiguous();
 | 
			
		||||
      x_min = input_contig.min().item<float>();
 | 
			
		||||
      x_max = input_contig.max().item<float>();
 | 
			
		||||
    }
 | 
			
		||||
    auto output = op.redispatch(
 | 
			
		||||
        DispatchKeySet({DispatchKey::CPU}),
 | 
			
		||||
        std::move(input),
 | 
			
		||||
        packed_weight,
 | 
			
		||||
        reduce_range);
 | 
			
		||||
    auto q_params = quant_utils::ChooseQuantizationParams(
 | 
			
		||||
        /*min=*/x_min,
 | 
			
		||||
        /*max=*/x_max,
 | 
			
		||||
        /*qmin=*/0,
 | 
			
		||||
        /*qmax=*/255);
 | 
			
		||||
    ctx->saved_data["weight"] = packed_weight;
 | 
			
		||||
    // q_params.scale : shape [1] (per-tensor)
 | 
			
		||||
    ctx->saved_data["input_scale"] = q_params.scale;
 | 
			
		||||
    return output;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
 | 
			
		||||
    if (grad_outputs.empty()) {
 | 
			
		||||
      return {torch::Tensor(), torch::Tensor(), torch::Tensor()};
 | 
			
		||||
    }
 | 
			
		||||
    auto packed_weight =
 | 
			
		||||
        ctx->saved_data["weight"].toCustomClass<LinearPackedParamsBase>();
 | 
			
		||||
    auto unpacked_parameters = packed_weight->unpack();
 | 
			
		||||
    auto original_weight = std::get<0>(unpacked_parameters);
 | 
			
		||||
    auto input_scale = ctx->saved_data["input_scale"].toDouble();
 | 
			
		||||
 | 
			
		||||
    // Gradient for post-scaling
 | 
			
		||||
    // Let us rewrite this layer by separating the matmul from the output
 | 
			
		||||
    // scaling: y = (x * s1) @ w * s2 + b So you now back-propagate through four
 | 
			
		||||
    // operations: + b, * s2, @ W, and * s1. The steps are: start with the
 | 
			
		||||
    // gradient from the top, aka the adjoint, which is grad_outputs[0].
 | 
			
		||||
    // gradient for  + b: this is a no-op.
 | 
			
		||||
    // gradient for * s2: scale by s2. That's the affine/per-channel scale baked
 | 
			
		||||
    // into W. gradient for @ W: matmul with W.t. gradient for * s1: scale by
 | 
			
		||||
    // s1.
 | 
			
		||||
    auto grad_output0 = grad_outputs[0];
 | 
			
		||||
    const auto qtype = original_weight.qscheme();
 | 
			
		||||
    if (qtype == at::kPerTensorAffine) {
 | 
			
		||||
      grad_output0 *= original_weight.q_scale();
 | 
			
		||||
      original_weight = at::permute(original_weight, {1, 0});
 | 
			
		||||
    } else if (qtype == at::kPerChannelAffine) {
 | 
			
		||||
      // Per Channel quantizer does not support transpose.
 | 
			
		||||
      // Manual transpose is necessary
 | 
			
		||||
      original_weight = original_weight.dequantize();
 | 
			
		||||
 | 
			
		||||
// kwanghoon(TODO): This is going to be a long term solution that is applicable
 | 
			
		||||
// to every models One issue with quantizing a gradient, we can't get good
 | 
			
		||||
// enough gradient to improve model accuracy when model become complicated As of
 | 
			
		||||
// now, we can disable, and comeback when we figure it out better solution.
 | 
			
		||||
#if 0
 | 
			
		||||
      // Enable Kernel backend for quantized backpropagaiton matrix
 | 
			
		||||
      // multiplication
 | 
			
		||||
      original_weight = at::permute(original_weight, {1, 0});
 | 
			
		||||
      // Take advantage of QNNPACK for matrix multiplication
 | 
			
		||||
      // Per channel scales & zero point computation
 | 
			
		||||
      // Sources :
 | 
			
		||||
      // https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/observer.py#L350-L353
 | 
			
		||||
      auto [amin, amax] = at::aminmax(original_weight, /*dim* = */ 1);
 | 
			
		||||
      // QInt8 type signed quantization
 | 
			
		||||
      auto qmax = 127;
 | 
			
		||||
      auto qmin = -128;
 | 
			
		||||
      // Clamp with some epsilon number, so that value does not go below zero
 | 
			
		||||
      auto epsilon = 1e-9;
 | 
			
		||||
      auto new_scales = (amax - amin) / float(qmax - qmin);
 | 
			
		||||
      new_scales = at::clamp(new_scales, epsilon);
 | 
			
		||||
      auto new_zero_point =
 | 
			
		||||
          qmin - at::round(amin / new_scales).toType(c10::kInt);
 | 
			
		||||
      new_zero_point = at::clamp(new_zero_point, qmin, qmax);
 | 
			
		||||
      // TO-DO (BUGBUG)
 | 
			
		||||
      // Backend kernel is designed for inference, tightly coded for output
 | 
			
		||||
      // channel. For mathematical correctness, we should enable to run kernel
 | 
			
		||||
      // with input channel axis after transpose. As workaround, we are simply
 | 
			
		||||
      // either exploring per tensor quantization or per channel quantization
 | 
			
		||||
      // with axis = 0
 | 
			
		||||
      original_weight = at::quantize_per_channel(
 | 
			
		||||
          original_weight,
 | 
			
		||||
          new_scales,
 | 
			
		||||
          new_zero_point,
 | 
			
		||||
          /*axis = 1 for transpose, but we are forcing it to non-transposed case
 | 
			
		||||
             due to above issue*/
 | 
			
		||||
          0,
 | 
			
		||||
          c10::kQInt8);
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
 | 
			
		||||
    }
 | 
			
		||||
#if 1
 | 
			
		||||
    // Pure FP32 computation, useful for debugging purpose
 | 
			
		||||
    auto dLdX1 = torch::matmul(grad_output0, original_weight);
 | 
			
		||||
#else
 | 
			
		||||
    // Take advantage of QNNPACK for matrix multiplication
 | 
			
		||||
    static auto op = at::Dispatcher::singleton()
 | 
			
		||||
                         .findSchemaOrThrow("quantized::linear_prepack", "")
 | 
			
		||||
                         .typed<c10::intrusive_ptr<LinearPackedParamsBase>(
 | 
			
		||||
                             at::Tensor, c10::optional<at::Tensor>)>();
 | 
			
		||||
    auto prepacked_weight = op.call(original_weight, nullopt);
 | 
			
		||||
 | 
			
		||||
    auto dLdX1 =
 | 
			
		||||
        prepacked_weight->apply_dynamic(grad_output0.toType(c10::kFloat));
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    auto input_grad0 = dLdX1 * input_scale;
 | 
			
		||||
    return {input_grad0, torch::Tensor(), torch::Tensor()};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
at::Tensor packed_linear_weight_grad(
 | 
			
		||||
    c10::DispatchKeySet ks,
 | 
			
		||||
    at::Tensor input,
 | 
			
		||||
    const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
 | 
			
		||||
    bool reduce_range) {
 | 
			
		||||
  return PackedLinearWeightDynamicBackward::apply(
 | 
			
		||||
      std::move(input), packed_weight, reduce_range);
 | 
			
		||||
}
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace native {
 | 
			
		||||
namespace {
 | 
			
		||||
TORCH_LIBRARY_IMPL(quantized, Autograd, m) {
 | 
			
		||||
  m.impl(
 | 
			
		||||
      TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
 | 
			
		||||
      TORCH_FN(packed_linear_weight_grad));
 | 
			
		||||
}
 | 
			
		||||
} // namespace
 | 
			
		||||
} // namespace native
 | 
			
		||||
} // namespace at
 | 
			
		||||
@ -73,7 +73,7 @@ class DiagTensorBelow(WrapperTensor):
 | 
			
		||||
        # For everything else, call the handler:
 | 
			
		||||
        fn = cls.handled_ops.get(func.__name__, None)
 | 
			
		||||
        if fn:
 | 
			
		||||
            return fn(*args, **kwargs or {})
 | 
			
		||||
            return fn(*args, **(kwargs or {}))
 | 
			
		||||
        else:
 | 
			
		||||
            # Note that here, because we don't need to provide the autograd formulas
 | 
			
		||||
            # we can have a default "fallback" that creates a plain Tensor based
 | 
			
		||||
 | 
			
		||||
@ -159,7 +159,7 @@ def generate_cct_and_mode(autograd_view_consistency=True):
 | 
			
		||||
 | 
			
		||||
        @classmethod
 | 
			
		||||
        def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
 | 
			
		||||
            all_args = pytree.arg_tree_leaves(*args, **kwargs or {})
 | 
			
		||||
            all_args = pytree.arg_tree_leaves(*args, **(kwargs or {}))
 | 
			
		||||
            modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
 | 
			
		||||
            if not all_same_mode(modes):
 | 
			
		||||
                raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user