mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			cpp-docs-d
			...
			gh/dzmitry
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b7c93c8157 | 
@ -13,4 +13,3 @@ exclude:
 | 
			
		||||
  - "**/benchmarks/**"
 | 
			
		||||
  - "**/test_*.py"
 | 
			
		||||
  - "**/*_test.py"
 | 
			
		||||
  - "tools/**"
 | 
			
		||||
 | 
			
		||||
@ -195,16 +195,13 @@ case "$tag" in
 | 
			
		||||
    NINJA_VERSION=1.9.0
 | 
			
		||||
    TRITON=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
 | 
			
		||||
  pytorch-linux-jammy-xpu-n-py3)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.10
 | 
			
		||||
    GCC_VERSION=11
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    XPU_VERSION=2025.2
 | 
			
		||||
    NINJA_VERSION=1.9.0
 | 
			
		||||
    TRITON=yes
 | 
			
		||||
    if [[ $tag =~ "benchmarks" ]]; then
 | 
			
		||||
      INDUCTOR_BENCHMARKS=yes
 | 
			
		||||
    fi
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.10
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@
 | 
			
		||||
 | 
			
		||||
set -eux
 | 
			
		||||
 | 
			
		||||
ACL_VERSION=${ACL_VERSION:-"v52.6.0"}
 | 
			
		||||
ACL_VERSION=${ACL_VERSION:-"v25.02"}
 | 
			
		||||
ACL_INSTALL_DIR="/acl"
 | 
			
		||||
 | 
			
		||||
# Clone ACL
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,11 @@ EOF
 | 
			
		||||
 | 
			
		||||
    # Default url values
 | 
			
		||||
    rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
 | 
			
		||||
    amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu"
 | 
			
		||||
 | 
			
		||||
    # Add amdgpu repository
 | 
			
		||||
    UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
 | 
			
		||||
    echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
 | 
			
		||||
 | 
			
		||||
    # Add rocm repository
 | 
			
		||||
    wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
 | 
			
		||||
 | 
			
		||||
@ -12,8 +12,8 @@ function do_install() {
 | 
			
		||||
 | 
			
		||||
    rocm_version_nodot=${rocm_version//./}
 | 
			
		||||
 | 
			
		||||
    # post merge of https://github.com/icl-utk-edu/magma/pull/65
 | 
			
		||||
    MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
 | 
			
		||||
    # https://github.com/icl-utk-edu/magma/pull/65
 | 
			
		||||
    MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
 | 
			
		||||
    magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
 | 
			
		||||
 | 
			
		||||
    rocm_dir="/opt/rocm"
 | 
			
		||||
 | 
			
		||||
@ -97,7 +97,7 @@ case ${image} in
 | 
			
		||||
    manylinux2_28-builder:xpu)
 | 
			
		||||
        TARGET=xpu_final
 | 
			
		||||
        GPU_IMAGE=amd64/almalinux:8
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13"
 | 
			
		||||
        DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=11"
 | 
			
		||||
        MANY_LINUX_VERSION="2_28"
 | 
			
		||||
        ;;
 | 
			
		||||
    *)
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,15 @@
 | 
			
		||||
sphinx==7.2.6
 | 
			
		||||
sphinx==5.3.0
 | 
			
		||||
#Description: This is used to generate PyTorch docs
 | 
			
		||||
#Pinned versions: 7.2.6
 | 
			
		||||
#Pinned versions: 5.3.0
 | 
			
		||||
 | 
			
		||||
pytorch_sphinx_theme2==0.2.0
 | 
			
		||||
#Description: This is needed to generate PyTorch docs
 | 
			
		||||
#Pinned versions: 0.2.0
 | 
			
		||||
standard-imghdr==3.13.0; python_version >= "3.13"
 | 
			
		||||
#Description: This is needed by Sphinx, so it needs to be added here.
 | 
			
		||||
# The reasons are as follows:
 | 
			
		||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
 | 
			
		||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
 | 
			
		||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
 | 
			
		||||
 | 
			
		||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
 | 
			
		||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
 | 
			
		||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
 | 
			
		||||
# something related to Docker setup. We can investigate this later.
 | 
			
		||||
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
 | 
			
		||||
#Description: This is used to generate PyTorch docs
 | 
			
		||||
#Pinned versions: 2.13.0
 | 
			
		||||
 | 
			
		||||
breathe==4.36.0
 | 
			
		||||
breathe==4.34.0
 | 
			
		||||
#Description: This is used to generate PyTorch C++ docs
 | 
			
		||||
#Pinned versions: 4.36.0
 | 
			
		||||
#Pinned versions: 4.34.0
 | 
			
		||||
 | 
			
		||||
exhale==0.3.7
 | 
			
		||||
exhale==0.2.3
 | 
			
		||||
#Description: This is used to generate PyTorch C++ docs
 | 
			
		||||
#Pinned versions: 0.3.7
 | 
			
		||||
#Pinned versions: 0.2.3
 | 
			
		||||
 | 
			
		||||
docutils==0.20
 | 
			
		||||
docutils==0.16
 | 
			
		||||
#Description: This is used to generate PyTorch C++ docs
 | 
			
		||||
#Pinned versions: 0.20
 | 
			
		||||
#Pinned versions: 0.16
 | 
			
		||||
 | 
			
		||||
bs4==0.0.1
 | 
			
		||||
#Description: This is used to generate PyTorch C++ docs
 | 
			
		||||
@ -52,13 +56,13 @@ IPython==8.12.0
 | 
			
		||||
#Description: This is used to generate PyTorch functorch docs
 | 
			
		||||
#Pinned versions: 8.12.0
 | 
			
		||||
 | 
			
		||||
myst-nb==1.3.0
 | 
			
		||||
myst-nb==0.17.2
 | 
			
		||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
 | 
			
		||||
#Pinned versions: 1.3.0
 | 
			
		||||
#Pinned versions: 0.17.2
 | 
			
		||||
 | 
			
		||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
 | 
			
		||||
python-etcd==0.4.5
 | 
			
		||||
sphinx-copybutton==0.5.0
 | 
			
		||||
sphinx-design==0.6.1
 | 
			
		||||
sphinx-design==0.4.0
 | 
			
		||||
sphinxcontrib-mermaid==1.0.0
 | 
			
		||||
myst-parser==4.0.1
 | 
			
		||||
myst-parser==0.18.1
 | 
			
		||||
 | 
			
		||||
@ -54,15 +54,12 @@ ENV OPENSSL_DIR /opt/openssl
 | 
			
		||||
RUN rm install_openssl.sh
 | 
			
		||||
 | 
			
		||||
ARG INDUCTOR_BENCHMARKS
 | 
			
		||||
ARG ANACONDA_PYTHON_VERSION
 | 
			
		||||
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
 | 
			
		||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
 | 
			
		||||
COPY ./common/common_utils.sh common_utils.sh
 | 
			
		||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
 | 
			
		||||
COPY ci_commit_pins/timm.txt timm.txt
 | 
			
		||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
 | 
			
		||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
 | 
			
		||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
 | 
			
		||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
 | 
			
		||||
 | 
			
		||||
# Install XPU Dependencies
 | 
			
		||||
ARG XPU_VERSION
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ dependencies = [
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
    "uv==0.9.6"
 | 
			
		||||
    "uv==0.9.5"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
SHELL=/usr/bin/env bash
 | 
			
		||||
 | 
			
		||||
DOCKER_CMD ?= docker
 | 
			
		||||
DESIRED_ROCM ?= 7.1
 | 
			
		||||
DESIRED_ROCM ?= 7.0
 | 
			
		||||
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
 | 
			
		||||
PACKAGE_NAME = magma-rocm
 | 
			
		||||
# inherit this from underlying docker image, do not pass this env var to docker
 | 
			
		||||
@ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
 | 
			
		||||
	magma-rocm/build_magma.sh
 | 
			
		||||
 | 
			
		||||
.PHONY: all
 | 
			
		||||
all: magma-rocm71
 | 
			
		||||
all: magma-rocm70
 | 
			
		||||
all: magma-rocm64
 | 
			
		||||
 | 
			
		||||
@ -25,11 +24,6 @@ clean:
 | 
			
		||||
	$(RM) -r magma-*
 | 
			
		||||
	$(RM) -r output
 | 
			
		||||
 | 
			
		||||
.PHONY: magma-rocm71
 | 
			
		||||
magma-rocm71: DESIRED_ROCM := 7.1
 | 
			
		||||
magma-rocm71:
 | 
			
		||||
	$(DOCKER_RUN)
 | 
			
		||||
 | 
			
		||||
.PHONY: magma-rocm70
 | 
			
		||||
magma-rocm70: DESIRED_ROCM := 7.0
 | 
			
		||||
magma-rocm70:
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,8 @@ set -eou pipefail
 | 
			
		||||
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
 | 
			
		||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
 | 
			
		||||
 | 
			
		||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
 | 
			
		||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
 | 
			
		||||
# https://github.com/icl-utk-edu/magma/pull/65
 | 
			
		||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
 | 
			
		||||
 | 
			
		||||
# Folders for the build
 | 
			
		||||
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
 | 
			
		||||
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
 | 
			
		||||
 | 
			
		||||
# Fetch magma sources and verify checksum
 | 
			
		||||
pushd ${PACKAGE_DIR}
 | 
			
		||||
git clone https://github.com/icl-utk-edu/magma
 | 
			
		||||
git clone https://github.com/jeffdaily/magma
 | 
			
		||||
pushd magma
 | 
			
		||||
git checkout ${MAGMA_VERSION}
 | 
			
		||||
popd
 | 
			
		||||
 | 
			
		||||
@ -426,7 +426,7 @@ fi
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
 | 
			
		||||
  # export test times so that potential sharded tests that'll branch off this build will use consistent data
 | 
			
		||||
  # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
 | 
			
		||||
  PYTHONPATH=. python tools/stats/export_test_times.py
 | 
			
		||||
  python tools/stats/export_test_times.py
 | 
			
		||||
fi
 | 
			
		||||
# don't do this for bazel or s390x or riscv64 as they don't use sccache
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
 | 
			
		||||
 | 
			
		||||
@ -89,23 +89,20 @@ if [ "$is_main_doc" = true ]; then
 | 
			
		||||
 | 
			
		||||
  make coverage
 | 
			
		||||
  # Now we have the coverage report, we need to make sure it is empty.
 | 
			
		||||
  # Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
 | 
			
		||||
  # showing the undocumented count in the third column.
 | 
			
		||||
  # Example: | TOTAL | 99.83% | 2 |
 | 
			
		||||
  # Count the number of lines in the file and turn that number into a variable
 | 
			
		||||
  # $lines. The `cut -f1 ...` is to only parse the number, not the filename
 | 
			
		||||
  # Skip the report header by subtracting 2: the header will be output even if
 | 
			
		||||
  # there are no undocumented items.
 | 
			
		||||
  #
 | 
			
		||||
  # Also: see docs/source/conf.py for "coverage_ignore*" items, which should
 | 
			
		||||
  # be documented then removed from there.
 | 
			
		||||
 | 
			
		||||
  # Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
 | 
			
		||||
  # The table format is: | Module | Coverage | Undocumented |
 | 
			
		||||
  # Extract the third column (undocumented count) from the TOTAL row
 | 
			
		||||
  undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
 | 
			
		||||
 | 
			
		||||
  if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
 | 
			
		||||
  lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
 | 
			
		||||
  undocumented=$((lines - 2))
 | 
			
		||||
  if [ $undocumented -lt 0 ]; then
 | 
			
		||||
    echo coverage output not found
 | 
			
		||||
    exit 1
 | 
			
		||||
  elif [ "$undocumented" -gt 0 ]; then
 | 
			
		||||
    echo "undocumented objects found:"
 | 
			
		||||
  elif [ $undocumented -gt 0 ]; then
 | 
			
		||||
    echo undocumented objects found:
 | 
			
		||||
    cat build/coverage/python.txt
 | 
			
		||||
    echo "Make sure you've updated relevant .rsts in docs/source!"
 | 
			
		||||
    echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
 | 
			
		||||
 | 
			
		||||
@ -572,8 +572,6 @@ fi
 | 
			
		||||
 | 
			
		||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
 | 
			
		||||
  DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
 | 
			
		||||
  DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
 | 
			
		||||
else
 | 
			
		||||
  DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
 | 
			
		||||
fi
 | 
			
		||||
@ -667,8 +665,6 @@ test_perf_for_dashboard() {
 | 
			
		||||
    device=cuda_b200
 | 
			
		||||
  elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
 | 
			
		||||
    device=rocm
 | 
			
		||||
  elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
 | 
			
		||||
    device=xpu
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  for mode in "${modes[@]}"; do
 | 
			
		||||
@ -1761,7 +1757,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
 | 
			
		||||
  else
 | 
			
		||||
    # Do this after checkout_install_torchbench to ensure we clobber any
 | 
			
		||||
    # nightlies that torchbench may pull in
 | 
			
		||||
    if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then
 | 
			
		||||
    if [[ "${TEST_CONFIG}" != *cpu* ]]; then
 | 
			
		||||
      install_torchrec_and_fbgemm
 | 
			
		||||
    fi
 | 
			
		||||
    PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
 | 
			
		||||
 | 
			
		||||
@ -60,11 +60,9 @@ performance-*,
 | 
			
		||||
readability-container-size-empty,
 | 
			
		||||
readability-delete-null-pointer,
 | 
			
		||||
readability-duplicate-include,
 | 
			
		||||
readability-named-parameter,
 | 
			
		||||
readability-misplaced-array-index,
 | 
			
		||||
readability-redundant*,
 | 
			
		||||
readability-simplify-subscript-expr,
 | 
			
		||||
readability-static-definition-in-anonymous-namespace
 | 
			
		||||
readability-string-compare,
 | 
			
		||||
-readability-redundant-access-specifiers,
 | 
			
		||||
-readability-redundant-control-flow,
 | 
			
		||||
 | 
			
		||||
@ -1,319 +0,0 @@
 | 
			
		||||
---
 | 
			
		||||
name: add-uint-support
 | 
			
		||||
description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# Add Unsigned Integer (uint) Support to Operators
 | 
			
		||||
 | 
			
		||||
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
 | 
			
		||||
 | 
			
		||||
## When to use this skill
 | 
			
		||||
 | 
			
		||||
Use this skill when:
 | 
			
		||||
- Adding uint16, uint32, or uint64 support to an operator
 | 
			
		||||
- User mentions "unsigned types", "uint support", "barebones unsigned types"
 | 
			
		||||
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
 | 
			
		||||
- Working with operator implementations that need expanded type coverage
 | 
			
		||||
 | 
			
		||||
## Quick reference
 | 
			
		||||
 | 
			
		||||
**Add unsigned types to existing dispatch:**
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES));
 | 
			
		||||
 | 
			
		||||
// After (method 1: add unsigned types explicitly)
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
 | 
			
		||||
 | 
			
		||||
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Type group reference
 | 
			
		||||
 | 
			
		||||
**Unsigned type groups:**
 | 
			
		||||
- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64
 | 
			
		||||
- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
 | 
			
		||||
 | 
			
		||||
**Relationship:**
 | 
			
		||||
```cpp
 | 
			
		||||
AT_INTEGRAL_TYPES          // kByte, kChar, kInt, kLong, kShort
 | 
			
		||||
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
 | 
			
		||||
AT_INTEGRAL_TYPES_V2       // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
### Step 1: Determine if conversion to V2 is needed
 | 
			
		||||
 | 
			
		||||
Check if the file uses AT_DISPATCH_V2:
 | 
			
		||||
 | 
			
		||||
**If using old AT_DISPATCH:**
 | 
			
		||||
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
 | 
			
		||||
- Then proceed with adding uint support
 | 
			
		||||
 | 
			
		||||
**If already using AT_DISPATCH_V2:**
 | 
			
		||||
- Proceed directly to Step 2
 | 
			
		||||
 | 
			
		||||
### Step 2: Analyze the current dispatch macro
 | 
			
		||||
 | 
			
		||||
Identify what type groups are currently in use:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  // body
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
 | 
			
		||||
    ^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
    Current type coverage
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Common patterns:
 | 
			
		||||
- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
 | 
			
		||||
- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only
 | 
			
		||||
- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types
 | 
			
		||||
 | 
			
		||||
### Step 3: Choose the uint addition method
 | 
			
		||||
 | 
			
		||||
Two approaches:
 | 
			
		||||
 | 
			
		||||
**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly**
 | 
			
		||||
- Use when: You want to be explicit about adding uint support
 | 
			
		||||
- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list
 | 
			
		||||
 | 
			
		||||
**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2**
 | 
			
		||||
- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)`
 | 
			
		||||
- More concise: replaces one type group with its superset
 | 
			
		||||
- Only applicable if AT_INTEGRAL_TYPES is present
 | 
			
		||||
 | 
			
		||||
### Step 4: Apply the transformation
 | 
			
		||||
 | 
			
		||||
**Method 1 example:**
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    dtype,
 | 
			
		||||
    "min_values_cuda",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      kernel_impl<scalar_t>(iter);
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_ALL_TYPES),
 | 
			
		||||
    kBFloat16, kHalf, kBool
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// After (add unsigned types)
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    dtype,
 | 
			
		||||
    "min_values_cuda",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      kernel_impl<scalar_t>(iter);
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_ALL_TYPES),
 | 
			
		||||
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
 | 
			
		||||
    kBFloat16, kHalf, kBool
 | 
			
		||||
);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Method 2 example:**
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    dtype,
 | 
			
		||||
    "integral_op",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      kernel<scalar_t>();
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_INTEGRAL_TYPES)
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// After (substitute with V2)
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    dtype,
 | 
			
		||||
    "integral_op",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      kernel<scalar_t>();
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_INTEGRAL_TYPES_V2)
 | 
			
		||||
);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 5: Handle AT_ALL_TYPES vs individual type groups
 | 
			
		||||
 | 
			
		||||
If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`:
 | 
			
		||||
- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES`
 | 
			
		||||
- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list
 | 
			
		||||
 | 
			
		||||
If the dispatch separately lists INTEGRAL and FLOATING:
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
 | 
			
		||||
 | 
			
		||||
// After (Method 2 preferred)
 | 
			
		||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 6: Verify all dispatch sites
 | 
			
		||||
 | 
			
		||||
Check the file for ALL dispatch macros that need uint support:
 | 
			
		||||
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
 | 
			
		||||
- Apply the transformation consistently across all sites
 | 
			
		||||
- Ensure each gets the same type coverage updates
 | 
			
		||||
 | 
			
		||||
### Step 7: Validate the changes
 | 
			
		||||
 | 
			
		||||
Check that:
 | 
			
		||||
- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
 | 
			
		||||
- [ ] Unsigned types are added via one of the two methods
 | 
			
		||||
- [ ] All relevant dispatch sites in the file are updated
 | 
			
		||||
- [ ] Type groups use `AT_EXPAND()`
 | 
			
		||||
- [ ] Arguments are properly formatted and comma-separated
 | 
			
		||||
 | 
			
		||||
## Common patterns
 | 
			
		||||
 | 
			
		||||
### Pattern 1: AT_ALL_TYPES + extras
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Pattern 2: Separate INTEGRAL + FLOATING
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Pattern 3: Old dispatch needs conversion first
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before (needs v2 conversion first)
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
// After v2 conversion
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
 | 
			
		||||
 | 
			
		||||
// After adding uint support
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Multiple dispatch sites example
 | 
			
		||||
 | 
			
		||||
For a file with multiple functions:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
void min_values_kernel_cuda(TensorIterator& iter) {
 | 
			
		||||
  AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
 | 
			
		||||
    impl<scalar_t>(iter);
 | 
			
		||||
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
 | 
			
		||||
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
  //                           Added uint support
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void min_launch_kernel(TensorIterator &iter) {
 | 
			
		||||
  AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
 | 
			
		||||
    gpu_reduce_kernel<scalar_t>(iter);
 | 
			
		||||
  }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
 | 
			
		||||
  //                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
  //                           Added uint support here too
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Decision tree
 | 
			
		||||
 | 
			
		||||
Use this decision tree to determine the approach:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
Is the file using AT_DISPATCH_V2?
 | 
			
		||||
├─ No → Use at-dispatch-v2 skill first, then continue
 | 
			
		||||
└─ Yes
 | 
			
		||||
   └─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
 | 
			
		||||
      ├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
 | 
			
		||||
      └─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Edge cases
 | 
			
		||||
 | 
			
		||||
### Case 1: Dispatch with only floating types
 | 
			
		||||
 | 
			
		||||
If the operator only supports floating point types, don't add uint support:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Leave as-is - floating point only operator
 | 
			
		||||
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Case 2: Complex types present
 | 
			
		||||
 | 
			
		||||
Unsigned types work alongside complex types:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES),
 | 
			
		||||
    AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
 | 
			
		||||
    AT_EXPAND(AT_COMPLEX_TYPES),
 | 
			
		||||
    kHalf, kBFloat16);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Case 3: Already has uint support
 | 
			
		||||
 | 
			
		||||
Check if uint types are already present:
 | 
			
		||||
- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support
 | 
			
		||||
- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support
 | 
			
		||||
- Skip the file if uint support is already present
 | 
			
		||||
 | 
			
		||||
## Workflow
 | 
			
		||||
 | 
			
		||||
When asked to add uint support:
 | 
			
		||||
 | 
			
		||||
1. Read the target file
 | 
			
		||||
2. Check if using AT_DISPATCH_V2:
 | 
			
		||||
   - If not → use at-dispatch-v2 skill first
 | 
			
		||||
3. Identify all dispatch macro sites
 | 
			
		||||
4. For each dispatch:
 | 
			
		||||
   - Analyze current type groups
 | 
			
		||||
   - Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
 | 
			
		||||
   - Apply transformation with Edit tool
 | 
			
		||||
5. Show the user the changes
 | 
			
		||||
6. Explain what was modified
 | 
			
		||||
 | 
			
		||||
## Important notes
 | 
			
		||||
 | 
			
		||||
- Always check if v2 conversion is needed first
 | 
			
		||||
- Apply changes consistently across all dispatch sites in the file
 | 
			
		||||
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
 | 
			
		||||
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
 | 
			
		||||
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
 | 
			
		||||
- Some operators may not semantically support unsigned types - use judgment
 | 
			
		||||
 | 
			
		||||
## Testing
 | 
			
		||||
 | 
			
		||||
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
 | 
			
		||||
@ -1,305 +0,0 @@
 | 
			
		||||
---
 | 
			
		||||
name: at-dispatch-v2
 | 
			
		||||
description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
# AT_DISPATCH to AT_DISPATCH_V2 Converter
 | 
			
		||||
 | 
			
		||||
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`.
 | 
			
		||||
 | 
			
		||||
## When to use this skill
 | 
			
		||||
 | 
			
		||||
Use this skill when:
 | 
			
		||||
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
 | 
			
		||||
- Porting ATen kernels to use the new dispatch API
 | 
			
		||||
- Working with files in `aten/src/ATen/native/` that use dispatch macros
 | 
			
		||||
- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
 | 
			
		||||
 | 
			
		||||
## Quick reference
 | 
			
		||||
 | 
			
		||||
**Old format:**
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
 | 
			
		||||
  // lambda body
 | 
			
		||||
});
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**New format:**
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
 | 
			
		||||
  // lambda body
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Key transformations
 | 
			
		||||
 | 
			
		||||
1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types
 | 
			
		||||
2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas
 | 
			
		||||
3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion
 | 
			
		||||
4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
 | 
			
		||||
5. **Add include**: `#include <ATen/Dispatch_v2.h>` near other Dispatch includes
 | 
			
		||||
 | 
			
		||||
## Instructions
 | 
			
		||||
 | 
			
		||||
### Step 1: Add the Dispatch_v2.h include
 | 
			
		||||
 | 
			
		||||
Add the v2 header near the existing `#include <ATen/Dispatch.h>`:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
#include <ATen/Dispatch.h>
 | 
			
		||||
#include <ATen/Dispatch_v2.h>
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Keep the old Dispatch.h include for now (other code may still need it).
 | 
			
		||||
 | 
			
		||||
### Step 2: Identify the old dispatch pattern
 | 
			
		||||
 | 
			
		||||
Common patterns to convert:
 | 
			
		||||
 | 
			
		||||
- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)`
 | 
			
		||||
- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)`
 | 
			
		||||
- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)`
 | 
			
		||||
- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)`
 | 
			
		||||
 | 
			
		||||
### Step 3: Map the old macro to type groups
 | 
			
		||||
 | 
			
		||||
Identify which type group macro corresponds to the base types:
 | 
			
		||||
 | 
			
		||||
| Old macro base | AT_DISPATCH_V2 type group |
 | 
			
		||||
|----------------|---------------------------|
 | 
			
		||||
| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` |
 | 
			
		||||
| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` |
 | 
			
		||||
| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` |
 | 
			
		||||
| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` |
 | 
			
		||||
| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` |
 | 
			
		||||
 | 
			
		||||
For combined patterns, use multiple `AT_EXPAND()` entries:
 | 
			
		||||
```cpp
 | 
			
		||||
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
 | 
			
		||||
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 4: Extract the individual types
 | 
			
		||||
 | 
			
		||||
From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.).
 | 
			
		||||
 | 
			
		||||
These become the trailing arguments after the type group:
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
 | 
			
		||||
                                             ^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
                                             Individual types from AND3
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 5: Transform to AT_DISPATCH_V2
 | 
			
		||||
 | 
			
		||||
Apply the transformation:
 | 
			
		||||
 | 
			
		||||
**Pattern:**
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
  scalar_type,           // 1st: The dtype expression
 | 
			
		||||
  "name",                // 2nd: The debug string
 | 
			
		||||
  AT_WRAP(lambda),       // 3rd: The lambda wrapped in AT_WRAP
 | 
			
		||||
  type_groups,           // 4th+: Type groups with AT_EXPAND()
 | 
			
		||||
  individual_types       // Last: Individual types
 | 
			
		||||
)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Example transformation:**
 | 
			
		||||
```cpp
 | 
			
		||||
// BEFORE
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND3(
 | 
			
		||||
    kBFloat16, kHalf, kBool,
 | 
			
		||||
    iter.dtype(),
 | 
			
		||||
    "min_values_cuda",
 | 
			
		||||
    [&]() {
 | 
			
		||||
      min_values_kernel_cuda_impl<scalar_t>(iter);
 | 
			
		||||
    }
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// AFTER
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    iter.dtype(),
 | 
			
		||||
    "min_values_cuda",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      min_values_kernel_cuda_impl<scalar_t>(iter);
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_ALL_TYPES),
 | 
			
		||||
    kBFloat16, kHalf, kBool
 | 
			
		||||
);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 6: Handle multi-line lambdas
 | 
			
		||||
 | 
			
		||||
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    dtype,
 | 
			
		||||
    "complex_kernel",
 | 
			
		||||
    AT_WRAP([&]() {
 | 
			
		||||
      gpu_reduce_kernel<scalar_t, scalar_t>(
 | 
			
		||||
        iter,
 | 
			
		||||
        MinOps<scalar_t>{},
 | 
			
		||||
        thrust::pair<scalar_t, int64_t>(upper_bound(), 0)  // Commas inside!
 | 
			
		||||
      );
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_ALL_TYPES)
 | 
			
		||||
);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Step 7: Verify the conversion
 | 
			
		||||
 | 
			
		||||
Check that:
 | 
			
		||||
- [ ] `AT_WRAP()` wraps the entire lambda
 | 
			
		||||
- [ ] Type groups use `AT_EXPAND()`
 | 
			
		||||
- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`)
 | 
			
		||||
- [ ] Argument order is: scalar_type, name, lambda, types
 | 
			
		||||
- [ ] Include added: `#include <ATen/Dispatch_v2.h>`
 | 
			
		||||
 | 
			
		||||
## Type group reference
 | 
			
		||||
 | 
			
		||||
Available type group macros (use with `AT_EXPAND()`):
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
AT_INTEGRAL_TYPES      // kByte, kChar, kInt, kLong, kShort
 | 
			
		||||
AT_FLOATING_TYPES      // kDouble, kFloat
 | 
			
		||||
AT_COMPLEX_TYPES       // kComplexDouble, kComplexFloat
 | 
			
		||||
AT_QINT_TYPES         // kQInt8, kQUInt8, kQInt32
 | 
			
		||||
AT_ALL_TYPES          // INTEGRAL_TYPES + FLOATING_TYPES
 | 
			
		||||
AT_ALL_TYPES_AND_COMPLEX  // ALL_TYPES + COMPLEX_TYPES
 | 
			
		||||
AT_INTEGRAL_TYPES_V2  // INTEGRAL_TYPES + unsigned types
 | 
			
		||||
AT_BAREBONES_UNSIGNED_TYPES  // kUInt16, kUInt32, kUInt64
 | 
			
		||||
AT_FLOAT8_TYPES       // Float8 variants
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Common patterns
 | 
			
		||||
 | 
			
		||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND2
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
 | 
			
		||||
  kernel<scalar_t>(data);
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>(data);
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
 | 
			
		||||
    tensor.scalar_type(), "float_op", [&] {
 | 
			
		||||
  process<scalar_t>(tensor);
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
 | 
			
		||||
  process<scalar_t>(tensor);
 | 
			
		||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
 | 
			
		||||
    kComplexHalf, kHalf,
 | 
			
		||||
    self.scalar_type(),
 | 
			
		||||
    "complex_op",
 | 
			
		||||
    [&] {
 | 
			
		||||
      result = compute<scalar_t>(self);
 | 
			
		||||
    }
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(
 | 
			
		||||
    self.scalar_type(),
 | 
			
		||||
    "complex_op",
 | 
			
		||||
    AT_WRAP([&] {
 | 
			
		||||
      result = compute<scalar_t>(self);
 | 
			
		||||
    }),
 | 
			
		||||
    AT_EXPAND(AT_ALL_TYPES),
 | 
			
		||||
    AT_EXPAND(AT_COMPLEX_TYPES),
 | 
			
		||||
    kComplexHalf,
 | 
			
		||||
    kHalf
 | 
			
		||||
);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Edge cases
 | 
			
		||||
 | 
			
		||||
### Case 1: No extra types (rare)
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES));
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Case 2: Many individual types (AND4, AND5, etc.)
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
 | 
			
		||||
    dtype, "float8_op", [&]() { kernel<scalar_t>(); });
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
 | 
			
		||||
  kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Case 3: Lambda with no captures
 | 
			
		||||
 | 
			
		||||
```cpp
 | 
			
		||||
// Before
 | 
			
		||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
 | 
			
		||||
  static_kernel<scalar_t>();
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
// After
 | 
			
		||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
 | 
			
		||||
  static_kernel<scalar_t>();
 | 
			
		||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Benefits of AT_DISPATCH_V2
 | 
			
		||||
 | 
			
		||||
1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4
 | 
			
		||||
2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()`
 | 
			
		||||
3. **Extensible**: Easy to add more types without hitting macro limits
 | 
			
		||||
4. **Clearer**: Type groups are explicit, not implicit in macro name
 | 
			
		||||
 | 
			
		||||
## Important notes
 | 
			
		||||
 | 
			
		||||
- Keep `#include <ATen/Dispatch.h>` - other code may need it
 | 
			
		||||
- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda
 | 
			
		||||
- Type groups need `AT_EXPAND()`, individual types don't
 | 
			
		||||
- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs
 | 
			
		||||
- See the header file for the Python script to regenerate the macro implementation
 | 
			
		||||
 | 
			
		||||
## Workflow
 | 
			
		||||
 | 
			
		||||
When asked to convert AT_DISPATCH macros:
 | 
			
		||||
 | 
			
		||||
1. Read the file to identify all AT_DISPATCH uses
 | 
			
		||||
2. Add `#include <ATen/Dispatch_v2.h>` if not present
 | 
			
		||||
3. For each dispatch macro:
 | 
			
		||||
   - Identify the pattern and extract components
 | 
			
		||||
   - Map the base type group
 | 
			
		||||
   - Extract individual types
 | 
			
		||||
   - Construct the AT_DISPATCH_V2 call
 | 
			
		||||
   - Apply with Edit tool
 | 
			
		||||
4. Show the user the complete converted file
 | 
			
		||||
5. Explain what was changed
 | 
			
		||||
 | 
			
		||||
Do NOT compile or test the code - focus on accurate conversion only.
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/actions/diskspace-cleanup/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/actions/diskspace-cleanup/action.yml
									
									
									
									
										vendored
									
									
								
							@ -27,9 +27,7 @@ runs:
 | 
			
		||||
            docker system prune -af
 | 
			
		||||
            diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
 | 
			
		||||
            if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
 | 
			
		||||
                diskspace_cutoff_int=$((diskspace_cutoff + 0))
 | 
			
		||||
                difference=$((100 - diskspace_cutoff_int))
 | 
			
		||||
                echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
 | 
			
		||||
                echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
 | 
			
		||||
                echo "$msg"
 | 
			
		||||
                exit 1
 | 
			
		||||
            else
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
 | 
			
		||||
69bbe7363897764f9e758d851cd0340147d27f94
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vision.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
 | 
			
		||||
218d2ab791d437309f91e0486eb9fa7f00badc17
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
 | 
			
		||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -19,7 +19,6 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-x86-zen
 | 
			
		||||
- ciflow/inductor-perf-test-nightly-xpu
 | 
			
		||||
- ciflow/inductor-periodic
 | 
			
		||||
- ciflow/inductor-rocm
 | 
			
		||||
- ciflow/linux-aarch64
 | 
			
		||||
@ -27,7 +26,6 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/nightly
 | 
			
		||||
- ciflow/op-benchmark
 | 
			
		||||
- ciflow/periodic
 | 
			
		||||
- ciflow/periodic-rocm-mi200
 | 
			
		||||
- ciflow/periodic-rocm-mi300
 | 
			
		||||
- ciflow/pull
 | 
			
		||||
- ciflow/quantization-periodic
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										89
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										89
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -11,17 +11,11 @@ architectures:
 | 
			
		||||
    * Latest XPU
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SCRIPT_DIR = Path(__file__).absolute().parent
 | 
			
		||||
REPO_ROOT = SCRIPT_DIR.parent.parent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
 | 
			
		||||
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
 | 
			
		||||
CUDA_STABLE = "12.8"
 | 
			
		||||
CUDA_ARCHES_FULL_VERSION = {
 | 
			
		||||
@ -37,7 +31,8 @@ CUDA_ARCHES_CUDNN_VERSION = {
 | 
			
		||||
    "13.0": "9",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ROCM_ARCHES = ["7.0", "7.1"]
 | 
			
		||||
# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
 | 
			
		||||
ROCM_ARCHES = ["6.4", "7.0"]
 | 
			
		||||
 | 
			
		||||
XPU_ARCHES = ["xpu"]
 | 
			
		||||
 | 
			
		||||
@ -142,48 +137,9 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Used by tools/nightly.py
 | 
			
		||||
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
 | 
			
		||||
NIGHTLY_SOURCE_MATRIX = {
 | 
			
		||||
    "cpu": dict(
 | 
			
		||||
        name="cpu",
 | 
			
		||||
        index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
 | 
			
		||||
        supported_platforms=["Linux", "macOS", "Windows"],
 | 
			
		||||
        accelerator="cpu",
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
CUDA_NIGHTLY_SOURCE_MATRIX = {
 | 
			
		||||
    f"cuda-{major}.{minor}": dict(
 | 
			
		||||
        name=f"cuda-{major}.{minor}",
 | 
			
		||||
        index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
 | 
			
		||||
        supported_platforms=["Linux", "Windows"],
 | 
			
		||||
        accelerator="cuda",
 | 
			
		||||
    )
 | 
			
		||||
    for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
 | 
			
		||||
}
 | 
			
		||||
ROCM_NIGHTLY_SOURCE_MATRIX = {
 | 
			
		||||
    f"rocm-{major}.{minor}": dict(
 | 
			
		||||
        name=f"rocm-{major}.{minor}",
 | 
			
		||||
        index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
 | 
			
		||||
        supported_platforms=["Linux"],
 | 
			
		||||
        accelerator="rocm",
 | 
			
		||||
    )
 | 
			
		||||
    for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
 | 
			
		||||
}
 | 
			
		||||
XPU_NIGHTLY_SOURCE_MATRIX = {
 | 
			
		||||
    "xpu": dict(
 | 
			
		||||
        name="xpu",
 | 
			
		||||
        index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
 | 
			
		||||
        supported_platforms=["Linux"],
 | 
			
		||||
        accelerator="xpu",
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
 | 
			
		||||
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
 | 
			
		||||
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_nccl_wheel_version(arch_version: str) -> str:
 | 
			
		||||
    import re
 | 
			
		||||
 | 
			
		||||
    requirements = map(
 | 
			
		||||
        str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
 | 
			
		||||
    )
 | 
			
		||||
@ -191,14 +147,17 @@ def get_nccl_wheel_version(arch_version: str) -> str:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_nccl_pin(arch_version: str) -> str:
 | 
			
		||||
    nccl_pin_path = (
 | 
			
		||||
        REPO_ROOT
 | 
			
		||||
        / ".ci"
 | 
			
		||||
        / "docker"
 | 
			
		||||
        / "ci_commit_pins"
 | 
			
		||||
        / f"nccl-cu{arch_version[:2]}.txt"
 | 
			
		||||
    from pathlib import Path
 | 
			
		||||
 | 
			
		||||
    nccl_pin_path = os.path.join(
 | 
			
		||||
        Path(__file__).absolute().parents[2],
 | 
			
		||||
        ".ci",
 | 
			
		||||
        "docker",
 | 
			
		||||
        "ci_commit_pins",
 | 
			
		||||
        f"nccl-cu{arch_version[:2]}.txt",
 | 
			
		||||
    )
 | 
			
		||||
    return nccl_pin_path.read_text().strip()
 | 
			
		||||
    with open(nccl_pin_path) as f:
 | 
			
		||||
        return f.read().strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_nccl_dep_consistency(arch_version: str) -> None:
 | 
			
		||||
@ -206,8 +165,7 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
 | 
			
		||||
    wheel_ver = get_nccl_wheel_version(arch_version)
 | 
			
		||||
    if not nccl_release_tag.startswith(f"v{wheel_ver}"):
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            f"{arch_version} NCCL release tag version {nccl_release_tag} "
 | 
			
		||||
            f"does not correspond to wheel version {wheel_ver}"
 | 
			
		||||
            f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -454,14 +412,7 @@ def generate_wheels_matrix(
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
arch_version = ""
 | 
			
		||||
for arch_version in CUDA_ARCHES:
 | 
			
		||||
    validate_nccl_dep_consistency(arch_version)
 | 
			
		||||
del arch_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Used by tools/nightly.py
 | 
			
		||||
    (SCRIPT_DIR / "nightly_source_matrix.json").write_text(
 | 
			
		||||
        json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
 | 
			
		||||
    )
 | 
			
		||||
validate_nccl_dep_consistency("13.0")
 | 
			
		||||
validate_nccl_dep_consistency("12.9")
 | 
			
		||||
validate_nccl_dep_consistency("12.8")
 | 
			
		||||
validate_nccl_dep_consistency("12.6")
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								.github/workflows/_xpu-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.github/workflows/_xpu-test.yml
									
									
									
									
										vendored
									
									
								
							@ -38,10 +38,6 @@ on:
 | 
			
		||||
        default: ""
 | 
			
		||||
        description: |
 | 
			
		||||
          List of tests to include (empty string implies default list)
 | 
			
		||||
      dashboard-tag:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: ""
 | 
			
		||||
      disable-monitor:
 | 
			
		||||
        description: |
 | 
			
		||||
          [Experimental] Disable utilization monitoring for tests.
 | 
			
		||||
@ -62,11 +58,6 @@ on:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: number
 | 
			
		||||
        default: 1
 | 
			
		||||
    secrets:
 | 
			
		||||
      HUGGING_FACE_HUB_TOKEN:
 | 
			
		||||
        required: false
 | 
			
		||||
        description: |
 | 
			
		||||
          HF Auth token to avoid rate limits when downloading models or datasets from hub
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
@ -205,8 +196,6 @@ jobs:
 | 
			
		||||
          PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
 | 
			
		||||
          PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
 | 
			
		||||
          TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
 | 
			
		||||
          DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
 | 
			
		||||
          HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
 | 
			
		||||
        timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
 | 
			
		||||
        run: |
 | 
			
		||||
          # Fetch aws credential from IMDs
 | 
			
		||||
@ -257,8 +246,6 @@ jobs:
 | 
			
		||||
            -e PYTORCH_TEST_RERUN_DISABLED_TESTS \
 | 
			
		||||
            -e TESTS_TO_INCLUDE \
 | 
			
		||||
            -e ZE_AFFINITY_MASK \
 | 
			
		||||
            -e HUGGING_FACE_HUB_TOKEN \
 | 
			
		||||
            -e DASHBOARD_TAG \
 | 
			
		||||
            --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
 | 
			
		||||
            --ulimit stack=10485760:83886080 \
 | 
			
		||||
            --ulimit core=0 \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/build-almalinux-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-almalinux-images.yml
									
									
									
									
										vendored
									
									
								
							@ -36,7 +36,7 @@ jobs:
 | 
			
		||||
    runs-on: linux.9xlarge.ephemeral
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"]
 | 
			
		||||
        tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Build docker image
 | 
			
		||||
        uses: pytorch/pytorch/.github/actions/binary-docker-build@main
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/build-libtorch-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-libtorch-images.yml
									
									
									
									
										vendored
									
									
								
							@ -52,8 +52,8 @@ jobs:
 | 
			
		||||
          { tag: "cuda12.9" },
 | 
			
		||||
          { tag: "cuda12.8" },
 | 
			
		||||
          { tag: "cuda12.6" },
 | 
			
		||||
          { tag: "rocm6.4"  },
 | 
			
		||||
          { tag: "rocm7.0"  },
 | 
			
		||||
          { tag: "rocm7.1"  },
 | 
			
		||||
          { tag: "cpu"      },
 | 
			
		||||
        ]
 | 
			
		||||
    steps:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/build-magma-rocm-linux.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-magma-rocm-linux.yml
									
									
									
									
										vendored
									
									
								
							@ -34,7 +34,7 @@ jobs:
 | 
			
		||||
      id-token: write
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        rocm_version: ["71", "70"]
 | 
			
		||||
        rocm_version: ["70", "64"]
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/build-manywheel-images.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/build-manywheel-images.yml
									
									
									
									
										vendored
									
									
								
							@ -54,8 +54,8 @@ jobs:
 | 
			
		||||
          { name: "manylinuxaarch64-builder",       tag: "cuda12.9",          runner: "linux.arm64.2xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinuxaarch64-builder",       tag: "cuda12.8",          runner: "linux.arm64.2xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinuxaarch64-builder",       tag: "cuda12.6",          runner: "linux.arm64.2xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28-builder",          tag: "rocm6.4",           runner: "linux.9xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28-builder",          tag: "rocm7.0",           runner: "linux.9xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28-builder",          tag: "rocm7.1",           runner: "linux.9xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28-builder",          tag: "cpu",               runner: "linux.9xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28_aarch64-builder",  tag: "cpu-aarch64",       runner: "linux.arm64.2xlarge.ephemeral" },
 | 
			
		||||
          { name: "manylinux2_28-builder",          tag: "xpu",               runner: "linux.9xlarge.ephemeral" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/build-triton-wheel.yml
									
									
									
									
										vendored
									
									
								
							@ -55,7 +55,7 @@ jobs:
 | 
			
		||||
        docker-image: ["pytorch/manylinux2_28-builder:cpu"]
 | 
			
		||||
        include:
 | 
			
		||||
          - device: "rocm"
 | 
			
		||||
            rocm_version: "7.1"
 | 
			
		||||
            rocm_version: "7.0"
 | 
			
		||||
            runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
 | 
			
		||||
          - device: "cuda"
 | 
			
		||||
            rocm_version: ""
 | 
			
		||||
@ -159,7 +159,12 @@ jobs:
 | 
			
		||||
            WITH_CLANG_LDD="--with-clang-ldd"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
 | 
			
		||||
          if [[ "${BUILD_DEVICE}" == xpu ]]; then
 | 
			
		||||
            docker exec -t "${container_name}" bash -c "dnf install -y gcc-toolset-13-gcc-c++"
 | 
			
		||||
            docker exec -t "${container_name}" bash -c "source /opt/rh/gcc-toolset-13/enable && ${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE"
 | 
			
		||||
          else
 | 
			
		||||
            docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then
 | 
			
		||||
            docker exec -t "${container_name}"  bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							@ -67,7 +67,6 @@ jobs:
 | 
			
		||||
          pytorch-linux-jammy-py3.12-halide,
 | 
			
		||||
          pytorch-linux-jammy-xpu-n-1-py3,
 | 
			
		||||
          pytorch-linux-jammy-xpu-n-py3,
 | 
			
		||||
          pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
 | 
			
		||||
          pytorch-linux-jammy-py3-clang18-asan,
 | 
			
		||||
          pytorch-linux-jammy-py3-clang12-onnx,
 | 
			
		||||
          pytorch-linux-jammy-linter,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										236
									
								
								.github/workflows/generated-linux-binary-libtorch-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										236
									
								
								.github/workflows/generated-linux-binary-libtorch-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -384,6 +384,124 @@ jobs:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  libtorch-rocm6_4-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm6.4
 | 
			
		||||
      GPU_ARCH_VERSION: "6.4"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm6.4
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      timeout-minutes: 300
 | 
			
		||||
      build_name: libtorch-rocm6_4-shared-with-deps-release
 | 
			
		||||
      build_environment: linux-binary-libtorch
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  libtorch-rocm6_4-shared-with-deps-release-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-rocm6_4-shared-with-deps-release-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: linux.rocm.gpu.mi250
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm6.4
 | 
			
		||||
      GPU_ARCH_VERSION: "6.4"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm6.4
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Setup ROCm
 | 
			
		||||
        uses: ./.github/actions/setup-rocm
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-rocm6_4-shared-with-deps-release
 | 
			
		||||
          path: "${{ runner.temp }}/artifacts/"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: ROCm set GPU_FLAG
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: configure aws credentials
 | 
			
		||||
        id: aws_creds
 | 
			
		||||
        if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
 | 
			
		||||
        uses: aws-actions/configure-aws-credentials@v4
 | 
			
		||||
        with:
 | 
			
		||||
          role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
 | 
			
		||||
          aws-region: us-east-1
 | 
			
		||||
          role-duration-seconds: 18000
 | 
			
		||||
      - name: Calculate docker image
 | 
			
		||||
        id: calculate-docker-image
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
 | 
			
		||||
        with:
 | 
			
		||||
          docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
 | 
			
		||||
          docker-image-name: libtorch-cxx11-builder
 | 
			
		||||
          custom-tag-prefix: rocm6.4
 | 
			
		||||
          docker-build-dir: .ci/docker
 | 
			
		||||
          working-directory: pytorch
 | 
			
		||||
      - name: Pull Docker image
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/pull-docker-image@main
 | 
			
		||||
        with:
 | 
			
		||||
          docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
 | 
			
		||||
      - name: Test Pytorch binary
 | 
			
		||||
        uses: ./pytorch/.github/actions/test-pytorch-binary
 | 
			
		||||
        env:
 | 
			
		||||
          DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
 | 
			
		||||
      - name: Teardown ROCm
 | 
			
		||||
        uses: ./.github/actions/teardown-rocm
 | 
			
		||||
  libtorch-rocm6_4-shared-with-deps-release-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-rocm6_4-shared-with-deps-release-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm6.4
 | 
			
		||||
      GPU_ARCH_VERSION: "6.4"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm6.4
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      build_name: libtorch-rocm6_4-shared-with-deps-release
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  libtorch-rocm7_0-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
@ -501,121 +619,3 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
  libtorch-rocm7_1-shared-with-deps-release-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm7.1
 | 
			
		||||
      GPU_ARCH_VERSION: "7.1"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm7.1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      timeout-minutes: 300
 | 
			
		||||
      build_name: libtorch-rocm7_1-shared-with-deps-release
 | 
			
		||||
      build_environment: linux-binary-libtorch
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  libtorch-rocm7_1-shared-with-deps-release-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - libtorch-rocm7_1-shared-with-deps-release-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    runs-on: linux.rocm.gpu.mi250
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm7.1
 | 
			
		||||
      GPU_ARCH_VERSION: "7.1"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm7.1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Setup ROCm
 | 
			
		||||
        uses: ./.github/actions/setup-rocm
 | 
			
		||||
      - uses: actions/download-artifact@v4.1.7
 | 
			
		||||
        name: Download Build Artifacts
 | 
			
		||||
        with:
 | 
			
		||||
          name: libtorch-rocm7_1-shared-with-deps-release
 | 
			
		||||
          path: "${{ runner.temp }}/artifacts/"
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: ROCm set GPU_FLAG
 | 
			
		||||
        run: |
 | 
			
		||||
          echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: configure aws credentials
 | 
			
		||||
        id: aws_creds
 | 
			
		||||
        if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
 | 
			
		||||
        uses: aws-actions/configure-aws-credentials@v4
 | 
			
		||||
        with:
 | 
			
		||||
          role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
 | 
			
		||||
          aws-region: us-east-1
 | 
			
		||||
          role-duration-seconds: 18000
 | 
			
		||||
      - name: Calculate docker image
 | 
			
		||||
        id: calculate-docker-image
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
 | 
			
		||||
        with:
 | 
			
		||||
          docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
 | 
			
		||||
          docker-image-name: libtorch-cxx11-builder
 | 
			
		||||
          custom-tag-prefix: rocm7.1
 | 
			
		||||
          docker-build-dir: .ci/docker
 | 
			
		||||
          working-directory: pytorch
 | 
			
		||||
      - name: Pull Docker image
 | 
			
		||||
        uses: pytorch/test-infra/.github/actions/pull-docker-image@main
 | 
			
		||||
        with:
 | 
			
		||||
          docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
 | 
			
		||||
      - name: Test Pytorch binary
 | 
			
		||||
        uses: ./pytorch/.github/actions/test-pytorch-binary
 | 
			
		||||
        env:
 | 
			
		||||
          DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
 | 
			
		||||
      - name: Teardown ROCm
 | 
			
		||||
        uses: ./.github/actions/teardown-rocm
 | 
			
		||||
  libtorch-rocm7_1-shared-with-deps-release-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: libtorch-rocm7_1-shared-with-deps-release-test
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: libtorch
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: rocm7.1
 | 
			
		||||
      GPU_ARCH_VERSION: "7.1"
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      DOCKER_IMAGE: libtorch-cxx11-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm7.1
 | 
			
		||||
      LIBTORCH_CONFIG: release
 | 
			
		||||
      LIBTORCH_VARIANT: shared-with-deps
 | 
			
		||||
      build_name: libtorch-rocm7_1-shared-with-deps-release
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1610
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										1610
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										148
									
								
								.github/workflows/inductor-perf-test-nightly-xpu.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										148
									
								
								.github/workflows/inductor-perf-test-nightly-xpu.yml
									
									
									
									
										vendored
									
									
								
							@ -1,148 +0,0 @@
 | 
			
		||||
name: inductor-perf-nightly-xpu
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/inductor-perf-test-nightly-xpu/*
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 30 17 * * *
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
    inputs:
 | 
			
		||||
      training:
 | 
			
		||||
        description: Run training (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      inference:
 | 
			
		||||
        description: Run inference (on by default)?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: true
 | 
			
		||||
      default:
 | 
			
		||||
        description: Run inductor_default?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      dynamic:
 | 
			
		||||
        description: Run inductor_dynamic_shapes?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cppwrapper:
 | 
			
		||||
        description: Run inductor_cpp_wrapper?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      freezing_cudagraphs:
 | 
			
		||||
        description: Run inductor_cudagraphs with freezing for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      aotinductor:
 | 
			
		||||
        description: Run aot_inductor for inference?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      maxautotune:
 | 
			
		||||
        description: Run inductor_max_autotune?
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      benchmark_configs:
 | 
			
		||||
        description: The list of configs used the benchmark
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
      opt_out_experiments: lf
 | 
			
		||||
 | 
			
		||||
  xpu-n-py3_10-inductor-benchmark-build:
 | 
			
		||||
    name: xpu-n-py3.10-inductor-benchmark
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
 | 
			
		||||
      runner: linux.c7i.12xlarge
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  xpu-n-py3_10-inductor-benchmark-test-nightly:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    if: github.event_name != 'workflow_dispatch'
 | 
			
		||||
    name: xpu-n-py3.10-inductor-benchmark
 | 
			
		||||
    uses: ./.github/workflows/_xpu-test.yml
 | 
			
		||||
    needs: xpu-n-py3_10-inductor-benchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
 | 
			
		||||
      docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 720
 | 
			
		||||
      # Disable monitor in perf tests for more investigation
 | 
			
		||||
      disable-monitor: true
 | 
			
		||||
      monitor-log-interval: 10
 | 
			
		||||
      monitor-data-collect-interval: 2
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  xpu-n-py3_10-inductor-benchmark-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    if: github.event_name == 'workflow_dispatch'
 | 
			
		||||
    name: xpu-n-py3.10-inductor-test
 | 
			
		||||
    uses: ./.github/workflows/_xpu-test.yml
 | 
			
		||||
    needs: xpu-n-py3_10-inductor-benchmark-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-xpu-n-py3.10
 | 
			
		||||
      dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
 | 
			
		||||
      docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 720
 | 
			
		||||
      disable-monitor: false
 | 
			
		||||
      monitor-log-interval: 15
 | 
			
		||||
      monitor-data-collect-interval: 4
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							@ -76,12 +76,11 @@ jobs:
 | 
			
		||||
 | 
			
		||||
  # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
 | 
			
		||||
  #       fails to find types when it should
 | 
			
		||||
  # NOTE: We should be able to disable this and consolidate with Pyrefly
 | 
			
		||||
  lintrunner-pyrefly:
 | 
			
		||||
  lintrunner-mypy:
 | 
			
		||||
    uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
 | 
			
		||||
    name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
 | 
			
		||||
    name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
 | 
			
		||||
    needs: [get-label-type, get-changed-files]
 | 
			
		||||
    # Only run if there are changed files relevant to pyrefly
 | 
			
		||||
    # Only run if there are changed files relevant to mypy
 | 
			
		||||
    if: |
 | 
			
		||||
      github.repository_owner == 'pytorch' && (
 | 
			
		||||
        needs.get-changed-files.outputs.changed-files == '*' ||
 | 
			
		||||
@ -99,8 +98,8 @@ jobs:
 | 
			
		||||
      ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
      script: |
 | 
			
		||||
        CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
 | 
			
		||||
        echo "Running pyrefly"
 | 
			
		||||
        ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
        echo "Running mypy"
 | 
			
		||||
        ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
 | 
			
		||||
  lintrunner-noclang:
 | 
			
		||||
    uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
 | 
			
		||||
@ -119,9 +118,9 @@ jobs:
 | 
			
		||||
        CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
 | 
			
		||||
        echo "Running all other linters"
 | 
			
		||||
        if [ "$CHANGED_FILES" = '*' ]; then
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
 | 
			
		||||
        else
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
          ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
  quick-checks:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										84
									
								
								.github/workflows/periodic-rocm-mi200.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										84
									
								
								.github/workflows/periodic-rocm-mi200.yml
									
									
									
									
										vendored
									
									
								
							@ -1,84 +0,0 @@
 | 
			
		||||
name: periodic-rocm-mi200
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  schedule:
 | 
			
		||||
    # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
 | 
			
		||||
    # Also run less frequently on weekends.
 | 
			
		||||
    - cron: 45 0,8,16 * * 1-5
 | 
			
		||||
    - cron: 45 4 * * 0,6
 | 
			
		||||
    - cron: 45 4,12,20 * * 1-5
 | 
			
		||||
    - cron: 45 12 * * 0,6
 | 
			
		||||
    - cron: 29 8 * * *  # about 1:29am PDT, for mem leak check and rerun disabled tests
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/periodic/*
 | 
			
		||||
      - ciflow/periodic-rocm-mi200/*
 | 
			
		||||
    branches:
 | 
			
		||||
      - release/*
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  llm-td:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: before-test
 | 
			
		||||
    uses: ./.github/workflows/llm_td_retrieval.yml
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  target-determination:
 | 
			
		||||
    name: before-test
 | 
			
		||||
    uses: ./.github/workflows/target_determination.yml
 | 
			
		||||
    needs: llm-td
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
 | 
			
		||||
  get-label-type:
 | 
			
		||||
    name: get-label-type
 | 
			
		||||
    uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
 | 
			
		||||
    if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch'
 | 
			
		||||
    with:
 | 
			
		||||
      triggering_actor: ${{ github.triggering_actor }}
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
          { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										31
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										31
									
								
								.github/workflows/periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -204,6 +204,37 @@ jobs:
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-build:
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
          { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-rocm-py3_10-test:
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    name: linux-jammy-rocm-py3.10
 | 
			
		||||
    uses: ./.github/workflows/_rocm-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-rocm-py3_10-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3.10
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/upload-test-stats.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/upload-test-stats.yml
									
									
									
									
										vendored
									
									
								
							@ -6,7 +6,6 @@ on:
 | 
			
		||||
      - pull
 | 
			
		||||
      - trunk
 | 
			
		||||
      - periodic
 | 
			
		||||
      - periodic-rocm-mi200
 | 
			
		||||
      - periodic-rocm-mi300
 | 
			
		||||
      - inductor
 | 
			
		||||
      - unstable
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -143,7 +143,6 @@ scripts/release_notes/*.json
 | 
			
		||||
sccache-stats*.json
 | 
			
		||||
lint.json
 | 
			
		||||
merge_record.json
 | 
			
		||||
.github/scripts/nightly_source_matrix.json
 | 
			
		||||
 | 
			
		||||
# These files get copied over on invoking setup.py
 | 
			
		||||
torchgen/packaged/*
 | 
			
		||||
@ -398,4 +397,3 @@ CLAUDE.local.md
 | 
			
		||||
/test_*.py
 | 
			
		||||
/debug_*.py
 | 
			
		||||
CLAUDE_CONTEXT/
 | 
			
		||||
/.claude/settings.local.json
 | 
			
		||||
 | 
			
		||||
@ -121,6 +121,94 @@ command = [
 | 
			
		||||
]
 | 
			
		||||
is_formatter = true
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'MYPY'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    'setup.py',
 | 
			
		||||
    'functorch/dim/**/*.py',
 | 
			
		||||
    'torch/**/*.py',
 | 
			
		||||
    'torch/**/*.pyi',
 | 
			
		||||
    'caffe2/**/*.py',
 | 
			
		||||
    'caffe2/**/*.pyi',
 | 
			
		||||
    'test/test_bundled_images.py',
 | 
			
		||||
    'test/test_bundled_inputs.py',
 | 
			
		||||
    'test/test_complex.py',
 | 
			
		||||
    'test/test_datapipe.py',
 | 
			
		||||
    'test/test_futures.py',
 | 
			
		||||
    'test/test_numpy_interop.py',
 | 
			
		||||
    'test/test_torch.py',
 | 
			
		||||
    'test/test_type_hints.py',
 | 
			
		||||
    'test/test_type_info.py',
 | 
			
		||||
    'test/test_utils.py',
 | 
			
		||||
]
 | 
			
		||||
exclude_patterns = [
 | 
			
		||||
    '**/fb/**',
 | 
			
		||||
]
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/mypy_linter.py',
 | 
			
		||||
    '--config=mypy.ini',
 | 
			
		||||
    '--',
 | 
			
		||||
    '@{{PATHSFILE}}'
 | 
			
		||||
]
 | 
			
		||||
init_command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pip_init.py',
 | 
			
		||||
    '--dry-run={{DRYRUN}}',
 | 
			
		||||
    'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
 | 
			
		||||
    'numpy==2.1.0 ; python_version >= "3.12"',
 | 
			
		||||
    'expecttest==0.3.0',
 | 
			
		||||
    'mypy==1.16.0',
 | 
			
		||||
    'sympy==1.13.3',
 | 
			
		||||
    'types-requests==2.27.25',
 | 
			
		||||
    'types-pyyaml==6.0.2',
 | 
			
		||||
    'types-tabulate==0.8.8',
 | 
			
		||||
    'types-protobuf==5.29.1.20250403',
 | 
			
		||||
    'types-setuptools==79.0.0.20250422',
 | 
			
		||||
    'types-jinja2==2.11.9',
 | 
			
		||||
    'types-colorama==0.4.6',
 | 
			
		||||
    'filelock==3.18.0',
 | 
			
		||||
    'junitparser==2.1.1',
 | 
			
		||||
    'rich==14.1.0',
 | 
			
		||||
    'pyyaml==6.0.2',
 | 
			
		||||
    'optree==0.13.0',
 | 
			
		||||
    'dataclasses-json==0.6.7',
 | 
			
		||||
    'pandas==2.2.3',
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'MYPYSTRICT'
 | 
			
		||||
include_patterns = [
 | 
			
		||||
    '.github/**/*.py',
 | 
			
		||||
    'benchmarks/instruction_counts/**/*.py',
 | 
			
		||||
    'tools/**/*.py',
 | 
			
		||||
    'torchgen/**/*.py',
 | 
			
		||||
    'torch/utils/_pytree.py',
 | 
			
		||||
    'torch/utils/_cxx_pytree.py',
 | 
			
		||||
    'torch/utils/benchmark/utils/common.py',
 | 
			
		||||
    'torch/utils/benchmark/utils/timer.py',
 | 
			
		||||
    'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
 | 
			
		||||
]
 | 
			
		||||
exclude_patterns = [
 | 
			
		||||
    # (linbinyu) copied from internal repo
 | 
			
		||||
    '**/fb/**',
 | 
			
		||||
    'tools/code_analyzer/gen_operators_yaml.py',
 | 
			
		||||
    'tools/dynamo/verify_dynamo.py',
 | 
			
		||||
    'tools/gen_vulkan_spv.py',
 | 
			
		||||
    'tools/test/gen_operators_yaml_test.py',
 | 
			
		||||
    'tools/test/gen_oplist_test.py',
 | 
			
		||||
    'tools/test/test_selective_build.py',
 | 
			
		||||
    'tools/experimental/torchfuzz/**',
 | 
			
		||||
]
 | 
			
		||||
command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/mypy_linter.py',
 | 
			
		||||
    '--config=mypy-strict.ini',
 | 
			
		||||
    '--code=MYPYSTRICT',
 | 
			
		||||
    '--',
 | 
			
		||||
    '@{{PATHSFILE}}'
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[[linter]]
 | 
			
		||||
code = 'PYREFLY'
 | 
			
		||||
@ -142,7 +230,6 @@ init_command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pip_init.py',
 | 
			
		||||
    '--dry-run={{DRYRUN}}',
 | 
			
		||||
    'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
 | 
			
		||||
    'numpy==2.1.0 ; python_version >= "3.12"',
 | 
			
		||||
    'expecttest==0.3.0',
 | 
			
		||||
    'pyrefly==0.36.2',
 | 
			
		||||
 | 
			
		||||
@ -374,7 +374,7 @@ cmake_dependent_option(
 | 
			
		||||
  "Build the lazy Torchscript backend, not compatible with mobile builds" ON
 | 
			
		||||
  "NOT INTERN_BUILD_MOBILE" OFF)
 | 
			
		||||
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
 | 
			
		||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder"
 | 
			
		||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
 | 
			
		||||
                       OFF "USE_CUDA" OFF)
 | 
			
		||||
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
 | 
			
		||||
                        "CPU_AARCH64" OFF)
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ aspects of contributing to PyTorch.
 | 
			
		||||
<!-- toc -->
 | 
			
		||||
 | 
			
		||||
- [Developing PyTorch](#developing-pytorch)
 | 
			
		||||
  - [Setup the development environment](#setup-the-development-environment)
 | 
			
		||||
  - [Tips and Debugging](#tips-and-debugging)
 | 
			
		||||
- [Nightly Checkout & Pull](#nightly-checkout--pull)
 | 
			
		||||
- [Codebase structure](#codebase-structure)
 | 
			
		||||
@ -66,6 +67,23 @@ aspects of contributing to PyTorch.
 | 
			
		||||
 | 
			
		||||
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
 | 
			
		||||
 | 
			
		||||
### Setup the development environment
 | 
			
		||||
 | 
			
		||||
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
 | 
			
		||||
 | 
			
		||||
Then clone the PyTorch project and setup the development environment:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
git clone git@github.com:<USERNAME>/pytorch.git
 | 
			
		||||
cd pytorch
 | 
			
		||||
git remote add upstream git@github.com:pytorch/pytorch.git
 | 
			
		||||
 | 
			
		||||
make setup-env
 | 
			
		||||
# Or run `make setup-env-cuda` for pre-built CUDA binaries
 | 
			
		||||
# Or run `make setup-env-rocm` for pre-built ROCm binaries
 | 
			
		||||
source venv/bin/activate  # or `. .\venv\Scripts\activate` on Windows
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Tips and Debugging
 | 
			
		||||
 | 
			
		||||
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
 | 
			
		||||
 | 
			
		||||
@ -181,7 +181,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
 | 
			
		||||
  static const size_t size = sizeof(CPUGeneratorImplState);
 | 
			
		||||
  static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
 | 
			
		||||
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto rng_state = state_tensor.data_ptr();
 | 
			
		||||
 | 
			
		||||
  // accumulate generator data to be copied into byte tensor
 | 
			
		||||
 | 
			
		||||
@ -223,7 +223,7 @@ void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
 | 
			
		||||
    "setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ",
 | 
			
		||||
    at::num_sdp_backends, " unique backends specified in priority order.");
 | 
			
		||||
  for (uint32_t i = 0; i < order.size(); i++) {
 | 
			
		||||
    sdp_priority_order[i] = static_cast<at::SDPBackend>(order[i]);
 | 
			
		||||
    sdp_priority_order[i] = (at::SDPBackend) order[i];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -825,14 +825,6 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
 | 
			
		||||
  display_vmap_fallback_warnings_ = enabled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Context::warnOnAccumulateGradStreamMismatch() const {
 | 
			
		||||
  return warn_on_accumulate_grad_stream_mismatch_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
 | 
			
		||||
  warn_on_accumulate_grad_stream_mismatch_ = enabled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Context::isDefaultMobileCPUAllocatorSet() {
 | 
			
		||||
  return prev_allocator_ptr_ != nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -404,9 +404,6 @@ class TORCH_API Context {
 | 
			
		||||
  void setDisplayVmapFallbackWarnings(bool enabled);
 | 
			
		||||
  bool areVmapFallbackWarningsEnabled() const;
 | 
			
		||||
 | 
			
		||||
  void setWarnOnAccumulateGradStreamMismatch(bool enabled);
 | 
			
		||||
  bool warnOnAccumulateGradStreamMismatch() const;
 | 
			
		||||
 | 
			
		||||
  bool isDefaultMobileCPUAllocatorSet();
 | 
			
		||||
  void setDefaultMobileCPUAllocator();
 | 
			
		||||
  void unsetDefaultMobileCPUAllocator();
 | 
			
		||||
@ -497,7 +494,6 @@ class TORCH_API Context {
 | 
			
		||||
  bool release_original_weights = false;
 | 
			
		||||
#endif
 | 
			
		||||
  bool display_vmap_fallback_warnings_ = false;
 | 
			
		||||
  bool warn_on_accumulate_grad_stream_mismatch_ = true;
 | 
			
		||||
  std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
 | 
			
		||||
  bool enable_sparse_tensor_invariant_checks = false;
 | 
			
		||||
  bool allow_fp16_reduction_cpu = false;
 | 
			
		||||
 | 
			
		||||
@ -197,7 +197,6 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
 | 
			
		||||
    /* don't use TYPE again in case it is an expensive or side-effect op */ \
 | 
			
		||||
    at::ScalarType _st = ::detail::scalar_type(the_type);                   \
 | 
			
		||||
    RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st);                    \
 | 
			
		||||
    C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")             \
 | 
			
		||||
    switch (_st) {                                                          \
 | 
			
		||||
      __VA_ARGS__                                                           \
 | 
			
		||||
      default:                                                              \
 | 
			
		||||
@ -209,7 +208,6 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
 | 
			
		||||
            toString(_st),                                                  \
 | 
			
		||||
            "'");                                                           \
 | 
			
		||||
    }                                                                       \
 | 
			
		||||
    C10_DIAGNOSTIC_POP()                                                    \
 | 
			
		||||
  }()
 | 
			
		||||
 | 
			
		||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...)            \
 | 
			
		||||
 | 
			
		||||
@ -252,13 +252,13 @@ MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd,
 | 
			
		||||
    if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) {
 | 
			
		||||
      if (flags_ & ALLOCATOR_MAPPED_SHARED) {
 | 
			
		||||
        // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
 | 
			
		||||
        if ((fd = open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
 | 
			
		||||
        if ((fd = open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
 | 
			
		||||
          TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
 | 
			
		||||
        }
 | 
			
		||||
      } else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) {
 | 
			
		||||
#ifdef HAVE_SHM_OPEN
 | 
			
		||||
        // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
 | 
			
		||||
        if((fd = shm_open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
 | 
			
		||||
        if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
 | 
			
		||||
          TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
 | 
			
		||||
        }
 | 
			
		||||
#else
 | 
			
		||||
@ -503,7 +503,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *fi
 | 
			
		||||
 | 
			
		||||
void RefcountedMapAllocator::initializeAlloc() {
 | 
			
		||||
  TORCH_CHECK(base_ptr_, "base_ptr_ is null");
 | 
			
		||||
  MapInfo *map_info = static_cast<MapInfo*>(base_ptr_);
 | 
			
		||||
  MapInfo *map_info = (MapInfo*)base_ptr_;
 | 
			
		||||
 | 
			
		||||
#ifdef _WIN32
 | 
			
		||||
  ReleaseContext* r_ctx = new ReleaseContext;
 | 
			
		||||
@ -539,7 +539,7 @@ void RefcountedMapAllocator::close() {
 | 
			
		||||
  }
 | 
			
		||||
#else /* _WIN32 */
 | 
			
		||||
 | 
			
		||||
  MapInfo *info = static_cast<MapInfo*>(data);
 | 
			
		||||
  MapInfo *info = (MapInfo*)(data);
 | 
			
		||||
  if (--info->refcount == 0) {
 | 
			
		||||
#ifdef HAVE_SHM_UNLINK
 | 
			
		||||
    if (shm_unlink(filename_.c_str()) == -1) {
 | 
			
		||||
 | 
			
		||||
@ -862,7 +862,7 @@ void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
 | 
			
		||||
  shape_[dim] = size;
 | 
			
		||||
  view_offsets_[dim] += start;
 | 
			
		||||
  for (auto& op : operands_) {
 | 
			
		||||
    op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start;
 | 
			
		||||
    op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
 | 
			
		||||
  }
 | 
			
		||||
  if (size == 1 && !is_reduction_) {
 | 
			
		||||
    coalesce_dimensions();
 | 
			
		||||
@ -873,7 +873,7 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(start_dim <= ndim());
 | 
			
		||||
  for (const auto i : c10::irange(start_dim, ndim())) {
 | 
			
		||||
    for (auto& op : operands_) {
 | 
			
		||||
      op.data = (static_cast<char*>(op.data)) + op.stride_bytes[i] * indices[i - start_dim];
 | 
			
		||||
      op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
 | 
			
		||||
    }
 | 
			
		||||
    shape_[i] = 1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ inline void serial_for_each(
 | 
			
		||||
    IntArrayRef strides,
 | 
			
		||||
    char** base_ptrs,
 | 
			
		||||
    size_t ntensors,
 | 
			
		||||
    TensorIteratorBase::loop2d_t loop,
 | 
			
		||||
    typename TensorIteratorBase::loop2d_t loop,
 | 
			
		||||
    Range range) {
 | 
			
		||||
  const auto ndim = shape.size();
 | 
			
		||||
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
 | 
			
		||||
 | 
			
		||||
@ -190,14 +190,12 @@ class IListRef;
 | 
			
		||||
 * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
 | 
			
		||||
 */
 | 
			
		||||
#define TORCH_ILISTREF_UNWRAP(TAG, BODY)                         \
 | 
			
		||||
  C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")    \
 | 
			
		||||
  switch (TAG) {                                                 \
 | 
			
		||||
    TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
 | 
			
		||||
    break;                                                       \
 | 
			
		||||
    default:                                                     \
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag.");     \
 | 
			
		||||
  } \
 | 
			
		||||
  C10_DIAGNOSTIC_POP()
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
enum class IListRefTag {
 | 
			
		||||
#define DEFINE_TAG(tag, ...) tag,
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
 | 
			
		||||
 * in this overloaded version
 | 
			
		||||
 */
 | 
			
		||||
template <typename T, typename V>
 | 
			
		||||
C10_HOST_DEVICE inline std::enable_if_t<!std::is_floating_point_v<T>, T>uniform_int(V val) {
 | 
			
		||||
C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
 | 
			
		||||
  if constexpr (std::is_same_v<T, bool>) {
 | 
			
		||||
    return static_cast<bool>(val & 1);
 | 
			
		||||
  } else if constexpr (std::is_same_v<T, int64_t>) {
 | 
			
		||||
 | 
			
		||||
@ -114,25 +114,25 @@ inline typename remove_symint<T>::type unpackSymInt(T x) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
 | 
			
		||||
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
 | 
			
		||||
  return x.guard_int(__FILE__, __LINE__);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
 | 
			
		||||
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
 | 
			
		||||
    c10::SymIntArrayRef x) {
 | 
			
		||||
  return C10_AS_INTARRAYREF_SLOW(x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
 | 
			
		||||
inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
 | 
			
		||||
    std::optional<c10::SymInt> x) {
 | 
			
		||||
  return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
 | 
			
		||||
                       : std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
 | 
			
		||||
inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
 | 
			
		||||
    at::OptionalSymIntArrayRef x) {
 | 
			
		||||
  return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
 | 
			
		||||
                       : std::nullopt;
 | 
			
		||||
 | 
			
		||||
@ -631,8 +631,8 @@ call_functor_with_args_from_stack_(
 | 
			
		||||
    Stack* stack,
 | 
			
		||||
    std::index_sequence<ivalue_arg_indices...> /*unused*/,
 | 
			
		||||
    guts::typelist::typelist<ArgTypes...>* /*unused*/) {
 | 
			
		||||
  (void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would
 | 
			
		||||
               // be unused and we have to silence the compiler warning.
 | 
			
		||||
  (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
 | 
			
		||||
                 // be unused and we have to silence the compiler warning.
 | 
			
		||||
 | 
			
		||||
  // We're explicitly filtering out DispatchKeySet from the argument list.
 | 
			
		||||
  // Some kernels take a DispatchKeySet as their first argument in order to
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,6 @@ struct TORCH_API EnumType : public NamedType {
 | 
			
		||||
      TypePtr value,
 | 
			
		||||
      std::vector<EnumNameValue> enum_names_values,
 | 
			
		||||
      std::weak_ptr<::torch::jit::CompilationUnit> cu) {
 | 
			
		||||
    C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
 | 
			
		||||
    switch (value->kind()) {
 | 
			
		||||
      case TypeKind::IntType:
 | 
			
		||||
      case TypeKind::FloatType:
 | 
			
		||||
@ -35,7 +34,6 @@ struct TORCH_API EnumType : public NamedType {
 | 
			
		||||
            value->str(),
 | 
			
		||||
            "', only int, float and string are supported");
 | 
			
		||||
    }
 | 
			
		||||
    C10_DIAGNOSTIC_POP()
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string str() const override {
 | 
			
		||||
 | 
			
		||||
@ -601,8 +601,8 @@ std::ostream& IValue::repr(
 | 
			
		||||
      double d = v.toDouble();
 | 
			
		||||
      int c = std::fpclassify(d);
 | 
			
		||||
      if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
 | 
			
		||||
        int64_t i = static_cast<int64_t>(d);
 | 
			
		||||
        if (static_cast<double>(i) == d) {
 | 
			
		||||
        int64_t i = int64_t(d);
 | 
			
		||||
        if (double(i) == d) {
 | 
			
		||||
          // -0.0 (signed zero) needs to be parsed as -0.
 | 
			
		||||
          if (i == 0 && std::signbit(d)) {
 | 
			
		||||
            return out << "-" << i << ".";
 | 
			
		||||
@ -799,8 +799,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
 | 
			
		||||
      double d = v.toDouble();
 | 
			
		||||
      int c = std::fpclassify(d);
 | 
			
		||||
      if (c == FP_NORMAL || c == FP_ZERO) {
 | 
			
		||||
        int64_t i = static_cast<int64_t>(d);
 | 
			
		||||
        if (static_cast<double>(i) == d) {
 | 
			
		||||
        int64_t i = int64_t(d);
 | 
			
		||||
        if (double(i) == d) {
 | 
			
		||||
          return out << i << ".";
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
 | 
			
		||||
inline bool is_contiguous_strides(
 | 
			
		||||
    const IntArrayRef sizes,
 | 
			
		||||
    const IntArrayRef strides) {
 | 
			
		||||
  size_t n_dim = sizes.size();
 | 
			
		||||
  int n_dim = static_cast<int>(sizes.size());
 | 
			
		||||
  if (n_dim == 0) {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
@ -50,7 +50,7 @@ inline bool is_contiguous_strides(
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (int i = static_cast<int>(n_dim) - 2; i >= 0; i--) {
 | 
			
		||||
  for (int i = n_dim - 2; i >= 0; i--) {
 | 
			
		||||
    if (strides[i] != strides[i + 1] * sizes[i + 1]) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
@ -922,7 +922,6 @@ struct TORCH_API DictType : public SharedType {
 | 
			
		||||
    if (auto dyn = key->castRaw<DynamicType>()) {
 | 
			
		||||
      kind = dyn->dynamicKind();
 | 
			
		||||
    }
 | 
			
		||||
    C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
 | 
			
		||||
    switch (kind) {
 | 
			
		||||
      case TypeKind::AnyType:
 | 
			
		||||
      case TypeKind::IntType:
 | 
			
		||||
@ -939,7 +938,6 @@ struct TORCH_API DictType : public SharedType {
 | 
			
		||||
            key->str(),
 | 
			
		||||
            "', only int, float, complex, Tensor, device and string keys are supported");
 | 
			
		||||
    }
 | 
			
		||||
    C10_DIAGNOSTIC_POP()
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // aligned with the format in FunctionSchema
 | 
			
		||||
@ -2373,7 +2371,7 @@ private:
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
 | 
			
		||||
inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
 | 
			
		||||
  if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
 | 
			
		||||
      kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
 | 
			
		||||
    return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
 | 
			
		||||
@ -2382,7 +2380,7 @@ inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
inline detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
 | 
			
		||||
inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
 | 
			
		||||
  if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
 | 
			
		||||
      kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
 | 
			
		||||
    return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
 | 
			
		||||
 | 
			
		||||
@ -19,13 +19,6 @@ inline namespace CPU_CAPABILITY {
 | 
			
		||||
#error "Big endian is not supported."
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// GCC does not properly optimize bf16 operators
 | 
			
		||||
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
 | 
			
		||||
#define BF16_ARITHMETIC_SUPPORTED() 1
 | 
			
		||||
#else
 | 
			
		||||
#define BF16_ARITHMETIC_SUPPORTED() 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Unlike the float16_t family of types, bfloat16_t is not available
 | 
			
		||||
// when we're not targeting bfloat16 hardware support on some
 | 
			
		||||
// platforms (but not Mac, so we have to be careful not to shadow the
 | 
			
		||||
@ -359,72 +352,18 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
 | 
			
		||||
        other, &Vectorized<float>::name);                        \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
 | 
			
		||||
  Vectorized frac() const;
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
 | 
			
		||||
 | 
			
		||||
#ifdef __ARM_FEATURE_BF16
 | 
			
		||||
  // Flip sign bit
 | 
			
		||||
  Vectorized<c10::BFloat16> neg() const {
 | 
			
		||||
    return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
 | 
			
		||||
  }
 | 
			
		||||
  // Fast reciprocal is fine because we are truncating results
 | 
			
		||||
  Vectorized<c10::BFloat16> reciprocal() const {
 | 
			
		||||
    auto x = vcvtq_low_f32_bf16(values);
 | 
			
		||||
    auto y = vcvtq_high_f32_bf16(values);
 | 
			
		||||
    x = vrecpeq_f32(x);
 | 
			
		||||
    y = vrecpeq_f32(y);
 | 
			
		||||
    return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
 | 
			
		||||
  }
 | 
			
		||||
  // Clearing the sign bit
 | 
			
		||||
  Vectorized<c10::BFloat16> abs() const {
 | 
			
		||||
    return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
 | 
			
		||||
  DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// These functions are optimized on clang-21+
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
 | 
			
		||||
  Vectorized<c10::BFloat16> operator==(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values == other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator!=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values != other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values < other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator<=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values <= other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values > other.values;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Vectorized<c10::BFloat16> operator>=(
 | 
			
		||||
      const Vectorized<c10::BFloat16>& other) const {
 | 
			
		||||
    return values >= other.values;
 | 
			
		||||
  }
 | 
			
		||||
#else
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
 | 
			
		||||
  DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
 | 
			
		||||
@ -473,52 +412,28 @@ template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator+(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x + y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator-(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x - y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator*(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x * y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
Vectorized<c10::BFloat16> inline operator/(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  return x / y;
 | 
			
		||||
#else
 | 
			
		||||
  return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// frac. Implement this here so we can use subtraction
 | 
			
		||||
@ -629,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16!  Also,
 | 
			
		||||
  // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
 | 
			
		||||
  // elements, not the bottom and top half, so they don't seem
 | 
			
		||||
  // particularly useful here. Ideally we would include dot product in
 | 
			
		||||
  // the Vectorized interface...
 | 
			
		||||
  return a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -649,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y + z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b + c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -665,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return x * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -681,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
 | 
			
		||||
    const Vectorized<c10::BFloat16>& a,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& b,
 | 
			
		||||
    const Vectorized<c10::BFloat16>& c) {
 | 
			
		||||
#if BF16_ARITHMETIC_SUPPORTED()
 | 
			
		||||
  bfloat16x8_t x = a;
 | 
			
		||||
  bfloat16x8_t y = b;
 | 
			
		||||
  bfloat16x8_t z = c;
 | 
			
		||||
  return (-x) * y - z;
 | 
			
		||||
#else
 | 
			
		||||
  // See NOTE [BF16 FMA] above.
 | 
			
		||||
  return -a * b - c;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
 | 
			
		||||
 | 
			
		||||
@ -6,9 +6,9 @@ namespace at::vec {
 | 
			
		||||
inline namespace CPU_CAPABILITY {
 | 
			
		||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
 | 
			
		||||
 | 
			
		||||
// Enable auto-vectorization for clang-17+
 | 
			
		||||
// Enable auto-vectorization for GCC-13+ and clang-17+
 | 
			
		||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
 | 
			
		||||
#if defined(__clang__) && (__clang_major__ >= 17)
 | 
			
		||||
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
 | 
			
		||||
 | 
			
		||||
template <typename from_type, typename to_type>
 | 
			
		||||
inline void convertImpl(
 | 
			
		||||
 | 
			
		||||
@ -309,7 +309,7 @@ class Vectorized<float> {
 | 
			
		||||
  DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
 | 
			
		||||
  // Implementation copied from Arm Optimized Routine
 | 
			
		||||
  // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
 | 
			
		||||
  inline Vectorized<float> vexpq_f32_u20() const {
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    // bail out to sleef if it's a special case:
 | 
			
		||||
    // i.e. there's an input s.t. |input| > 87.3....
 | 
			
		||||
    const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
 | 
			
		||||
@ -348,9 +348,6 @@ class Vectorized<float> {
 | 
			
		||||
 | 
			
		||||
    return vfmaq_f32(scale, poly, scale);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> exp_u20() const {
 | 
			
		||||
    return vexpq_f32_u20();
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<float> fexp_u20() const {
 | 
			
		||||
    return exp_u20();
 | 
			
		||||
  }
 | 
			
		||||
@ -637,7 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
 | 
			
		||||
  // - exp(- x * x)
 | 
			
		||||
  auto pow_2 = (*this) * (*this);
 | 
			
		||||
  auto neg_pow_2 = pow_2 ^ neg_zero_vec;
 | 
			
		||||
  auto tmp4 = neg_pow_2.vexpq_f32_u20();
 | 
			
		||||
  auto tmp4 = neg_pow_2.exp();
 | 
			
		||||
  auto tmp5 = tmp4 ^ neg_zero_vec;
 | 
			
		||||
  // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
 | 
			
		||||
  auto tmp6 = t * tmp5;
 | 
			
		||||
 | 
			
		||||
@ -514,7 +514,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
 | 
			
		||||
 | 
			
		||||
  using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
 | 
			
		||||
  using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
 | 
			
		||||
  using value_type = c10::qint8::underlying;
 | 
			
		||||
  using value_type = typename c10::qint8::underlying;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using Vectorizedqi::Vectorizedqi;
 | 
			
		||||
@ -727,7 +727,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
 | 
			
		||||
 | 
			
		||||
  using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
 | 
			
		||||
  using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
 | 
			
		||||
  using value_type = c10::quint8::underlying;
 | 
			
		||||
  using value_type = typename c10::quint8::underlying;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using Vectorizedqi::Vectorizedqi;
 | 
			
		||||
 | 
			
		||||
@ -567,7 +567,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
 | 
			
		||||
 | 
			
		||||
  using float_vec_return_type = std::array<Vectorized<float>, 4>;
 | 
			
		||||
  using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
 | 
			
		||||
  using value_type = c10::qint8::underlying;
 | 
			
		||||
  using value_type = typename c10::qint8::underlying;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using Vectorizedqi::Vectorizedqi;
 | 
			
		||||
@ -804,7 +804,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
 | 
			
		||||
 | 
			
		||||
  using float_vec_return_type = std::array<Vectorized<float>, 4>;
 | 
			
		||||
  using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
 | 
			
		||||
  using value_type = c10::quint8::underlying;
 | 
			
		||||
  using value_type = typename c10::quint8::underlying;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  using Vectorizedqi::Vectorizedqi;
 | 
			
		||||
 | 
			
		||||
@ -672,7 +672,7 @@ struct Vectorized {
 | 
			
		||||
    return map(std::sqrt);
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<T> reciprocal() const {
 | 
			
		||||
    return map([](T x) { return (T)1 / x; });
 | 
			
		||||
    return map([](T x) { return (T)(1) / x; });
 | 
			
		||||
  }
 | 
			
		||||
  Vectorized<T> rsqrt() const {
 | 
			
		||||
    return map([](T x) { return (T)1 / std::sqrt(x); });
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,7 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
 | 
			
		||||
  parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
 | 
			
		||||
    map(
 | 
			
		||||
        [](const Vectorized<scalar_t>& x) {
 | 
			
		||||
          return Vectorized<scalar_t>((scalar_t)1) / x.sqrt();
 | 
			
		||||
          return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
 | 
			
		||||
        },
 | 
			
		||||
        out + begin,
 | 
			
		||||
        in + begin,
 | 
			
		||||
 | 
			
		||||
@ -194,8 +194,8 @@ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
 | 
			
		||||
void CUDAGeneratorState::capture_prologue() {
 | 
			
		||||
  capturing_ = true;
 | 
			
		||||
  offset_intragraph_ = 0;
 | 
			
		||||
  seed_extragraph_.fill_(static_cast<int64_t>(seed_));
 | 
			
		||||
  offset_extragraph_.fill_(0);
 | 
			
		||||
  seed_extragraph_.fill_(int64_t(seed_));
 | 
			
		||||
  offset_extragraph_.fill_(int64_t(0));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
@ -216,8 +216,8 @@ void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
 | 
			
		||||
  at::cuda::assertNotCapturing(
 | 
			
		||||
      "Cannot prepare for replay during capturing stage.");
 | 
			
		||||
  if (wholegraph_increment) {
 | 
			
		||||
      seed_extragraph_.fill_(static_cast<int64_t>(seed_));
 | 
			
		||||
      offset_extragraph_.fill_(static_cast<int64_t>(philox_offset_per_thread_));
 | 
			
		||||
      seed_extragraph_.fill_(int64_t(seed_));
 | 
			
		||||
      offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
 | 
			
		||||
      // Applies the total increment achieved during previous captures to update the
 | 
			
		||||
      // offset.
 | 
			
		||||
      increase(wholegraph_increment);
 | 
			
		||||
@ -329,7 +329,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
 | 
			
		||||
  constexpr size_t offset_size = sizeof(int64_t);
 | 
			
		||||
  constexpr size_t total_size = seed_size + offset_size;
 | 
			
		||||
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(total_size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
 | 
			
		||||
  auto rng_state = state_tensor.data_ptr<uint8_t>();
 | 
			
		||||
  auto current_seed = this->current_seed();
 | 
			
		||||
  auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
 | 
			
		||||
 | 
			
		||||
@ -1,90 +1,78 @@
 | 
			
		||||
#include <ATen/cuda/CUDAGreenContext.h>
 | 
			
		||||
 | 
			
		||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#define HAS_CUDA_GREEN_CONTEXT() 1
 | 
			
		||||
#else
 | 
			
		||||
#define HAS_CUDA_GREEN_CONTEXT() 0
 | 
			
		||||
// Suppress unsued private field warnings as this class is not supposed to be called
 | 
			
		||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
  GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    int driver_version;
 | 
			
		||||
    C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        driver_version >= 12080, "cuda driver too old to use green context!");
 | 
			
		||||
    CUcontext pctx = nullptr;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
 | 
			
		||||
    if (C10_UNLIKELY(!pctx)) {
 | 
			
		||||
      TORCH_WARN(
 | 
			
		||||
          "Attempted to create a green context but"
 | 
			
		||||
          " there was no primary context! Creating a primary context...");
 | 
			
		||||
 | 
			
		||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
  int driver_version;
 | 
			
		||||
  C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      driver_version >= 12080, "cuda driver too old to use green context!");
 | 
			
		||||
  CUcontext pctx = nullptr;
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
 | 
			
		||||
  if (C10_UNLIKELY(!pctx)) {
 | 
			
		||||
    TORCH_WARN(
 | 
			
		||||
        "Attempted to create a green context but"
 | 
			
		||||
        " there was no primary context! Creating a primary context...");
 | 
			
		||||
      cudaFree(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cudaFree(0);
 | 
			
		||||
  }
 | 
			
		||||
    CUdevice device;
 | 
			
		||||
    device_id_ = device_id;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
 | 
			
		||||
 | 
			
		||||
   CUdevice device;
 | 
			
		||||
  device_id_ = device_id;
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
      c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
 | 
			
		||||
    // Get device resources
 | 
			
		||||
    CUdevResource device_resource;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
 | 
			
		||||
        device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
 | 
			
		||||
 | 
			
		||||
  // Get device resources
 | 
			
		||||
  CUdevResource device_resource;
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
 | 
			
		||||
      device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
 | 
			
		||||
    // Split resources
 | 
			
		||||
    std::vector<CUdevResource> result(1);
 | 
			
		||||
    auto result_data = result.data();
 | 
			
		||||
    unsigned int nb_groups = 1;
 | 
			
		||||
    CUdevResource remaining;
 | 
			
		||||
 | 
			
		||||
  // Split resources
 | 
			
		||||
  std::vector<CUdevResource> result(1);
 | 
			
		||||
  auto result_data = result.data();
 | 
			
		||||
  unsigned int nb_groups = 1;
 | 
			
		||||
  CUdevResource remaining;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
 | 
			
		||||
            result_data,
 | 
			
		||||
            &nb_groups,
 | 
			
		||||
            &device_resource,
 | 
			
		||||
            &remaining,
 | 
			
		||||
            0, // default flags
 | 
			
		||||
            num_sms));
 | 
			
		||||
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
      c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
 | 
			
		||||
          result_data,
 | 
			
		||||
          &nb_groups,
 | 
			
		||||
          &device_resource,
 | 
			
		||||
          &remaining,
 | 
			
		||||
          0, // default flags
 | 
			
		||||
          num_sms));
 | 
			
		||||
    TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
 | 
			
		||||
    // Generate resource descriptor
 | 
			
		||||
    CUdevResourceDesc desc;
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
 | 
			
		||||
            &desc, result_data, 1));
 | 
			
		||||
 | 
			
		||||
  // Generate resource descriptor
 | 
			
		||||
  CUdevResourceDesc desc;
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
      c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
 | 
			
		||||
          &desc, result_data, 1));
 | 
			
		||||
    // Create green context
 | 
			
		||||
    // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
 | 
			
		||||
    // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
 | 
			
		||||
        &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
 | 
			
		||||
 | 
			
		||||
  // Create green context
 | 
			
		||||
  // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
 | 
			
		||||
      &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
 | 
			
		||||
 | 
			
		||||
  // Convert to regular context
 | 
			
		||||
  C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
      c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
 | 
			
		||||
  TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
 | 
			
		||||
    // Convert to regular context
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
 | 
			
		||||
    TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
 | 
			
		||||
#else
 | 
			
		||||
  TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::unique_ptr<GreenContext> GreenContext::create(
 | 
			
		||||
      uint32_t num_sms,
 | 
			
		||||
      std::optional<uint32_t> device_id) {
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (!device_id.has_value()) {
 | 
			
		||||
      device_id = at::cuda::current_device();
 | 
			
		||||
    }
 | 
			
		||||
    return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
 | 
			
		||||
    return std::make_unique<GreenContext>(device_id.value(), num_sms);
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
@ -92,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext::GreenContext(GreenContext&& other) noexcept{
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    device_id_ = std::exchange(other.device_id_, -1);
 | 
			
		||||
    green_ctx_ = std::exchange(other.green_ctx_, nullptr);
 | 
			
		||||
    context_ = std::exchange(other.context_, nullptr);
 | 
			
		||||
@ -103,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    if (this != &other) {
 | 
			
		||||
      // Clean up current resources
 | 
			
		||||
      if (green_ctx_) {
 | 
			
		||||
@ -132,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  GreenContext::~GreenContext() noexcept{
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    C10_CUDA_DRIVER_CHECK(
 | 
			
		||||
        c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
 | 
			
		||||
#else
 | 
			
		||||
@ -140,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext GreenContext::getContext() const {
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    return context_;
 | 
			
		||||
#else
 | 
			
		||||
    TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx GreenContext::getGreenContext() const {
 | 
			
		||||
    return green_ctx_;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void GreenContext::setContext() {
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    auto current_stream = c10::cuda::getCurrentCUDAStream();
 | 
			
		||||
    parent_stream_ = current_stream.stream();
 | 
			
		||||
 | 
			
		||||
@ -171,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void GreenContext::popContext() {
 | 
			
		||||
#if HAS_CUDA_GREEN_CONTEXT()
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
    // see above note about stream being hardcoded to the default stream
 | 
			
		||||
    at::cuda::CUDAEvent ev;
 | 
			
		||||
    ev.record(c10::cuda::getCurrentCUDAStream());
 | 
			
		||||
 | 
			
		||||
@ -1,38 +1,53 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <ATen/cuda/CUDAEvent.h>
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
 | 
			
		||||
// Forward declare green context as opaque ptr
 | 
			
		||||
typedef struct CUgreenCtx_st* CUgreenCtx;
 | 
			
		||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
 | 
			
		||||
#include <c10/cuda/driver_api.h>
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 1
 | 
			
		||||
#else
 | 
			
		||||
#define CUDA_HAS_GREEN_CONTEXT 0
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::cuda {
 | 
			
		||||
 | 
			
		||||
class TORCH_CUDA_CPP_API GreenContext {
 | 
			
		||||
 public:
 | 
			
		||||
  // Green context creation
 | 
			
		||||
  static std::unique_ptr<GreenContext> create(
 | 
			
		||||
      uint32_t num_sms,
 | 
			
		||||
      std::optional<uint32_t> device_id);
 | 
			
		||||
  ~GreenContext() noexcept;
 | 
			
		||||
  GreenContext(uint32_t device_id, uint32_t num_sms);
 | 
			
		||||
 | 
			
		||||
  static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
 | 
			
		||||
 | 
			
		||||
  // Delete copy constructor and assignment
 | 
			
		||||
  GreenContext(const GreenContext&) = delete;
 | 
			
		||||
  GreenContext& operator=(const GreenContext&) = delete;
 | 
			
		||||
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext(GreenContext&& other) noexcept;
 | 
			
		||||
  GreenContext& operator=(GreenContext&& other) noexcept;
 | 
			
		||||
  ~GreenContext() noexcept;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying CUDA context
 | 
			
		||||
  CUcontext getContext() const;
 | 
			
		||||
 | 
			
		||||
  // Get the underlying green context
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  CUgreenCtx getGreenContext() const;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Make this context current
 | 
			
		||||
  void setContext();
 | 
			
		||||
 | 
			
		||||
  void popContext();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  GreenContext(uint32_t device_id, uint32_t num_sms);
 | 
			
		||||
  // Implement move operations
 | 
			
		||||
  GreenContext(GreenContext&& other) noexcept;
 | 
			
		||||
  GreenContext& operator=(GreenContext&& other) noexcept;
 | 
			
		||||
 | 
			
		||||
#if CUDA_HAS_GREEN_CONTEXT
 | 
			
		||||
  int32_t device_id_ = -1;
 | 
			
		||||
  CUgreenCtx green_ctx_ = nullptr;
 | 
			
		||||
  CUcontext context_ = nullptr;
 | 
			
		||||
  cudaStream_t parent_stream_ = nullptr;
 | 
			
		||||
#endif
 | 
			
		||||
};
 | 
			
		||||
} // namespace at::cuda
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,17 @@
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
// hipSparse const API added in v2.4.0
 | 
			
		||||
#if HIPSPARSE_VERSION >= 200400
 | 
			
		||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
 | 
			
		||||
#else
 | 
			
		||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
 | 
			
		||||
#endif
 | 
			
		||||
#else // USE_ROCM
 | 
			
		||||
#define AT_USE_HIPSPARSE_GENERIC_API() 0
 | 
			
		||||
#endif // USE_ROCM
 | 
			
		||||
 | 
			
		||||
// cuSparse Generic API spsv function was added in CUDA 11.3.0
 | 
			
		||||
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
 | 
			
		||||
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
 | 
			
		||||
 | 
			
		||||
@ -155,8 +155,8 @@ size_t parseChosenWorkspaceSize() {
 | 
			
		||||
    while (next != end) {
 | 
			
		||||
      std::smatch match = *next;
 | 
			
		||||
      TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
 | 
			
		||||
      size_t curr_size = std::stoull(match.str(1));
 | 
			
		||||
      size_t count = std::stoull(match.str(2));
 | 
			
		||||
      size_t curr_size = (size_t) std::stoi(match.str(1));
 | 
			
		||||
      size_t count = (size_t) std::stoi(match.str(2));
 | 
			
		||||
      total_size += curr_size * 1024 * count;
 | 
			
		||||
      next++;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,8 @@
 | 
			
		||||
#include <ATen/Tensor.h>
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
 | 
			
		||||
#include <mutex>
 | 
			
		||||
 | 
			
		||||
namespace at {
 | 
			
		||||
namespace cuda {
 | 
			
		||||
namespace detail {
 | 
			
		||||
@ -10,36 +12,39 @@ __device__ __constant__ float cublas_one_device;
 | 
			
		||||
__device__ __constant__ float cublas_zero_device;
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_one() {
 | 
			
		||||
  static float *ptr = nullptr;
 | 
			
		||||
  static auto init_flag = [&]() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float one = 1.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
 | 
			
		||||
    return true;
 | 
			
		||||
  }();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_cublas_device_zero() {
 | 
			
		||||
  static float *ptr = nullptr;
 | 
			
		||||
  static auto init_flag = [&]() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    const float zero = 0.f;
 | 
			
		||||
    AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
 | 
			
		||||
    AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
 | 
			
		||||
    return true;
 | 
			
		||||
  }();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  float *ptr;
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
 | 
			
		||||
  return ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
float *get_user_alpha_ptr() {
 | 
			
		||||
  static float *alpha_ptr;
 | 
			
		||||
 | 
			
		||||
  static bool init_flag [[maybe_unused]] = []() {
 | 
			
		||||
  static c10::once_flag init_flag;
 | 
			
		||||
 | 
			
		||||
  c10::call_once(init_flag, []() {
 | 
			
		||||
    AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
 | 
			
		||||
    return true;
 | 
			
		||||
  }();
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return alpha_ptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <c10/util/irange.h>
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
 | 
			
		||||
@ -137,9 +136,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
 | 
			
		||||
    "Weight strides: ", t.strides(), "\n",
 | 
			
		||||
    "cuDNN suggested memory_format: ", memory_format);
 | 
			
		||||
 | 
			
		||||
  std::array<int, CUDNN_DIM_MAX> size;
 | 
			
		||||
  int size[CUDNN_DIM_MAX];
 | 
			
		||||
  for (const auto i : c10::irange(dim)) {
 | 
			
		||||
    size[i] = static_cast<int>(t.size(i));
 | 
			
		||||
    size[i] = (int) t.size(i);
 | 
			
		||||
  }
 | 
			
		||||
  for (const auto i : c10::irange(dim, pad)) {
 | 
			
		||||
    size[i] = 1;
 | 
			
		||||
@ -157,7 +156,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
 | 
			
		||||
  }
 | 
			
		||||
  set(getDataType(t), static_cast<int>(dim), size.data(), filter_format);
 | 
			
		||||
  set(getDataType(t), static_cast<int>(dim), size, filter_format);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/CachingDeviceAllocator.h>
 | 
			
		||||
#include <c10/core/Device.h>
 | 
			
		||||
#include <c10/util/Exception.h>
 | 
			
		||||
 | 
			
		||||
@ -152,36 +151,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool isAvailable() const override;
 | 
			
		||||
 | 
			
		||||
  /* MTIAGraph related APIs */
 | 
			
		||||
  virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
    return -1;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphCaptureEnd(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphInstantiate(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphReplay(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void mtiagraphReset(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual MempoolId_t mtiagraphPool(int64_t handle) const {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API MTIAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
@ -198,7 +198,7 @@ static void autogradBasedTransformSendToNext(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Step 6
 | 
			
		||||
  stack->erase(stack->end() - static_cast<std::ptrdiff_t>(args_size + ret_size), stack->end() - static_cast<std::ptrdiff_t>(ret_size));
 | 
			
		||||
  stack->erase(stack->end() - std::ptrdiff_t(args_size + ret_size), stack->end() - std::ptrdiff_t(ret_size));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void GradInterpreterPtr::processImpl(
 | 
			
		||||
 | 
			
		||||
@ -443,14 +443,14 @@ static bool has_same_shape(
 | 
			
		||||
  if (!tensor.defined()) {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  if (rankWithoutBatchDim(tensor, tensor_bdim) != static_cast<int64_t>(normalized_shape.size())) {
 | 
			
		||||
  if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  const auto tensor_shape = tensor.sizes();
 | 
			
		||||
  for (const auto i : c10::irange(normalized_shape.size())) {
 | 
			
		||||
    auto j = i;
 | 
			
		||||
    // (0, 1, 2), 1 -> (0, 2, 3)
 | 
			
		||||
    if (tensor_bdim.has_value() && static_cast<int64_t>(i) >= tensor_bdim.value()) {
 | 
			
		||||
    if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) {
 | 
			
		||||
      j = j + 1;
 | 
			
		||||
    }
 | 
			
		||||
    if (normalized_shape[i] != tensor_shape[j]) {
 | 
			
		||||
 | 
			
		||||
@ -135,7 +135,7 @@ static void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit
 | 
			
		||||
    reduction_case = ReductionCase::DimArray;
 | 
			
		||||
    dims = arguments[dim_arg_pos].toIntList().vec();
 | 
			
		||||
    if (dims.empty()) {
 | 
			
		||||
      auto all_dims = range(0, std::max(static_cast<int64_t>(1), logical_dim));
 | 
			
		||||
      auto all_dims = range(0, std::max((int64_t)1, logical_dim));
 | 
			
		||||
      dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
 | 
			
		||||
    }
 | 
			
		||||
  } else if (arguments[dim_arg_pos].isInt()) {
 | 
			
		||||
 | 
			
		||||
@ -432,7 +432,7 @@ namespace {
 | 
			
		||||
    // Eg. Given `indexed_shape.size()` is 5 and
 | 
			
		||||
    // shape of `values` is (N, 2, 3), then following block
 | 
			
		||||
    // will reshape `values` to (N, 1, 1, 2, 3).
 | 
			
		||||
    if ( static_cast<int64_t>(indexed_shape.size()) > values_.dim()) {
 | 
			
		||||
    if ( (int64_t) indexed_shape.size() > values_.dim()) {
 | 
			
		||||
      auto values_sizes = values_.sym_sizes();
 | 
			
		||||
 | 
			
		||||
      // number of unit dims (for broadcasting value to indexed_shape)
 | 
			
		||||
 | 
			
		||||
@ -109,7 +109,7 @@ std::tuple<Tensor, std::optional<int64_t>> repeat_batch_rule(
 | 
			
		||||
  SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
 | 
			
		||||
  sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
 | 
			
		||||
  auto self_ = moveBatchDimToFront(self, self_bdim);
 | 
			
		||||
  while (self_.dim() < static_cast<int64_t>(sizes_with_bdim.size())) {
 | 
			
		||||
  while (self_.dim() < (int64_t)sizes_with_bdim.size()) {
 | 
			
		||||
    self_ = self_.unsqueeze(1);
 | 
			
		||||
  }
 | 
			
		||||
  return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0);
 | 
			
		||||
@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) {
 | 
			
		||||
std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    std::optional<int64_t> self_bdim,
 | 
			
		||||
    c10::SymInt diagonal = 0) {
 | 
			
		||||
    int64_t diagonal = 0) {
 | 
			
		||||
  TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
 | 
			
		||||
  auto self_ = moveBatchDimToFront(self, self_bdim);
 | 
			
		||||
  auto result = at::tril_symint(self_, std::move(diagonal));
 | 
			
		||||
  auto result = at::tril(self_, diagonal);
 | 
			
		||||
  return std::make_tuple(std::move(result), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    std::optional<int64_t> self_bdim,
 | 
			
		||||
    c10::SymInt diagonal = 0) {
 | 
			
		||||
    int64_t diagonal = 0) {
 | 
			
		||||
  TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
 | 
			
		||||
  auto self_ = moveBatchDimToFront(self, self_bdim);
 | 
			
		||||
  auto result = at::triu_symint(self_, std::move(diagonal));
 | 
			
		||||
  auto result = at::triu(self_, diagonal);
 | 
			
		||||
  return std::make_tuple(std::move(result), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -191,7 +191,7 @@ static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, t
 | 
			
		||||
      // simplicity. When that is not the case, this code should be updated.
 | 
			
		||||
      const auto& argument = (*stack)[arguments_begin + arg_idx];
 | 
			
		||||
      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
 | 
			
		||||
          || static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
          || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
        // argument isn't a BatchedTensor
 | 
			
		||||
        torch::jit::push(stack, argument);
 | 
			
		||||
        continue;
 | 
			
		||||
@ -345,7 +345,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
 | 
			
		||||
      // simplicity. When that is not the case, this code should be updated.
 | 
			
		||||
      const auto& argument = (*stack)[arguments_begin + arg_idx];
 | 
			
		||||
      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
 | 
			
		||||
          || static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
          || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
        // argument isn't a BatchedTensor
 | 
			
		||||
        torch::jit::push(stack, argument);
 | 
			
		||||
        continue;
 | 
			
		||||
@ -473,7 +473,7 @@ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::ji
 | 
			
		||||
      // simplicity. When that is not the case, this code should be updated.
 | 
			
		||||
      const auto& argument = (*stack)[arguments_begin + arg_idx];
 | 
			
		||||
      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
 | 
			
		||||
          || static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
          || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
 | 
			
		||||
        // argument isn't a BatchedTensor
 | 
			
		||||
        torch::jit::push(stack, argument);
 | 
			
		||||
        continue;
 | 
			
		||||
 | 
			
		||||
@ -157,7 +157,7 @@ Tensor& squeeze__batching_rule(Tensor& self) {
 | 
			
		||||
  const auto physical_shape = batched->value().sizes();
 | 
			
		||||
  auto how_many_dims_of_size_1_before_bdim = 0;
 | 
			
		||||
  for (const auto i : c10::irange(0, physical_shape.size())) {
 | 
			
		||||
    if (static_cast<int64_t>(i) == bdim) {
 | 
			
		||||
    if ((int64_t)i == bdim) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    if (physical_shape[i] == 1) {
 | 
			
		||||
@ -573,7 +573,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
 | 
			
		||||
  std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional(static_cast<int64_t>(0)) : std::nullopt;
 | 
			
		||||
  std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
 | 
			
		||||
  auto result = at::cat(tensors_to_cat, new_dim);
 | 
			
		||||
  return makeBatched(result, new_bdim, get_current_level());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
//  Copyright © 2022 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <c10/util/CallOnce.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/mps/IndexKernels.h>
 | 
			
		||||
#include <ATen/mps/MPSAllocatorInterface.h>
 | 
			
		||||
#include <ATen/mps/MPSDevice.h>
 | 
			
		||||
@ -8,6 +10,9 @@
 | 
			
		||||
 | 
			
		||||
namespace at::mps {
 | 
			
		||||
 | 
			
		||||
static std::unique_ptr<MPSDevice> mps_device;
 | 
			
		||||
static c10::once_flag mpsdev_init;
 | 
			
		||||
 | 
			
		||||
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
 | 
			
		||||
  // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
 | 
			
		||||
  // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+
 | 
			
		||||
@ -16,8 +21,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MPSDevice* MPSDevice::getInstance() {
 | 
			
		||||
  static MPSDevice mps_device;
 | 
			
		||||
  return &mps_device;
 | 
			
		||||
  c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr<MPSDevice>(new MPSDevice()); });
 | 
			
		||||
  return mps_device.get();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MPSDevice::~MPSDevice() {
 | 
			
		||||
 | 
			
		||||
@ -25,19 +25,18 @@ TORCH_PRECOMPUTE_META_FUNC(avg_pool2d)
 | 
			
		||||
  // #20866, #22032: Guarantee this for the official C++ API?
 | 
			
		||||
  TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
 | 
			
		||||
    "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints");
 | 
			
		||||
  const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
 | 
			
		||||
  const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
 | 
			
		||||
  const int64_t kH = kernel_size[0];
 | 
			
		||||
  const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1];
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2,
 | 
			
		||||
    "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints");
 | 
			
		||||
  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
 | 
			
		||||
  const int dW = stride.empty() ? kW :
 | 
			
		||||
                 stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
 | 
			
		||||
  const int64_t dH = stride.empty() ? kH : stride[0];
 | 
			
		||||
  const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1];
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
 | 
			
		||||
    "avg_pool2d: padding must either be a single int, or a tuple of two ints");
 | 
			
		||||
  const int padH = safe_downcast<int, int64_t>(padding[0]);
 | 
			
		||||
  const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
 | 
			
		||||
  const int64_t padH = padding[0];
 | 
			
		||||
  const int64_t padW = padding.size() == 1 ? padH : padding[1];
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0,
 | 
			
		||||
    "divisor must be not zero");
 | 
			
		||||
 | 
			
		||||
@ -198,9 +198,9 @@ void avg_pool3d_out_frame(
 | 
			
		||||
            int64_t hend = std::min(hstart + kH, iheight + padH);
 | 
			
		||||
            int64_t wend = std::min(wstart + kW, iwidth + padW);
 | 
			
		||||
            int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
 | 
			
		||||
            tstart = std::max(tstart, static_cast<int64_t>(0));
 | 
			
		||||
            hstart = std::max(hstart, static_cast<int64_t>(0));
 | 
			
		||||
            wstart = std::max(wstart, static_cast<int64_t>(0));
 | 
			
		||||
            tstart = std::max(tstart, (int64_t) 0);
 | 
			
		||||
            hstart = std::max(hstart, (int64_t) 0);
 | 
			
		||||
            wstart = std::max(wstart, (int64_t) 0);
 | 
			
		||||
            tend = std::min(tend, itime);
 | 
			
		||||
            hend = std::min(hend, iheight);
 | 
			
		||||
            wend = std::min(wend, iwidth);
 | 
			
		||||
@ -377,9 +377,9 @@ void avg_pool3d_backward_out_frame(
 | 
			
		||||
            int64_t hend = std::min(hstart + kH, iheight + padH);
 | 
			
		||||
            int64_t wend = std::min(wstart + kW, iwidth + padW);
 | 
			
		||||
            int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
 | 
			
		||||
            tstart = std::max(tstart, static_cast<int64_t>(0));
 | 
			
		||||
            hstart = std::max(hstart, static_cast<int64_t>(0));
 | 
			
		||||
            wstart = std::max(wstart, static_cast<int64_t>(0));
 | 
			
		||||
            tstart = std::max(tstart, (int64_t) 0);
 | 
			
		||||
            hstart = std::max(hstart, (int64_t) 0);
 | 
			
		||||
            wstart = std::max(wstart, (int64_t) 0);
 | 
			
		||||
            tend = std::min(tend, itime);
 | 
			
		||||
            hend = std::min(hend, iheight);
 | 
			
		||||
            wend = std::min(wend, iwidth);
 | 
			
		||||
 | 
			
		||||
@ -946,10 +946,10 @@ void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& in
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  // avoid overflow
 | 
			
		||||
  auto matrix_rank = std::min(m, n);
 | 
			
		||||
  float matrix_rank = float(std::min(m, n));
 | 
			
		||||
  // A heuristic tested on a 32 core/socket ICX system
 | 
			
		||||
  // https://github.com/pytorch/pytorch/pull/93037#discussion_r1090112948
 | 
			
		||||
  int64_t chunk_size_per_thread = static_cast<int64_t>(
 | 
			
		||||
  int64_t chunk_size_per_thread = int64_t(
 | 
			
		||||
      std::min(1.0, 3200.0 / (matrix_rank * matrix_rank * matrix_rank)));
 | 
			
		||||
  int64_t grain_size = chunk_size_per_thread * at::get_num_threads();
 | 
			
		||||
  at::parallel_for(0, batch_size, grain_size, loop);
 | 
			
		||||
 | 
			
		||||
@ -267,7 +267,7 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
 | 
			
		||||
 | 
			
		||||
  float input_scale = scale_a.item<float>();
 | 
			
		||||
  float weight_scale = scale_b.item<float>();
 | 
			
		||||
  float output_scale = 1.0f;
 | 
			
		||||
  float output_scale = float(1.0);
 | 
			
		||||
  if (scale_result.has_value() &&
 | 
			
		||||
      (*out_dtype == ScalarType::Float8_e4m3fn ||
 | 
			
		||||
       *out_dtype == ScalarType::Float8_e5m2)) {
 | 
			
		||||
 | 
			
		||||
@ -331,7 +331,7 @@ bool gemv_use_fast_path<double>(
 | 
			
		||||
    [[maybe_unused]] double beta,
 | 
			
		||||
    int64_t incy) {
 | 
			
		||||
  return gemv_use_fast_path<float>(
 | 
			
		||||
      trans, m, n, static_cast<float>(alpha), lda, incx, static_cast<float>(beta), incy);
 | 
			
		||||
      trans, m, n, (float)alpha, lda, incx, (float)beta, incy);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
@ -523,8 +523,8 @@ static inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
 | 
			
		||||
  if (n == 1) incx = 1;
 | 
			
		||||
#if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if (blas_impl::scal_use_fast_path<scalar_t>(n, incx)) {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    blas_impl::scal_fast_path<scalar_t>(&i_n, &a, x, &i_incx);
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
@ -545,11 +545,11 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i
 | 
			
		||||
#if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
 | 
			
		||||
    TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
 | 
			
		||||
    int i_m = static_cast<int>(m);
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_lda = static_cast<int>(lda);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_m = (int)m;
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_lda = (int)lda;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -680,9 +680,9 @@ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
 | 
			
		||||
  {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_daxpy(i_n, a, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -705,9 +705,9 @@ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t in
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
 | 
			
		||||
  {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_saxpy(i_n, a, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -730,9 +730,9 @@ void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int6
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
 | 
			
		||||
  {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -755,9 +755,9 @@ void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
 | 
			
		||||
  {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_caxpy(i_n, &a, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -781,9 +781,9 @@ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) {
 | 
			
		||||
  }
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_dcopy(i_n, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -805,9 +805,9 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) {
 | 
			
		||||
  }
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_scopy(i_n, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -829,9 +829,9 @@ void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<d
 | 
			
		||||
  }
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_zcopy(i_n, x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -853,9 +853,9 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
 | 
			
		||||
  }
 | 
			
		||||
  #if AT_BUILD_WITH_BLAS()
 | 
			
		||||
  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
 | 
			
		||||
    int i_n = static_cast<int>(n);
 | 
			
		||||
    int i_incx = static_cast<int>(incx);
 | 
			
		||||
    int i_incy = static_cast<int>(incy);
 | 
			
		||||
    int i_n = (int)n;
 | 
			
		||||
    int i_incx = (int)incx;
 | 
			
		||||
    int i_incy = (int)incy;
 | 
			
		||||
    #if C10_IOS
 | 
			
		||||
    cblas_ccopy(i_n, &x, i_incx, y, i_incy);
 | 
			
		||||
    #else
 | 
			
		||||
@ -1082,7 +1082,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
 | 
			
		||||
        M,
 | 
			
		||||
        N,
 | 
			
		||||
        K,
 | 
			
		||||
        1,
 | 
			
		||||
        int64_t(1),
 | 
			
		||||
        ld_a,
 | 
			
		||||
        ld_b,
 | 
			
		||||
        ld_c,
 | 
			
		||||
@ -1096,7 +1096,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
 | 
			
		||||
          M,
 | 
			
		||||
          N,
 | 
			
		||||
          K,
 | 
			
		||||
          1,
 | 
			
		||||
          int64_t(1),
 | 
			
		||||
          ld_a,
 | 
			
		||||
          ld_b,
 | 
			
		||||
          ld_c,
 | 
			
		||||
 | 
			
		||||
@ -410,8 +410,8 @@ struct ConvParams {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
 | 
			
		||||
    // broken on cuDNN 9.8 - 9.14
 | 
			
		||||
    if (cudnn_version >= 90800 && cudnn_version < 91500) {
 | 
			
		||||
    // broken on cuDNN 9.8
 | 
			
		||||
    if (cudnn_version >= 90800) {
 | 
			
		||||
      if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
 | 
			
		||||
          (input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
 | 
			
		||||
          weight.dim() == 5) {
 | 
			
		||||
 | 
			
		||||
@ -487,17 +487,17 @@ static Tensor _grid_sampler_2d_cpu_quantized(
 | 
			
		||||
  int64_t out_sC = output.stride(1);
 | 
			
		||||
  int64_t out_sH = output.stride(2);
 | 
			
		||||
  int64_t out_sW = output.stride(3);
 | 
			
		||||
  const uint8_t* inp_ptr = input.const_data_ptr<uint8_t>();
 | 
			
		||||
  uint8_t* out_ptr = output.data_ptr<uint8_t>();
 | 
			
		||||
  const float* grid_ptr = grid.const_data_ptr<float>();
 | 
			
		||||
  uint8_t* inp_ptr = (uint8_t*)input.data_ptr<quint8>();
 | 
			
		||||
  uint8_t* out_ptr = (uint8_t*)output.data_ptr<quint8>();
 | 
			
		||||
  float* grid_ptr = grid.data_ptr<float>();
 | 
			
		||||
  at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
 | 
			
		||||
    for (const auto n : c10::irange(start, end)) {
 | 
			
		||||
      const float* grid_ptr_N = grid_ptr + n * grid_sN;
 | 
			
		||||
      const uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
 | 
			
		||||
      float* grid_ptr_N = grid_ptr + n * grid_sN;
 | 
			
		||||
      uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
 | 
			
		||||
      for (const auto h : c10::irange(out_H)) {
 | 
			
		||||
        for (const auto w : c10::irange(out_W)) {
 | 
			
		||||
          // get the corresponding input x, y, z coordinates from grid
 | 
			
		||||
          const float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
 | 
			
		||||
          float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
 | 
			
		||||
          float x = *grid_ptr_NHW;
 | 
			
		||||
          float y = grid_ptr_NHW[grid_sCoor];
 | 
			
		||||
 | 
			
		||||
@ -527,7 +527,7 @@ static Tensor _grid_sampler_2d_cpu_quantized(
 | 
			
		||||
          float se = (ix - ix_nw) * (iy - iy_nw);
 | 
			
		||||
 | 
			
		||||
          // calculate bilinear weighted pixel value and set output pixel
 | 
			
		||||
          const uint8_t* inp_ptr_NC = inp_ptr_N;
 | 
			
		||||
          uint8_t* inp_ptr_NC = inp_ptr_N;
 | 
			
		||||
          uint8_t* out_ptr_NCHW =
 | 
			
		||||
              out_ptr + n * out_sN + h * out_sH + w * out_sW;
 | 
			
		||||
          for (int64_t c = 0; c < C;
 | 
			
		||||
 | 
			
		||||
@ -318,7 +318,7 @@ static std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArr
 | 
			
		||||
 | 
			
		||||
    const int64_t N = self.size(-1);
 | 
			
		||||
    const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
 | 
			
		||||
            static_cast<int64_t>(1), std::multiplies<int64_t>());
 | 
			
		||||
            (int64_t)1, std::multiplies<int64_t>());
 | 
			
		||||
    Tensor reshaped_self = self.reshape({ M, N });
 | 
			
		||||
 | 
			
		||||
    auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ Tensor do_trapezoid(const Tensor& y, const Tensor& dx, int64_t dim) {
 | 
			
		||||
// When dx is constant, the above formula simplifies
 | 
			
		||||
// to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
 | 
			
		||||
Tensor do_trapezoid(const Tensor& y, double dx, int64_t dim) {
 | 
			
		||||
    return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * 0.5) * dx;
 | 
			
		||||
    return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor zeros_like_except(const Tensor& y, int64_t dim) {
 | 
			
		||||
 | 
			
		||||
@ -201,7 +201,7 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra
 | 
			
		||||
  out_size.reserve(out_num_dim);
 | 
			
		||||
  for (auto& d : lro) out_size.push_back(left.sym_size(d));
 | 
			
		||||
  for (auto& d : lo) out_size.push_back(left.sym_size(d));
 | 
			
		||||
  for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)d; }; // avoid warning about not using d
 | 
			
		||||
  for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d
 | 
			
		||||
  for (auto& d : ro) out_size.push_back(right.sym_size(d));
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> lpermutation(lro);
 | 
			
		||||
@ -640,7 +640,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return std::move(ops[0]);
 | 
			
		||||
  return ops[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// _trilinear computes a trilinear einstein sum with an unrolled dimension
 | 
			
		||||
@ -805,7 +805,7 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
 | 
			
		||||
  std::vector<SymInt> rsizes;  // rsizes: sizes of the result
 | 
			
		||||
  p1.reserve(input1.dim());
 | 
			
		||||
  p2.reserve(input2.dim());
 | 
			
		||||
  rsizes.reserve(input1.dim() + input2.dim() - static_cast<int64_t>(dims1.size()));
 | 
			
		||||
  rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
 | 
			
		||||
  SymInt size1 = 1; // number of non-contracted elements in input1
 | 
			
		||||
  SymInt size2 = 1; // number of non-contracted elements in input2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1655,7 +1655,7 @@ static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self,
 | 
			
		||||
  auto s0 = self.accessor<const scalar_t, 3>();
 | 
			
		||||
  auto m0 = mat2.accessor<const scalar_t, 3>();
 | 
			
		||||
 | 
			
		||||
  int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), static_cast<int64_t>(1));
 | 
			
		||||
  int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
 | 
			
		||||
  using opmath_t = at::opmath_type<scalar_t>;
 | 
			
		||||
  parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
 | 
			
		||||
      for (const auto b : c10::irange(b_begin, b_end)) {
 | 
			
		||||
 | 
			
		||||
@ -235,7 +235,7 @@ void nll_loss_out_frame(
 | 
			
		||||
 | 
			
		||||
  constexpr int64_t cascade_sum_num_levels = 8;
 | 
			
		||||
  const int64_t level_power =
 | 
			
		||||
      std::max(static_cast<int64_t>(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
 | 
			
		||||
      std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
 | 
			
		||||
  const int64_t level_step = (1 << level_power);
 | 
			
		||||
  const int64_t level_mask = level_step - 1;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -129,7 +129,7 @@ void nll_loss2d_forward_out_frame(
 | 
			
		||||
      for (const auto b : c10::irange(start, end)) {
 | 
			
		||||
        for (const auto h : c10::irange(H)) {
 | 
			
		||||
          for (const auto w : c10::irange(W)) {
 | 
			
		||||
            const int64_t cur_target = target_acc[b][h][w];
 | 
			
		||||
            const int64_t cur_target = (int64_t)target_acc[b][h][w];
 | 
			
		||||
 | 
			
		||||
            if (cur_target == ignore_index) {
 | 
			
		||||
              output_acc[b][h][w] = static_cast<scalar_t>(0);
 | 
			
		||||
@ -188,7 +188,7 @@ void nll_loss2d_forward_out_frame(
 | 
			
		||||
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
 | 
			
		||||
  scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
 | 
			
		||||
  const int64_t level_power =
 | 
			
		||||
      std::max(static_cast<int64_t>(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
 | 
			
		||||
      std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
 | 
			
		||||
  const int64_t level_step = (1 << level_power);
 | 
			
		||||
  const int64_t level_mask = level_step - 1;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -192,7 +192,7 @@ Date:  February 1996
 | 
			
		||||
  x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(c10::pi<double>)))*std::exp(-x*x));
 | 
			
		||||
  x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(c10::pi<double>)))*std::exp(-x*x));
 | 
			
		||||
 | 
			
		||||
  return x;
 | 
			
		||||
  return(x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#undef CENTRAL_RANGE
 | 
			
		||||
@ -3819,7 +3819,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
 | 
			
		||||
 | 
			
		||||
    if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
 | 
			
		||||
        if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
 | 
			
		||||
            return std::cos((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
 | 
			
		||||
            return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (n % 2 == 0) {
 | 
			
		||||
 | 
			
		||||
@ -193,22 +193,22 @@ Tensor _nnpack_spatial_convolution(
 | 
			
		||||
  const size_t input_channels = input.size(1);
 | 
			
		||||
  const size_t output_channels = weight.size(0);
 | 
			
		||||
  const struct nnp_size input_size = {
 | 
			
		||||
      .width = static_cast<size_t>(input.size(3)),
 | 
			
		||||
      .height = static_cast<size_t>(input.size(2)),
 | 
			
		||||
      .width = (size_t)input.size(3),
 | 
			
		||||
      .height = (size_t)input.size(2),
 | 
			
		||||
  };
 | 
			
		||||
  const struct nnp_padding input_padding = {
 | 
			
		||||
      .top = static_cast<size_t>(padding[0]),
 | 
			
		||||
      .right = static_cast<size_t>(padding[1]),
 | 
			
		||||
      .bottom = static_cast<size_t>(padding[0]),
 | 
			
		||||
      .left = static_cast<size_t>(padding[1]),
 | 
			
		||||
      .top = (size_t)padding[0],
 | 
			
		||||
      .right = (size_t)padding[1],
 | 
			
		||||
      .bottom = (size_t)padding[0],
 | 
			
		||||
      .left = (size_t)padding[1],
 | 
			
		||||
  };
 | 
			
		||||
  const struct nnp_size kernel_size = {
 | 
			
		||||
      .width = static_cast<size_t>(weight.size(3)),
 | 
			
		||||
      .height = static_cast<size_t>(weight.size(2)),
 | 
			
		||||
      .width = (size_t)weight.size(3),
 | 
			
		||||
      .height = (size_t)weight.size(2),
 | 
			
		||||
  };
 | 
			
		||||
  const struct nnp_size output_size = {
 | 
			
		||||
      .width = static_cast<size_t>(output.size(3)),
 | 
			
		||||
      .height = static_cast<size_t>(output.size(2)),
 | 
			
		||||
      .width = (size_t)output.size(3),
 | 
			
		||||
      .height = (size_t)output.size(2),
 | 
			
		||||
  };
 | 
			
		||||
  const nnp_size output_subsample = {
 | 
			
		||||
      .width = static_cast<std::size_t>(stride[1]),
 | 
			
		||||
 | 
			
		||||
@ -248,8 +248,8 @@ void slow_conv_transpose3d_out_cpu_template(
 | 
			
		||||
  Tensor weight = weight_.contiguous();
 | 
			
		||||
  Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
 | 
			
		||||
 | 
			
		||||
  const auto n_input_plane = weight.size(0);
 | 
			
		||||
  const auto n_output_plane = weight.size(1);
 | 
			
		||||
  const int n_input_plane = (int)weight.size(0);
 | 
			
		||||
  const int n_output_plane = (int)weight.size(1);
 | 
			
		||||
 | 
			
		||||
  bool is_batch = false;
 | 
			
		||||
  if (input.dim() == 4) {
 | 
			
		||||
 | 
			
		||||
@ -84,8 +84,8 @@ static std::vector<int64_t> aligned_size(
 | 
			
		||||
    DimnameList aligned_names,
 | 
			
		||||
    bool is_aligning_two_tensors) {
 | 
			
		||||
  std::vector<int64_t> expanded_sizes(aligned_names.size(), 1);
 | 
			
		||||
  ptrdiff_t dim = static_cast<ptrdiff_t>(tensor_sizes.size()) - 1;
 | 
			
		||||
  ptrdiff_t idx = static_cast<ptrdiff_t>(aligned_names.size()) - 1;
 | 
			
		||||
  ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
 | 
			
		||||
  ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
 | 
			
		||||
  for (; idx >= 0 && dim >= 0; --idx) {
 | 
			
		||||
    if (tensor_names[dim] != aligned_names[idx]) {
 | 
			
		||||
      continue;
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ std::tuple<Tensor, Tensor> _rowwise_prune_helper(
 | 
			
		||||
  auto mask_contig = mask.contiguous();
 | 
			
		||||
  auto mask_data = mask_contig.data_ptr<bool>();
 | 
			
		||||
  for (const auto i : c10::irange(mask.numel())) {
 | 
			
		||||
    num_non_masked_rows += ((mask_data[i] == true) ? 1 : 0);
 | 
			
		||||
    num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
 | 
			
		||||
  }
 | 
			
		||||
  int num_cols = weights.size(1);
 | 
			
		||||
  auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,7 @@ void host_softmax(
 | 
			
		||||
  scalar_t* input_data_base = input.data_ptr<scalar_t>();
 | 
			
		||||
  scalar_t* output_data_base = output.data_ptr<scalar_t>();
 | 
			
		||||
  bool* mask_data_base = mask;
 | 
			
		||||
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
 | 
			
		||||
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
 | 
			
		||||
  parallel_for(
 | 
			
		||||
      0, outer_size * inner_size, grain_size,
 | 
			
		||||
      [&](int64_t begin, int64_t end) {
 | 
			
		||||
@ -265,7 +265,7 @@ void host_softmax_backward(
 | 
			
		||||
  scalar_t* output_data_base = output.data_ptr<scalar_t>();
 | 
			
		||||
  scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
 | 
			
		||||
  bool* mask_data_base = mask;
 | 
			
		||||
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
 | 
			
		||||
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
 | 
			
		||||
  parallel_for(
 | 
			
		||||
      0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
 | 
			
		||||
        for (const auto i : c10::irange(begin, end)) {
 | 
			
		||||
 | 
			
		||||
@ -1701,13 +1701,13 @@ Tensor& index_select_out_cpu_(
 | 
			
		||||
                  TORCH_CHECK_INDEX(
 | 
			
		||||
                      (self_i >= 0) && (self_i < self_dim_size),
 | 
			
		||||
                      "index out of range in self");
 | 
			
		||||
                  auto self_data = const_cast<char*>(static_cast<const char*>(
 | 
			
		||||
                                       selfSlice_data)) +
 | 
			
		||||
                  auto self_data = static_cast<const char*>(selfSlice_data) +
 | 
			
		||||
                      self_i * self_stride_bytes;
 | 
			
		||||
                  auto result_data = static_cast<char*>(resultSlice_data) +
 | 
			
		||||
                      i * result_stride_bytes;
 | 
			
		||||
                  sub_iter.unsafe_replace_operand(0, result_data);
 | 
			
		||||
                  sub_iter.unsafe_replace_operand(1, self_data);
 | 
			
		||||
                  sub_iter.unsafe_replace_operand(
 | 
			
		||||
                      1, const_cast<char*>(self_data));
 | 
			
		||||
                  copy_stub(sub_iter.device_type(), sub_iter, false);
 | 
			
		||||
                };
 | 
			
		||||
              });
 | 
			
		||||
 | 
			
		||||
@ -1382,7 +1382,7 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) {
 | 
			
		||||
  // use no-initialization Fischer-Yates variant
 | 
			
		||||
  // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_.22inside-out.22_algorithm
 | 
			
		||||
  for (int64_t i = 0; i < n; i++) {
 | 
			
		||||
    int64_t z = static_cast<int64_t>(generator->random64() % (i + 1));
 | 
			
		||||
    int64_t z = (int64_t)(generator->random64() % (i + 1));
 | 
			
		||||
    r__data[i * r__stride_0] = i;
 | 
			
		||||
    r__data[i * r__stride_0] = r__data[z * r__stride_0];
 | 
			
		||||
    r__data[z * r__stride_0] = i;
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user