mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 14:59:34 +08:00
Compare commits
40 Commits
annotate_f
...
ciflow/ind
Author | SHA1 | Date | |
---|---|---|---|
c831fcb0b1 | |||
20111eac3e | |||
fcbde24c1c | |||
861cdb887b | |||
3154482072 | |||
9fccbdd4f0 | |||
7dabfb07cb | |||
d0add0be43 | |||
11e2084308 | |||
9726553653 | |||
d82527b32a | |||
5d9b024276 | |||
5b2afe4c5d | |||
b2953f5643 | |||
470e2f61c3 | |||
e0fe37fa68 | |||
d2c82bafb7 | |||
98a488c9aa | |||
5b3ea75895 | |||
556fc09a9f | |||
ce109b3f79 | |||
4d833f859b | |||
d7e275d4b4 | |||
d5db3aee0d | |||
5641de7b6b | |||
cbc08c8993 | |||
1a54d3333d | |||
4c1c341fa0 | |||
5f21cc786a | |||
e86942f422 | |||
2cd5fd1588 | |||
7d0f872cb3 | |||
fb06e49ce8 | |||
27a98e6ae9 | |||
b10f463b1a | |||
431c13cf61 | |||
aead9270f5 | |||
9bf5b38c14 | |||
aba8c43594 | |||
37f3ba274a |
@ -113,6 +113,7 @@ case "$tag" in
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
INSTALL_MINGW=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
|
||||
CUDA_VERSION=13.0.0
|
||||
@ -361,6 +362,7 @@ docker build \
|
||||
--build-arg "OPENBLAS=${OPENBLAS:-}" \
|
||||
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
|
||||
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
|
||||
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
|
||||
-f $(dirname ${DOCKERFILE})/Dockerfile \
|
||||
-t "$tmp_tag" \
|
||||
"$@" \
|
||||
|
10
.ci/docker/common/install_mingw.sh
Normal file
10
.ci/docker/common/install_mingw.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Install MinGW-w64 for Windows cross-compilation
|
||||
apt-get update
|
||||
apt-get install -y g++-mingw-w64-x86-64-posix
|
||||
|
||||
echo "MinGW-w64 installed successfully"
|
||||
x86_64-w64-mingw32-g++ --version
|
@ -20,7 +20,7 @@ pip_install \
|
||||
|
||||
pip_install coloredlogs packaging
|
||||
pip_install onnxruntime==1.23.0
|
||||
pip_install onnxscript==0.5.3
|
||||
pip_install onnxscript==0.5.4
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||
|
@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
ARG INSTALL_MINGW
|
||||
COPY ./common/install_mingw.sh install_mingw.sh
|
||||
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
|
||||
RUN rm install_mingw.sh
|
||||
|
||||
ARG TRITON
|
||||
ARG TRITON_CPU
|
||||
|
||||
|
@ -485,6 +485,22 @@ test_inductor_aoti() {
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
test_inductor_aoti_cross_compile_for_windows() {
|
||||
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
|
||||
# Set WINDOWS_CUDA_HOME environment variable
|
||||
WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
|
||||
export WINDOWS_CUDA_HOME
|
||||
|
||||
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
|
||||
echo "Contents:"
|
||||
ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
|
||||
|
||||
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
|
||||
}
|
||||
|
||||
test_inductor_cpp_wrapper_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
||||
@ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){
|
||||
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
|
||||
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
||||
|
||||
if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
|
||||
if [[ "$(uname -m)" != "aarch64" ]]; then
|
||||
# Use Intel OpenMP for x86
|
||||
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
|
||||
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
|
||||
@ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){
|
||||
cores=$((cpus / thread_per_core))
|
||||
|
||||
# Set number of cores to 16 on aarch64 for performance runs
|
||||
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
|
||||
if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
|
||||
cores=16
|
||||
fi
|
||||
export OMP_NUM_THREADS=$cores
|
||||
@ -1667,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
|
||||
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
|
||||
fi
|
||||
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
|
||||
test_linux_aarch64
|
||||
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
|
||||
test_forward_backward_compatibility
|
||||
@ -1718,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
test_inductor_micro_benchmark
|
||||
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
|
||||
test_inductor_aoti_cross_compile_for_windows
|
||||
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
|
||||
install_torchvision
|
||||
id=$((SHARD_NUMBER-1))
|
||||
|
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
|
40
.github/workflows/_linux-test.yml
vendored
40
.github/workflows/_linux-test.yml
vendored
@ -224,6 +224,46 @@ jobs:
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/download-td-artifacts
|
||||
|
||||
- name: Download Windows torch wheel for cross-compilation
|
||||
if: matrix.win_torch_wheel_artifact != ''
|
||||
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
|
||||
with:
|
||||
name: ${{ matrix.win_torch_wheel_artifact }}
|
||||
path: win-torch-wheel
|
||||
|
||||
- name: Extract Windows wheel and setup CUDA libraries
|
||||
if: matrix.win_torch_wheel_artifact != ''
|
||||
shell: bash
|
||||
run: |
|
||||
set -x
|
||||
|
||||
# Find the wheel file
|
||||
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
|
||||
if [ -z "$WHEEL_FILE" ]; then
|
||||
echo "Error: No wheel file found in win-torch-wheel directory"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found wheel file: $WHEEL_FILE"
|
||||
|
||||
# Unzip the wheel file
|
||||
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
|
||||
echo "Extracted wheel contents"
|
||||
|
||||
# Setup CUDA libraries (cuda.lib and cudart.lib) directory
|
||||
mkdir -p win-torch-wheel-extracted/lib/x64
|
||||
if [ -f "win-torch-wheel/cuda.lib" ]; then
|
||||
mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
|
||||
echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
|
||||
fi
|
||||
if [ -f "win-torch-wheel/cudart.lib" ]; then
|
||||
mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
|
||||
echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
|
||||
fi
|
||||
|
||||
# Verify CUDA libraries are present
|
||||
echo "CUDA libraries:"
|
||||
ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
|
||||
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
|
25
.github/workflows/_win-build.yml
vendored
25
.github/workflows/_win-build.yml
vendored
@ -168,6 +168,31 @@ jobs:
|
||||
run: |
|
||||
.ci/pytorch/win-build.sh
|
||||
|
||||
# Collect Windows torch libs and CUDA libs for cross-compilation
|
||||
- name: Collect Windows CUDA libs for cross-compilation
|
||||
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
|
||||
shell: bash
|
||||
run: |
|
||||
set -ex
|
||||
|
||||
# Create directory structure if does not exist
|
||||
mkdir -p /c/${{ github.run_id }}/build-results
|
||||
|
||||
# Copy CUDA libs
|
||||
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
|
||||
|
||||
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
|
||||
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
|
||||
fi
|
||||
|
||||
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
|
||||
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
|
||||
fi
|
||||
|
||||
# List collected files
|
||||
echo "Collected CUDA libs:"
|
||||
ls -lah /c/${{ github.run_id }}/build-results/*.lib
|
||||
|
||||
# Upload to github so that people can click and download artifacts
|
||||
- name: Upload artifacts to s3
|
||||
if: steps.build.outcome != 'skipped'
|
||||
|
62
.github/workflows/b200-distributed.yml
vendored
Normal file
62
.github/workflows/b200-distributed.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
name: CI for distributed tests on B200
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/b200-distributed.yml
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/b200-distributed/*
|
||||
schedule:
|
||||
- cron: 46 8 * * * # about 1:46am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
|
||||
with:
|
||||
timeout-minutes: 1200
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
@ -88,27 +88,27 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
12
.github/workflows/rocm-mi355.yml
vendored
12
.github/workflows/rocm-mi355.yml
vendored
@ -45,12 +45,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
17
.github/workflows/trunk.yml
vendored
17
.github/workflows/trunk.yml
vendored
@ -200,6 +200,23 @@ jobs:
|
||||
cuda-arch-list: '8.0'
|
||||
secrets: inherit
|
||||
|
||||
# Test cross-compiled models with Windows libs extracted from wheel
|
||||
cross-compile-linux-test:
|
||||
name: cross-compile-linux-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build
|
||||
- get-label-type
|
||||
- win-vs2022-cuda12_8-py3-build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
verify-cachebench-cpu-build:
|
||||
name: verify-cachebench-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at::autocast {
|
||||
@ -37,29 +36,10 @@ namespace {
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||
// are not incorrectly reused in gradient-enabled contexts.
|
||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||
return cached_casts_grad_enabled;
|
||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
||||
return cached_casts;
|
||||
}
|
||||
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||
return cached_casts_grad_disabled;
|
||||
}
|
||||
|
||||
// Helper function to get the appropriate cache based on current gradient mode.
|
||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||
// preventing incorrect cache hits when gradient mode changes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
return at::GradMode::is_enabled() ?
|
||||
get_cached_casts_grad_enabled() :
|
||||
get_cached_casts_grad_disabled();
|
||||
}
|
||||
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
|
||||
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||
get_cached_casts_grad_enabled().clear();
|
||||
get_cached_casts_grad_disabled().clear();
|
||||
get_cached_casts().clear();
|
||||
}
|
||||
|
||||
int increment_nesting() {
|
||||
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||
// See cached_casts declaration above for detailed strategy.
|
||||
//
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||
// This fixes issue #158232 without any performance regression.
|
||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||
|
@ -229,10 +229,10 @@ private:
|
||||
}
|
||||
|
||||
|
||||
static const uint32_t kPhilox10A = 0x9E3779B9;
|
||||
static const uint32_t kPhilox10B = 0xBB67AE85;
|
||||
static const uint32_t kPhiloxSA = 0xD2511F53;
|
||||
static const uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
|
||||
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
|
||||
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
|
||||
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
typedef philox_engine Philox4_32;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
794
aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
Normal file
794
aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
Normal file
@ -0,0 +1,794 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_INT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<int##bit##_t> { \
|
||||
using neon_type = int##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = int##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_s##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(int##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ int##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_s##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<int##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
int64_t count = size()); \
|
||||
void store(void* ptr, int64_t count = size()) const; \
|
||||
template <int64_t mask> \
|
||||
static Vectorized<int##bit##_t> blend( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b); \
|
||||
static Vectorized<int##bit##_t> blendv( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b, \
|
||||
const Vectorized<int##bit##_t>& mask_) { \
|
||||
return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<int##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<int##bit##_t> set( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b, \
|
||||
int64_t count = size()); \
|
||||
const int##bit##_t& operator[](int idx) const = delete; \
|
||||
int##bit##_t& operator[](int idx) = delete; \
|
||||
Vectorized<int##bit##_t> abs() const { \
|
||||
return vabsq_s##bit(values); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<int##bit##_t> imag() const { \
|
||||
return vdupq_n_s##bit(0); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<int##bit##_t> neg() const { \
|
||||
return vnegq_s##bit(values); \
|
||||
} \
|
||||
int##bit##_t reduce_add() const { \
|
||||
return vaddvq_s##bit(values); \
|
||||
} \
|
||||
int##bit##_t reduce_max() const; \
|
||||
Vectorized<int##bit##_t> operator==( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator!=( \
|
||||
const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> operator<( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator<=( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator>( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator>=( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator+( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vaddq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator-( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vsubq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator&( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vandq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator|( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vorrq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator^( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return veorq_s##bit(a, b); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<int##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_INT_NEON_TEMPLATE(2, 64)
|
||||
VEC_INT_NEON_TEMPLATE(4, 32)
|
||||
VEC_INT_NEON_TEMPLATE(8, 16)
|
||||
VEC_INT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline int32_t Vectorized<int32_t>::reduce_max() const {
|
||||
return vmaxvq_s32(values);
|
||||
}
|
||||
|
||||
inline int16_t Vectorized<int16_t>::reduce_max() const {
|
||||
return vmaxvq_s16(values);
|
||||
}
|
||||
|
||||
inline int8_t Vectorized<int8_t>::reduce_max() const {
|
||||
return vmaxvq_s8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator*(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vmulq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator*(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vmulq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator*(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vmulq_s8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
|
||||
int64x2_t val = a;
|
||||
return ~val;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
|
||||
return vmvnq_s32(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
|
||||
return vmvnq_s16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
|
||||
return vmvnq_s8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
|
||||
const Vectorized<int64_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
|
||||
const Vectorized<int32_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
|
||||
const Vectorized<int16_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
|
||||
const Vectorized<int8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline minimum(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vminq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline minimum(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vminq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline minimum(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vminq_s8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline maximum(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vmaxq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline maximum(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vmaxq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline maximum(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vmaxq_s8(a, b);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int64_t> Vectorized<int64_t>::blend(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint64x2_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s64(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int32_t> Vectorized<int32_t>::blend(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint32x4_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 4LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 8LL) ? 0xFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s32(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int16_t> Vectorized<int16_t>::blend(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint16x8_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFF : 0,
|
||||
(mask & 4LL) ? 0xFFFF : 0,
|
||||
(mask & 8LL) ? 0xFFFF : 0,
|
||||
(mask & 16LL) ? 0xFFFF : 0,
|
||||
(mask & 32LL) ? 0xFFFF : 0,
|
||||
(mask & 64LL) ? 0xFFFF : 0,
|
||||
(mask & 128LL) ? 0xFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s16(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int8_t> Vectorized<int8_t>::blend(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_INT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) { \
|
||||
values = vdupq_n_s##bit(val); \
|
||||
} \
|
||||
inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu( \
|
||||
const void* ptr, int64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ int##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const int##bit##_t*>(ptr), \
|
||||
count * sizeof(int##bit##_t)); \
|
||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
int##bit##_t tmp_values[size()]; \
|
||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_INT_NEON_OPS(2, 64)
|
||||
VEC_INT_NEON_OPS(4, 32)
|
||||
VEC_INT_NEON_OPS(8, 16)
|
||||
VEC_INT_NEON_OPS(16, 8)
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator*(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return x * y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator/(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator/(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t x = a;
|
||||
int32x4_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
inline int64_t Vectorized<int64_t>::reduce_max() const {
|
||||
return std::max(values[0], values[1]);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline minimum(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return {std::min(x[0], y[0]), std::min(x[1], y[1])};
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline maximum(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return {std::max(x[0], y[0]), std::max(x[1], y[1])};
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
|
||||
int64_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int64_t> base_vec(base);
|
||||
const Vectorized<int64_t> step_vec(step);
|
||||
const int64x2_t step_sizes = {0, 1};
|
||||
return base_vec.values + step_sizes * step_vec.values;
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
|
||||
int32_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int32_t> base_vec(base);
|
||||
const Vectorized<int32_t> step_vec(step);
|
||||
const int32x4_t step_sizes = {0, 1, 2, 3};
|
||||
return vmlaq_s32(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
|
||||
int16_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int16_t> base_vec(base);
|
||||
const Vectorized<int16_t> step_vec(step);
|
||||
const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
return vmlaq_s16(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
|
||||
const Vectorized<int8_t> base_vec(base);
|
||||
const Vectorized<int8_t> step_vec(step);
|
||||
const int8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_s8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator>>(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
||||
uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
|
||||
return x >> vreinterpretq_s64_u64(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator>>(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t x = a;
|
||||
int32x4_t y = b;
|
||||
uint32x4_t bound = vdupq_n_u32(31);
|
||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
||||
return x >> vreinterpretq_s32_u32(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator>>(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
int16x8_t x = a;
|
||||
int16x8_t y = b;
|
||||
uint16x8_t bound = vdupq_n_u16(15);
|
||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
||||
return x >> vreinterpretq_s16_u16(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator>>(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
int8x16_t x = a;
|
||||
int8x16_t y = b;
|
||||
uint8x16_t bound = vdupq_n_u8(7);
|
||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator<<(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t y = b;
|
||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
||||
uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
|
||||
return vshlq_s64(a, vreinterpretq_s64_u64(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator<<(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t y = b;
|
||||
uint32x4_t bound = vdupq_n_u32(32);
|
||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
||||
return vshlq_s32(a, vreinterpretq_s32_u32(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator<<(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
int16x8_t y = b;
|
||||
uint16x8_t bound = vdupq_n_u16(16);
|
||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
||||
return vshlq_s16(a, vreinterpretq_s16_u16(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator<<(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
int8x16_t y = b;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
||||
return vshlq_s8(a, z);
|
||||
}
|
||||
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::set(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 2) {
|
||||
return b;
|
||||
} else {
|
||||
int64x2_t c = {b.values[0], a.values[1]};
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::set(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 4) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint32x4_t maskArray = {
|
||||
(count >= 1LL) ? 0xFFFFFFFF : 0,
|
||||
(count >= 2LL) ? 0xFFFFFFFF : 0,
|
||||
(count >= 3LL) ? 0xFFFFFFFF : 0,
|
||||
0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s32(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::set(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 8) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint16x8_t maskArray = {
|
||||
static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
|
||||
0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s16(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::set(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator/(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
|
||||
Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
|
||||
Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
|
||||
Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
|
||||
int32x4_t highBitsResult = highBitsA / highBitsB;
|
||||
int32x4_t lowBitsResult = lowBitsA / lowBitsB;
|
||||
return vuzp1q_s16(
|
||||
vreinterpretq_s16_s32(lowBitsResult),
|
||||
vreinterpretq_s16_s32(highBitsResult));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator/(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
|
||||
Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
|
||||
Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
|
||||
Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
|
||||
int16x8_t highBitsResult = highBitsA / highBitsB;
|
||||
int16x8_t lowBitsResult = lowBitsA / lowBitsB;
|
||||
return vuzp1q_s8(
|
||||
vreinterpretq_s8_s16(lowBitsResult),
|
||||
vreinterpretq_s8_s16(highBitsResult));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& min,
|
||||
const Vectorized<int64_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& min,
|
||||
const Vectorized<int32_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& min,
|
||||
const Vectorized<int16_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& min,
|
||||
const Vectorized<int8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp_max(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp_max(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp_max(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp_max(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp_min(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp_min(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp_min(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp_min(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s8x8 = vget_low_s8(src);
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
||||
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s8x8 = vget_low_s8(src);
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
||||
|
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
|
||||
*/
|
||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
// The RNG state comprises the seed, and an offset used for Philox.
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(int64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
* and size of the internal state.
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(int64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
detail::check_rng_state(new_state);
|
||||
|
||||
|
@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||
|
||||
DEFINE_DISPATCH(elu_stub);
|
||||
DEFINE_DISPATCH(elu_backward_stub);
|
||||
|
@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
template <>
|
||||
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
|
||||
auto intmax = std::numeric_limits<int>::max();
|
||||
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||
return n <= intmax && incx <= intmax;
|
||||
}
|
||||
|
||||
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
|
||||
int64_t incx,
|
||||
[[maybe_unused]] float beta,
|
||||
int64_t incy) {
|
||||
auto intmax = std::numeric_limits<int>::max();
|
||||
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
|
||||
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
||||
const static scalar_t kTailValues[] = {
|
||||
constexpr static scalar_t kTailValues[] = {
|
||||
0.0810614667953272,
|
||||
0.0413406959554092,
|
||||
0.0276779256849983,
|
||||
@ -139,7 +139,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
||||
0.00925546218271273,
|
||||
0.00833056343336287
|
||||
};
|
||||
if (k <= 9) {
|
||||
if (k <= sizeof(kTailValues)/sizeof(scalar_t)) {
|
||||
return kTailValues[static_cast<size_t>(k)];
|
||||
}
|
||||
scalar_t kp1sq = (k + 1) * (k + 1);
|
||||
|
@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
|
||||
template <typename scalar_t>
|
||||
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
// lanczos approximation
|
||||
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
0.006061842346248906525783753964555936883222,
|
||||
0.5098416655656676188125178644804694509993,
|
||||
19.51992788247617482847860966235652136208,
|
||||
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
103794043.1163445451906271053616070238554,
|
||||
56906521.91347156388090791033559122686859
|
||||
};
|
||||
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
1.,
|
||||
66.,
|
||||
1925.,
|
||||
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
||||
template <typename scalar_t>
|
||||
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
|
||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||
static const scalar_t d[25][25] =
|
||||
static constexpr scalar_t d[25][25] =
|
||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
|
||||
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
|
||||
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
|
||||
|
@ -62,7 +62,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
static const int MIOPEN_DIM_MAX = 5;
|
||||
static constexpr int MIOPEN_DIM_MAX = 5;
|
||||
|
||||
namespace at::meta {
|
||||
|
||||
|
@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
|
||||
// We keep this structure for BC and consider as deprecated.
|
||||
// See HelperInterpNearestExact as replacement
|
||||
|
||||
static const int interp_size = 1;
|
||||
static constexpr int interp_size = 1;
|
||||
|
||||
static inline void init_indices_weights(
|
||||
at::ScalarType output_type,
|
||||
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
|
||||
|
||||
struct HelperInterpLinear : public HelperInterpBase {
|
||||
|
||||
static const int interp_size = 2;
|
||||
static constexpr int interp_size = 2;
|
||||
|
||||
// Compute indices and weights for each interpolated dimension
|
||||
// indices_weights = {
|
||||
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
|
||||
|
||||
struct HelperInterpCubic : public HelperInterpBase {
|
||||
|
||||
static const int interp_size = 4;
|
||||
static constexpr int interp_size = 4;
|
||||
|
||||
// Compute indices and weights for each interpolated dimension
|
||||
// indices_weights = {
|
||||
|
@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
|
||||
}
|
||||
|
||||
|
||||
static const int BLOCK_THREADS = 256;
|
||||
static constexpr int BLOCK_THREADS = 256;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
#if defined (USE_ROCM)
|
||||
|
@ -36,9 +36,9 @@ namespace at::native {
|
||||
namespace {
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
static const int BLOCKDIMY = 16;
|
||||
static constexpr int BLOCKDIMY = 16;
|
||||
#else
|
||||
static const int BLOCKDIMY = 32;
|
||||
static constexpr int BLOCKDIMY = 32;
|
||||
#endif
|
||||
|
||||
template
|
||||
|
@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
// lanczos approximation
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
|
||||
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
0.006061842346248906525783753964555936883222,
|
||||
0.5098416655656676188125178644804694509993,
|
||||
19.51992788247617482847860966235652136208,
|
||||
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
103794043.1163445451906271053616070238554,
|
||||
56906521.91347156388090791033559122686859
|
||||
};
|
||||
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
1.,
|
||||
66.,
|
||||
1925.,
|
||||
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t ax, fac, res, num, numfac;
|
||||
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
||||
7.09782712893383996843E2 : 88.72283905206835;
|
||||
static const accscalar_t EXP1 = 2.718281828459045;
|
||||
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
|
||||
constexpr accscalar_t EXP1 = 2.718281828459045;
|
||||
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
|
||||
|
||||
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
|
||||
ax = a * ::log(x) - x - ::lgamma(a);
|
||||
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
|
||||
// Compute igam using DLMF 8.11.4. [igam1]
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
static const int MAXITER = 2000;
|
||||
constexpr int MAXITER = 2000;
|
||||
|
||||
int i;
|
||||
accscalar_t ans, ax, c, r;
|
||||
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
||||
accscalar_t fac = 1;
|
||||
accscalar_t sum = 0;
|
||||
accscalar_t term, logx;
|
||||
static const int MAXITER = 2000;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr int MAXITER = 2000;
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
|
||||
for (n = 1; n < MAXITER; n++) {
|
||||
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const accscalar_t d[25][25] =
|
||||
constexpr accscalar_t d[25][25] =
|
||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
|
||||
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
|
||||
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
|
||||
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
||||
|
||||
int k, n, sgn;
|
||||
int maxpow = 0;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
accscalar_t lambda = x / a;
|
||||
accscalar_t sigma = (x - a) / a;
|
||||
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
|
||||
int i;
|
||||
accscalar_t ans, ax, c, yc, r, t, y, z;
|
||||
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
|
||||
static const int MAXITER = 2000;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr int MAXITER = 2000;
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
||||
4.503599627370496e15 : 16777216.;
|
||||
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
||||
2.22044604925031308085e-16 : 5.9604644775390625E-8;
|
||||
|
||||
ax = _igam_helper_fac(a, x);
|
||||
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t absxma_a;
|
||||
|
||||
static const accscalar_t SMALL = 20.0;
|
||||
static const accscalar_t LARGE = 200.0;
|
||||
static const accscalar_t SMALLRATIO = 0.3;
|
||||
static const accscalar_t LARGERATIO = 4.5;
|
||||
constexpr accscalar_t SMALL = 20.0;
|
||||
constexpr accscalar_t LARGE = 200.0;
|
||||
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||
constexpr accscalar_t LARGERATIO = 4.5;
|
||||
|
||||
if ((x < 0) || (a < 0)) {
|
||||
// out of defined-region of the function
|
||||
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t absxma_a;
|
||||
static const accscalar_t SMALL = 20.0;
|
||||
static const accscalar_t LARGE = 200.0;
|
||||
static const accscalar_t SMALLRATIO = 0.3;
|
||||
static const accscalar_t LARGERATIO = 4.5;
|
||||
constexpr accscalar_t SMALL = 20.0;
|
||||
constexpr accscalar_t LARGE = 200.0;
|
||||
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||
constexpr accscalar_t LARGERATIO = 4.5;
|
||||
|
||||
// boundary values following SciPy
|
||||
if ((x < 0) || (a < 0)) {
|
||||
|
@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
|
||||
const auto digamma_string = jiterator_stringify(
|
||||
template <typename T>
|
||||
T digamma(T x) {
|
||||
static const double PI_f64 = 3.14159265358979323846;
|
||||
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||
|
||||
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
|
||||
if (x == 0) {
|
||||
@ -3072,9 +3072,9 @@ template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
|
||||
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const double PI_f64 = 3.14159265358979323846;
|
||||
const accscalar_t PSI_10 = 2.25175258906672110764;
|
||||
const accscalar_t A[] = {
|
||||
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
||||
constexpr accscalar_t A[] = {
|
||||
8.33333333333333333333E-2,
|
||||
-2.10927960927960927961E-2,
|
||||
7.57575757575757575758E-3,
|
||||
|
@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const int size = 2;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
|
||||
// taken from
|
||||
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const int size = 4;
|
||||
static constexpr int size = 4;
|
||||
};
|
||||
|
||||
template <typename accscalar_t>
|
||||
|
@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
||||
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
|
||||
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
|
||||
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
||||
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
||||
if (mat1.dim() == 1 && mat2.dim() == 1) {
|
||||
// aten::dot
|
||||
return mat1.size(0) > mkldnn_gemm_min_size;
|
||||
|
@ -1,16 +1,16 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeSharedParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatSharedParams {
|
||||
int32_t ndim;
|
||||
int32_t cat_dim;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
};
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeInputParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatInputParams {
|
||||
idx_type_t cat_dim_offset;
|
||||
idx_type_t input_element_offset;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
|
@ -6,12 +6,12 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
kernel void cat_large(
|
||||
template <typename I, typename T_in, typename T_out>
|
||||
kernel void cat(
|
||||
constant T_in* input [[buffer(0)]],
|
||||
device T_out* output [[buffer(1)]],
|
||||
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
|
||||
constant CatLargeInputParams<>& input_params [[buffer(3)]],
|
||||
constant CatSharedParams<I>& shared_params [[buffer(2)]],
|
||||
constant CatInputParams<I>& input_params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto ndim = shared_params.ndim;
|
||||
auto cat_dim = shared_params.cat_dim;
|
||||
@ -23,9 +23,9 @@ kernel void cat_large(
|
||||
constant auto& input_strides = input_params.input_strides;
|
||||
constant auto& input_sizes = input_params.input_sizes;
|
||||
|
||||
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
|
||||
int64_t input_offset = 0;
|
||||
int64_t output_offset = 0;
|
||||
auto input_element_idx = static_cast<I>(tid) + input_element_offset;
|
||||
I input_offset = 0;
|
||||
I output_offset = 0;
|
||||
|
||||
for (auto dim = ndim - 1; dim >= 0; dim--) {
|
||||
auto dim_size = input_sizes[dim];
|
||||
@ -42,41 +42,45 @@ kernel void cat_large(
|
||||
output[output_offset] = static_cast<T_out>(input[input_offset]);
|
||||
}
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
|
||||
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
|
||||
kernel void cat_large<T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
|
||||
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
|
||||
#define REGISTER_CAT_OP(I, T_in, T_out) \
|
||||
template [[host_name("cat_" #I "_" #T_in "_" #T_out)]] \
|
||||
kernel void cat<I, T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatSharedParams<I> & shared_params [[buffer(2)]], \
|
||||
constant CatInputParams<I> & input_params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
|
||||
REGISTER_CAT_LARGE_OP(float, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(half, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(int, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uint, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(long, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ulong, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(short, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ushort, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(char, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uchar, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bool, T_out);
|
||||
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
|
||||
REGISTER_CAT_OP(I, float, T_out); \
|
||||
REGISTER_CAT_OP(I, half, T_out); \
|
||||
REGISTER_CAT_OP(I, bfloat, T_out); \
|
||||
REGISTER_CAT_OP(I, int, T_out); \
|
||||
REGISTER_CAT_OP(I, uint, T_out); \
|
||||
REGISTER_CAT_OP(I, long, T_out); \
|
||||
REGISTER_CAT_OP(I, ulong, T_out); \
|
||||
REGISTER_CAT_OP(I, short, T_out); \
|
||||
REGISTER_CAT_OP(I, ushort, T_out); \
|
||||
REGISTER_CAT_OP(I, char, T_out); \
|
||||
REGISTER_CAT_OP(I, uchar, T_out); \
|
||||
REGISTER_CAT_OP(I, bool, T_out);
|
||||
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
|
||||
#define REGISTER_CAT_FOR_INDEX_TYPE(I) \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool); \
|
||||
\
|
||||
REGISTER_CAT_OP(I, float2, float2); \
|
||||
REGISTER_CAT_OP(I, half2, half2);
|
||||
|
||||
REGISTER_CAT_LARGE_OP(float2, float2);
|
||||
REGISTER_CAT_LARGE_OP(half2, half2);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation of cat is used only if one of the inputs or the output is
|
||||
// too large to use MPSGraph.
|
||||
template <typename T>
|
||||
std::string get_type_str();
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int64_t>() {
|
||||
return "int64_t";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int32_t>() {
|
||||
return "int32_t";
|
||||
}
|
||||
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatLargeSharedParams shared_params;
|
||||
template <typename idx_type_t>
|
||||
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatSharedParams<idx_type_t> shared_params;
|
||||
|
||||
shared_params.ndim = output.dim();
|
||||
shared_params.cat_dim = dimension;
|
||||
|
||||
for (const auto dim : c10::irange(output.dim())) {
|
||||
shared_params.output_strides[dim] = output.stride(dim);
|
||||
shared_params.output_sizes[dim] = output.size(dim);
|
||||
shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
|
||||
shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
|
||||
}
|
||||
|
||||
int64_t cat_dim_offset = 0;
|
||||
idx_type_t cat_dim_offset = 0;
|
||||
size_t input_idx = 0;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Launch a separate kernels for each input. This will produce some overhead,
|
||||
// but that should be relatively minimal since at least one of the inputs is
|
||||
// very large. In order to launch only one kernel to process all inputs, we
|
||||
// would have to copy all the input tensor data into a packed buffer, which
|
||||
// would not be ideal.
|
||||
// Launch a separate kernels for each input. This will produce some overhead.
|
||||
// In order to launch only one kernel to process all inputs, we would have to
|
||||
// copy all the input tensor data into a packed buffer, which would not be
|
||||
// ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
continue;
|
||||
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
|
||||
|
||||
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
|
||||
auto num_threads = std::min(max_num_threads, numel_remaining);
|
||||
CatLargeInputParams input_params;
|
||||
CatInputParams<idx_type_t> input_params;
|
||||
|
||||
input_params.cat_dim_offset = cat_dim_offset;
|
||||
input_params.input_element_offset = input.numel() - numel_remaining;
|
||||
input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
|
||||
input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
|
||||
|
||||
for (const auto dim : c10::irange(input.dim())) {
|
||||
input_params.input_strides[dim] = input.stride(dim);
|
||||
input_params.input_sizes[dim] = input.size(dim);
|
||||
input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
|
||||
input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
|
||||
get_type_str<idx_type_t>(),
|
||||
scalarToMetalTypeString(input),
|
||||
scalarToMetalTypeString(output)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
|
||||
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
// TODO: For better performance by eliminating input tensor gathering and post transpose,
|
||||
// TODO: it is better to keep the out tensor's memory format.
|
||||
// TODO: dimension needs to be recomputed as:
|
||||
// TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
|
||||
if (needsGather(out)) {
|
||||
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
} else {
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
|
||||
}
|
||||
for (auto idx : skipped_tensor_indices) {
|
||||
key += "," + std::to_string(idx);
|
||||
}
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
|
||||
std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
|
||||
newCachedGraph->inputTensors_.reserve(len_tensor_array);
|
||||
|
||||
for (const auto idx : c10::irange(len_tensor_array)) {
|
||||
const Tensor& tensor = input_tensors[idx];
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
|
||||
if (tensor.scalar_type() != out_dtype) {
|
||||
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
|
||||
toType:getMPSDataType(out_dtype)
|
||||
name:@"castInput"];
|
||||
} else {
|
||||
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
|
||||
}
|
||||
}
|
||||
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
|
||||
dimension:dimension // Maybe convert this from int64_t -> int32
|
||||
name:nil];
|
||||
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
std::vector<Placeholder> inputPlaceholders;
|
||||
int i = 0;
|
||||
int t_idx = 0;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
|
||||
t_idx++;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
auto outputDataType = getMPSScalarType(out.scalar_type());
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
|
||||
} else {
|
||||
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6531,6 +6531,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: var
|
||||
MPS: var_mps
|
||||
MTIA: var_mtia
|
||||
tags: core
|
||||
|
||||
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
|
||||
|
||||
#if defined(__ARM_NEON__) || defined(__aarch64__)
|
||||
|
||||
const static int PARALLEL_THRESHOLD = 1 << 20;
|
||||
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
|
||||
|
||||
// Generic template defaults to naive quantize implementation
|
||||
template <typename T>
|
||||
|
@ -1388,7 +1388,7 @@ namespace at::native {
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
||||
static std::optional<at::Tensor> other = std::nullopt;
|
||||
static const std::string_view binary_post_op = "none";
|
||||
constexpr std::string_view binary_post_op = "none";
|
||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zp,
|
||||
|
@ -16,8 +16,8 @@ namespace {
|
||||
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
|
||||
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
||||
const static int qnnpack_softmax_output_zero_point = 0;
|
||||
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
||||
constexpr static int qnnpack_softmax_output_zero_point = 0;
|
||||
|
||||
bool is_qnnpack_compatible(
|
||||
const Tensor& qx,
|
||||
|
@ -110,9 +110,9 @@ class ApplyLogSumExp {
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementLSE = ElementLSE_;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
static const ScaleType::Kind kScale =
|
||||
static int constexpr kElementsPerAccess = ElementsPerAccess;
|
||||
static int constexpr kCount = kElementsPerAccess;
|
||||
static constexpr ScaleType::Kind kScale =
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
|
@ -14,16 +14,16 @@ using namespace at;
|
||||
|
||||
namespace {
|
||||
|
||||
const auto int_min = std::numeric_limits<int>::min();
|
||||
const auto int_max = std::numeric_limits<int>::max();
|
||||
const auto long_min = std::numeric_limits<int64_t>::min();
|
||||
const auto long_max = std::numeric_limits<int64_t>::max();
|
||||
const auto float_lowest = std::numeric_limits<float>::lowest();
|
||||
const auto float_min = std::numeric_limits<float>::min();
|
||||
const auto float_max = std::numeric_limits<float>::max();
|
||||
const auto double_lowest = std::numeric_limits<double>::lowest();
|
||||
const auto double_min = std::numeric_limits<double>::min();
|
||||
const auto double_max = std::numeric_limits<double>::max();
|
||||
constexpr auto int_min = std::numeric_limits<int>::min();
|
||||
constexpr auto int_max = std::numeric_limits<int>::max();
|
||||
constexpr auto long_min = std::numeric_limits<int64_t>::min();
|
||||
constexpr auto long_max = std::numeric_limits<int64_t>::max();
|
||||
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
|
||||
constexpr auto float_min = std::numeric_limits<float>::min();
|
||||
constexpr auto float_max = std::numeric_limits<float>::max();
|
||||
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
|
||||
constexpr auto double_min = std::numeric_limits<double>::min();
|
||||
constexpr auto double_max = std::numeric_limits<double>::max();
|
||||
|
||||
const std::vector<int> ints {
|
||||
int_min,
|
||||
|
@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
|
||||
|
||||
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
// The RNG state comprises the seed, and an offset used for Philox.
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(uint64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
// The internal state is returned as a CPU byte tensor.
|
||||
auto state_tensor = at::detail::empty_cpu(
|
||||
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::xpu::assertNotCapturing(
|
||||
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(uint64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
at::detail::check_rng_state(new_state);
|
||||
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import functional
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
@ -1,5 +1,6 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
|
||||
all_experiments: dict[str, Callable] = {}
|
||||
|
@ -9,8 +9,9 @@ import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
@ -7,6 +7,7 @@ from pt import ( # noqa: F401
|
||||
binary_inplace_test,
|
||||
binary_test,
|
||||
bmm_test,
|
||||
boolean_test,
|
||||
cat_test,
|
||||
channel_shuffle_test,
|
||||
chunk_test,
|
||||
|
@ -56,6 +56,9 @@ binary_ops_list = op_bench.op_list(
|
||||
["sub", torch.sub],
|
||||
["div", torch.div],
|
||||
["mul", torch.mul],
|
||||
["asr", torch.bitwise_right_shift],
|
||||
["lsl", torch.bitwise_left_shift],
|
||||
["xor", torch.bitwise_xor],
|
||||
],
|
||||
)
|
||||
|
||||
|
73
benchmarks/operator_benchmark/pt/boolean_test.py
Normal file
73
benchmarks/operator_benchmark/pt/boolean_test.py
Normal file
@ -0,0 +1,73 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for boolean operators. Supports both Caffe2/PyTorch."""
|
||||
|
||||
# Configs for PT all operator
|
||||
all_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128], N=[32, 64], K=[256, 512], device=["cpu", "cuda"], tags=["long"]
|
||||
)
|
||||
|
||||
|
||||
all_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
class AllBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.randint(0, 2, (M, N, K), device=device, dtype=torch.bool)
|
||||
}
|
||||
self.set_module_name("all")
|
||||
|
||||
def forward(self, input_one):
|
||||
return torch.all(input_one)
|
||||
|
||||
|
||||
# The generated test names based on all_short_configs will be in the following pattern:
|
||||
# all_M8_N16_K32_devicecpu
|
||||
# all_M8_N16_K32_devicecpu_bwdall
|
||||
# all_M8_N16_K32_devicecpu_bwd1
|
||||
# all_M8_N16_K32_devicecpu_bwd2
|
||||
# ...
|
||||
# Those names can be used to filter tests.
|
||||
|
||||
op_bench.generate_pt_test(all_long_configs + all_short_configs, AllBenchmark)
|
||||
|
||||
"""Mircobenchmark for any operator."""
|
||||
|
||||
|
||||
class AnyBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.randint(0, 2, (M, N), device=device, dtype=torch.bool)
|
||||
}
|
||||
self.set_module_name("any")
|
||||
|
||||
def forward(self, input_one):
|
||||
return torch.any(input_one)
|
||||
|
||||
|
||||
any_configs = op_bench.cross_product_configs(
|
||||
M=[8, 256],
|
||||
N=[256, 16],
|
||||
device=["cpu", "cuda"],
|
||||
tags=["any"],
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(any_configs, AnyBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
@ -3,10 +3,11 @@ import csv
|
||||
import itertools
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
@ -1,8 +1,8 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
@ -1729,8 +1729,10 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_interface.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
|
||||
compiler_flags = get_pt_compiler_flags() + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
|
||||
}),
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
link_whole = True,
|
||||
linker_flags = get_no_as_needed_linker_flag(),
|
||||
@ -2023,6 +2025,9 @@ def define_buck_targets(
|
||||
"ovr_config//os:android-x86_64": [
|
||||
"-mssse3",
|
||||
],
|
||||
}) + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
|
||||
}),
|
||||
exported_preprocessor_flags = get_aten_preprocessor_flags(),
|
||||
exported_deps = [
|
||||
|
@ -3,11 +3,11 @@ from __future__ import annotations
|
||||
import dis
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
@ -5,7 +5,7 @@ Python implementation of function wrapping functionality for functorch.dim.
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
@ -15,6 +15,10 @@ from ._enable_all_layers import EnableAllLayers
|
||||
from ._tensor_info import TensorInfo
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Handle tensor conversion for torch function integration."""
|
||||
return tensor
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from types import (
|
||||
BuiltinMethodType,
|
||||
FunctionType,
|
||||
@ -12,7 +13,7 @@ from types import (
|
||||
MethodDescriptorType,
|
||||
WrapperDescriptorType,
|
||||
)
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
FUNC_TYPES = (
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch.dim import dims # noqa: F401
|
||||
@ -16,7 +16,7 @@ from ._parsing import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
|
@ -23,7 +23,7 @@ project-excludes = [
|
||||
# ==== below will be enabled directory by directory ====
|
||||
# ==== to test Pyrefly on a specific directory, simply comment it out ====
|
||||
"torch/_inductor/runtime",
|
||||
"torch/_inductor/codegen",
|
||||
"torch/_inductor/codegen/triton.py",
|
||||
# formatting issues, will turn on after adjusting where suppressions can be
|
||||
# in import statements
|
||||
"torch/linalg/__init__.py",
|
||||
|
@ -1023,7 +1023,7 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
DTensorMeshTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DTensorMeshTest,
|
||||
skipped_tests=[
|
||||
# Submeshes are not supported by local tensor mode
|
||||
# Test asserts must be rewritten for local tensor
|
||||
"test_from_local_sub_mesh",
|
||||
"test_default_value_sub_mesh",
|
||||
"test_redistribute_sub_mesh",
|
||||
|
65
test/distributed/tensor/test_dynamic.py
Normal file
65
test/distributed/tensor/test_dynamic.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor
|
||||
from torch.distributed.tensor.placement_types import Replicate
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
class TestDynamic(DTensorTestBase):
|
||||
@requires_gpu
|
||||
@with_comms
|
||||
@parametrize("fake_tensor_cache_enabled", [False, True])
|
||||
def test_embedding(self, fake_tensor_cache_enabled):
|
||||
with patch.object(
|
||||
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
|
||||
):
|
||||
device_mesh = self.build_device_mesh()
|
||||
|
||||
placements = (Replicate(),)
|
||||
|
||||
num_embeddings = 202048
|
||||
embedding_dim = 256
|
||||
weight = distribute_tensor(
|
||||
torch.rand(
|
||||
[num_embeddings, embedding_dim],
|
||||
dtype=torch.float32,
|
||||
device=GPU_TYPE,
|
||||
requires_grad=True,
|
||||
),
|
||||
device_mesh,
|
||||
placements, # [Replicate()],
|
||||
)
|
||||
|
||||
def forward(input_batch_inputs_):
|
||||
to = weight.to(torch.float32)
|
||||
emb = torch.nn.functional.embedding(input_batch_inputs_, to)
|
||||
return emb
|
||||
|
||||
arg0 = torch.randint(
|
||||
low=0, high=100, size=(2, 512), dtype=torch.int64, device=GPU_TYPE
|
||||
)
|
||||
arg0 = DTensor.from_local(arg0, device_mesh, placements)
|
||||
|
||||
compiled_forward = torch.compile(forward, fullgraph=True, dynamic=True)
|
||||
_out = compiled_forward(arg0)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDynamic)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -30,6 +30,7 @@ from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
@ -647,7 +648,7 @@ class TestViewOps(DTensorTestBase):
|
||||
@with_comms
|
||||
def test_squeeze_(self):
|
||||
mesh_2d = init_device_mesh(self.device_type, (3, 2), mesh_dim_names=("a", "b"))
|
||||
torch.manual_seed(self.rank)
|
||||
self.init_manual_seed_for_rank()
|
||||
x = torch.randn((1, 4), device=self.device_type)
|
||||
dist_x = DTensor.from_local(x, mesh_2d, [Partial(), Shard(1)])
|
||||
self._test_op_on_dtensor(
|
||||
@ -664,5 +665,13 @@ class TestViewOps(DTensorTestBase):
|
||||
self.assertEqual(dist_x.placements, [Partial(), Shard(0)])
|
||||
|
||||
|
||||
TestViewOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestViewOps,
|
||||
skipped_tests=[
|
||||
# Comparing data pointers is not supported for local tensor
|
||||
"test_dtensor_view_op_uneven",
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -7,8 +7,13 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
from torch.testing._internal.common_cuda import SM100OrLater
|
||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||
from torch.testing._internal.common_utils import requires_cuda_p2p_access, run_tests
|
||||
from torch.testing._internal.common_utils import (
|
||||
requires_cuda_p2p_access,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
)
|
||||
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
@ -59,6 +64,10 @@ class CupyAsTensorTest(MultiProcContinuousTest):
|
||||
def device(self) -> torch.device:
|
||||
return torch.device(device_type, self.rank)
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
SM100OrLater,
|
||||
"Fails if ran in docker environment without privileged access (https://github.com/pytorch/pytorch/issues/165170)",
|
||||
)
|
||||
def test_cupy_as_tensor(self) -> None:
|
||||
"""
|
||||
Test that torch.as_tensor works for cupy array interface
|
||||
|
@ -1664,14 +1664,14 @@ class CuTeLayoutTest(TestCase):
|
||||
def test_remap_to_tensor(self):
|
||||
"""Test the remap_to_tensor method for various scenarios."""
|
||||
# Test 1: Consecutive ranks, full world - should return logical groups directly
|
||||
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
|
||||
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
|
||||
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
|
||||
result1 = layout1.remap_to_tensor(original_mesh)
|
||||
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
self.assertEqual(result1, expected1)
|
||||
|
||||
# Test 2: Non-consecutive ranks - should map to actual ranks
|
||||
original_mesh = torch.tensor([10, 20, 30, 40], dtype=torch.int)
|
||||
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
|
||||
layout2 = _Layout((2, 2), (2, 1))
|
||||
result2 = layout2.remap_to_tensor(original_mesh)
|
||||
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
|
||||
@ -1692,7 +1692,7 @@ class CuTeLayoutTest(TestCase):
|
||||
self.assertEqual(result5, expected5)
|
||||
|
||||
# Test 6: Tensor Cute representation of a 2D mesh
|
||||
original_mesh = torch.tensor([0, 2, 1, 3], dtype=torch.int)
|
||||
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
|
||||
layout6 = _Layout((2, 2), (1, 2)) # column-major style
|
||||
result6 = layout6.remap_to_tensor(original_mesh)
|
||||
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
|
||||
|
@ -12,6 +12,7 @@ import torch.distributed._symmetric_memory as symm_mem
|
||||
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
|
||||
from torch._inductor.runtime.triton_compat import triton
|
||||
from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem
|
||||
from torch.testing._internal.common_cuda import SM100OrLater
|
||||
from torch.testing._internal.common_distributed import MultiProcContinuousTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -264,6 +265,10 @@ def my_reduce_kernel(
|
||||
nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation)
|
||||
|
||||
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
SM100OrLater,
|
||||
"Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class NVSHMEMTritonTest(MultiProcContinuousTest):
|
||||
def _init_device(self) -> None:
|
||||
|
@ -52,6 +52,9 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
test_contexts = [nullcontext, _test_mode]
|
||||
|
||||
# Set environment variable to disable multicast for all tests in this module
|
||||
os.environ["TORCH_SYMM_MEM_DISABLE_MULTICAST"] = "1"
|
||||
|
||||
# So that tests are written in device-agnostic way
|
||||
device_type = "cuda"
|
||||
device_module = torch.get_device_module(device_type)
|
||||
@ -549,6 +552,10 @@ class AsyncTPTest(MultiProcContinuousTest):
|
||||
@skipUnless(SM89OrLater, "Requires compute capability >= 8.9")
|
||||
@parametrize("scatter_dim", [0, 1])
|
||||
@parametrize("rowwise", [True, False])
|
||||
@skipIf(
|
||||
SM100OrLater,
|
||||
"https://github.com/pytorch/pytorch/issues/162940",
|
||||
)
|
||||
def test_fused_scaled_matmul_reduce_scatter(
|
||||
self, scatter_dim: int, rowwise: bool
|
||||
) -> None:
|
||||
|
@ -510,6 +510,7 @@ class TestDynamoTimed(TestCase):
|
||||
raw = dataclasses.asdict(compilation_events[0])
|
||||
del raw["feature_usage"]
|
||||
del raw["ir_count"]
|
||||
del raw["inductor_provenance"]
|
||||
del raw["param_numel"]
|
||||
del raw["param_bytes"]
|
||||
del raw["param_count"]
|
||||
@ -694,6 +695,7 @@ class TestDynamoTimed(TestCase):
|
||||
raw = dataclasses.asdict(compilation_events[1])
|
||||
del raw["feature_usage"]
|
||||
del raw["ir_count"]
|
||||
del raw["inductor_provenance"]
|
||||
del raw["guard_latency_us"]
|
||||
del raw["param_numel"]
|
||||
del raw["param_bytes"]
|
||||
@ -911,6 +913,27 @@ class TestDynamoTimed(TestCase):
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[0].ir_count, second)
|
||||
|
||||
@dynamo_config.patch(
|
||||
{
|
||||
"log_compilation_metrics": True,
|
||||
}
|
||||
)
|
||||
@inductor_config.patch(
|
||||
{"trace.enabled": True, "trace.provenance_tracking_level": 1},
|
||||
)
|
||||
def test_inductor_provenance(self):
|
||||
module = torch.nn.Linear(6, 66)
|
||||
graph_module = torch.fx.symbolic_trace(module)
|
||||
|
||||
compilation_events = []
|
||||
with mock.patch("torch._dynamo.utils.log_compilation_event") as log_event:
|
||||
torch.compile(graph_module)(torch.randn(6, 6))
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(
|
||||
compilation_events[0].inductor_provenance,
|
||||
{'{"extern_kernels.addmm:1": []}'},
|
||||
)
|
||||
|
||||
@dynamo_config.patch({"log_compilation_metrics": True})
|
||||
@inductor_config.patch({"force_disable_caches": True})
|
||||
def test_dynamic_shape_feature_use(self):
|
||||
|
@ -57,7 +57,7 @@ def graph_capture(model, inputs, with_export):
|
||||
with ExitStack() as stack:
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
model,
|
||||
gm,
|
||||
inputs,
|
||||
)
|
||||
return joint_with_descriptors.graph_module
|
||||
|
@ -2340,6 +2340,39 @@ class AOTInductorTestsTemplate:
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_cond_symint_input_disable_one_pass(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
a = y.shape[0]
|
||||
b = z.shape[0]
|
||||
|
||||
def true_fn(x):
|
||||
return x + a
|
||||
|
||||
def false_fn(x):
|
||||
return x + b * z
|
||||
|
||||
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
|
||||
|
||||
input1 = (
|
||||
torch.ones(3, 3, device=self.device),
|
||||
torch.ones(5, device=self.device),
|
||||
torch.ones(3, 3, device=self.device),
|
||||
)
|
||||
input2 = (
|
||||
torch.ones(10, 3, device=self.device),
|
||||
torch.ones(6, device=self.device),
|
||||
torch.ones(10, 3, device=self.device),
|
||||
)
|
||||
inputs = (input1, input2)
|
||||
dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}}
|
||||
with torch._inductor.config.patch({"triton.autotune_at_compile_time": False}):
|
||||
self.check_model_with_multiple_inputs(
|
||||
M(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
def test_while_loop_simple(self):
|
||||
inputs = (
|
||||
torch.randn((10, 20), device=self.device),
|
||||
|
371
test/inductor/test_aoti_cross_compile_windows.py
Normal file
371
test/inductor/test_aoti_cross_compile_windows.py
Normal file
@ -0,0 +1,371 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelTestConfig:
|
||||
"""Configuration for a model test case."""
|
||||
|
||||
name: str
|
||||
model_class: type
|
||||
example_inputs: tuple[torch.Tensor, ...]
|
||||
dynamic_shapes: Optional[dict[str, Any]] = None
|
||||
inductor_configs: Optional[dict[str, Any]] = None
|
||||
rtol: float = 1e-4
|
||||
atol: float = 1e-4
|
||||
|
||||
|
||||
class WindowsCrossCompilationTestFramework:
|
||||
"""
|
||||
Framework for testing cross-compilation from Linux to Windows.
|
||||
|
||||
Provides reusable logic for creating compile and load test methods.
|
||||
"""
|
||||
|
||||
_base_path: Optional[Path] = None
|
||||
_win_torch_libs_path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def base_path(cls) -> Path:
|
||||
"""Get or create the base path for package files."""
|
||||
if cls._base_path is None:
|
||||
cls._base_path = Path(tempfile.mkdtemp(prefix="aoti_cross_compile_"))
|
||||
return cls._base_path
|
||||
|
||||
@classmethod
|
||||
def set_base_path(cls, path: Optional[Path | str] = None) -> None:
|
||||
"""Set the base path for package files."""
|
||||
cls._base_path = Path(path) if path else None
|
||||
|
||||
@classmethod
|
||||
def set_win_torch_libs_path(cls, path: Optional[str] = None) -> None:
|
||||
"""Set the path for Windows torch libs."""
|
||||
cls._win_torch_libs_path = path
|
||||
|
||||
@classmethod
|
||||
def get_package_path(cls, model_name: str) -> str:
|
||||
"""Get the path for a model's .pt2 package file."""
|
||||
package_dir = cls.base_path()
|
||||
package_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(package_dir / f"{model_name}_windows.pt2")
|
||||
|
||||
@classmethod
|
||||
def get_win_torch_libs_path(cls) -> str:
|
||||
"""Get the path for Windows torch libs."""
|
||||
if cls._win_torch_libs_path is None:
|
||||
raise RuntimeError("Windows torch libs path not set")
|
||||
return str(cls._win_torch_libs_path)
|
||||
|
||||
@classmethod
|
||||
def create_compile_test(cls, config: ModelTestConfig):
|
||||
"""Create a compile test method for a model configuration."""
|
||||
|
||||
def compile_test(self):
|
||||
if platform.system() == "Windows":
|
||||
raise unittest.SkipTest(
|
||||
"This test should run on Linux for cross-compilation"
|
||||
)
|
||||
|
||||
self.assertTrue("WINDOWS_CUDA_HOME" in os.environ)
|
||||
|
||||
with torch.no_grad():
|
||||
# Windows cross-compilation is only used for GPU.
|
||||
# AOTI for CPU should be able to work as native compilation on Windows.
|
||||
device = GPU_TYPE
|
||||
model = config.model_class().to(device=device)
|
||||
example_inputs = config.example_inputs
|
||||
|
||||
# Inputs should already be on GPU_TYPE but ensure they are
|
||||
example_inputs = tuple(inp.to(device) for inp in example_inputs)
|
||||
|
||||
# Export the model
|
||||
exported = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes=config.dynamic_shapes
|
||||
)
|
||||
|
||||
# Prepare inductor configs
|
||||
inductor_configs = {
|
||||
"aot_inductor.cross_target_platform": "windows",
|
||||
"aot_inductor.precompile_headers": False,
|
||||
"aot_inductor.package_constants_on_disk_format": "binary_blob",
|
||||
"aot_inductor.package_constants_in_so": False,
|
||||
"aot_inductor.aoti_shim_library_path": cls.get_win_torch_libs_path(),
|
||||
}
|
||||
if config.inductor_configs:
|
||||
inductor_configs.update(config.inductor_configs)
|
||||
|
||||
# Compile and package directly to the expected location
|
||||
package_path = cls.get_package_path(config.name)
|
||||
torch._inductor.aoti_compile_and_package(
|
||||
exported,
|
||||
package_path=package_path,
|
||||
inductor_configs=inductor_configs,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
os.path.exists(package_path),
|
||||
f"Package file should exist at {package_path}",
|
||||
)
|
||||
|
||||
return compile_test
|
||||
|
||||
@classmethod
|
||||
def create_load_test(cls, config: ModelTestConfig):
|
||||
"""Create a load test method for a model configuration."""
|
||||
|
||||
def load_test(self):
|
||||
if platform.system() != "Windows":
|
||||
raise unittest.SkipTest("This test should run on Windows")
|
||||
|
||||
if not HAS_GPU:
|
||||
raise unittest.SkipTest("Test requires GPU")
|
||||
|
||||
package_path = cls.get_package_path(config.name)
|
||||
if not os.path.exists(package_path):
|
||||
raise unittest.SkipTest(
|
||||
f"Package file not found at {package_path}. "
|
||||
f"Run test_{config.name}_compile first."
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# Windows cross-compilation is only used for GPU.
|
||||
# AOTI for CPU should be able to work as native compilation on Windows.
|
||||
device = GPU_TYPE
|
||||
|
||||
# Create original model for comparison
|
||||
original_model = config.model_class().to(device=device)
|
||||
example_inputs = config.example_inputs
|
||||
|
||||
# Inputs should already be on GPU_TYPE but ensure they are
|
||||
example_inputs = tuple(inp.to(device) for inp in example_inputs)
|
||||
|
||||
# Load the compiled package
|
||||
loaded_model = torch._inductor.aoti_load_package(package_path)
|
||||
|
||||
# Test with the same inputs
|
||||
original_output = original_model(*example_inputs)
|
||||
loaded_output = loaded_model(*example_inputs)
|
||||
|
||||
# Compare outputs
|
||||
torch.testing.assert_close(
|
||||
original_output, loaded_output, rtol=config.rtol, atol=config.atol
|
||||
)
|
||||
|
||||
return load_test
|
||||
|
||||
|
||||
def auto_generate_tests(test_class):
|
||||
"""
|
||||
Class decorator to automatically generate compile/load test methods
|
||||
from _define_* methods that return ModelTestConfig.
|
||||
"""
|
||||
# Find all _define_* methods that return ModelTestConfig
|
||||
define_methods = {}
|
||||
for name in dir(test_class):
|
||||
if name.startswith("_define_") and callable(getattr(test_class, name)):
|
||||
method = getattr(test_class, name)
|
||||
# Try to call the method to see if it returns ModelTestConfig
|
||||
try:
|
||||
# Create a temporary instance to call the method
|
||||
temp_instance = test_class.__new__(test_class)
|
||||
result = method(temp_instance)
|
||||
if isinstance(result, ModelTestConfig):
|
||||
define_methods[name] = result
|
||||
except Exception:
|
||||
# If method fails, skip it
|
||||
pass
|
||||
|
||||
# Generate compile/load methods for each discovered definition
|
||||
for define_name, config in define_methods.items():
|
||||
model_name = define_name[8:] # Remove '_define_' prefix
|
||||
|
||||
# Create compile test method
|
||||
compile_method_name = f"test_{model_name}_compile"
|
||||
compile_method = WindowsCrossCompilationTestFramework.create_compile_test(
|
||||
config
|
||||
)
|
||||
compile_method.__name__ = compile_method_name
|
||||
compile_method.__doc__ = f"Step 1: Cross-compile {model_name} model on Linux"
|
||||
compile_method = requires_gpu()(compile_method)
|
||||
setattr(test_class, compile_method_name, compile_method)
|
||||
|
||||
# Create load test method
|
||||
load_method_name = f"test_{model_name}_load"
|
||||
load_method = WindowsCrossCompilationTestFramework.create_load_test(config)
|
||||
load_method.__name__ = load_method_name
|
||||
load_method.__doc__ = f"Step 2: Load and test {model_name} model on Windows"
|
||||
load_method = requires_gpu()(load_method)
|
||||
setattr(test_class, load_method_name, load_method)
|
||||
|
||||
return test_class
|
||||
|
||||
|
||||
@auto_generate_tests
|
||||
class TestAOTInductorWindowsCrossCompilation(TestCase):
|
||||
"""
|
||||
Test class for AOT Inductor Windows cross-compilation.
|
||||
|
||||
Define test methods that return ModelTestConfig, and the decorator
|
||||
will auto-generate compile/load test methods.
|
||||
"""
|
||||
|
||||
def _define_simple(self):
|
||||
"""Define the Simple model and its test configuration."""
|
||||
|
||||
class Simple(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(10, 16)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.fc2 = torch.nn.Linear(16, 1)
|
||||
self.sigmoid = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return x
|
||||
|
||||
return ModelTestConfig(
|
||||
name="simple",
|
||||
model_class=Simple,
|
||||
example_inputs=(torch.randn(8, 10, device=GPU_TYPE),),
|
||||
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=1024)}},
|
||||
)
|
||||
|
||||
def _define_simple_cnn(self):
|
||||
"""Define the SimpleCNN model and its test configuration."""
|
||||
|
||||
class SimpleCNN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 16, 3)
|
||||
self.relu = torch.nn.ReLU()
|
||||
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = torch.nn.Linear(16, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.pool(x)
|
||||
x = x.flatten(1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
return ModelTestConfig(
|
||||
name="simple_cnn",
|
||||
model_class=SimpleCNN,
|
||||
example_inputs=(torch.randn(2, 3, 32, 32, device=GPU_TYPE),),
|
||||
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=16)}},
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
def _define_transformer(self):
|
||||
"""Define the SimpleTransformer model and its test configuration."""
|
||||
|
||||
class SimpleTransformer(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embedding = torch.nn.Linear(128, 256)
|
||||
self.attention = torch.nn.MultiheadAttention(256, 8, batch_first=True)
|
||||
self.norm1 = torch.nn.LayerNorm(256)
|
||||
self.ffn = torch.nn.Sequential(
|
||||
torch.nn.Linear(256, 1024),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(1024, 256),
|
||||
)
|
||||
self.norm2 = torch.nn.LayerNorm(256)
|
||||
self.output = torch.nn.Linear(256, 10)
|
||||
|
||||
def forward(self, x):
|
||||
# x shape: (batch, seq_len, input_dim)
|
||||
x = self.embedding(x)
|
||||
attn_out, _ = self.attention(x, x, x)
|
||||
x = self.norm1(x + attn_out)
|
||||
ffn_out = self.ffn(x)
|
||||
x = self.norm2(x + ffn_out)
|
||||
x = x.mean(dim=1) # Global average pooling
|
||||
x = self.output(x)
|
||||
return x
|
||||
|
||||
return ModelTestConfig(
|
||||
name="transformer",
|
||||
model_class=SimpleTransformer,
|
||||
example_inputs=(torch.randn(4, 16, 128, device=GPU_TYPE),),
|
||||
dynamic_shapes={"x": {0: torch.export.Dim("batch", min=1, max=32)}},
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
# Check for --package-dir argument and remove it before unittest sees it
|
||||
package_dir = None
|
||||
win_torch_lib_dir = None
|
||||
filtered_argv = []
|
||||
i = 0
|
||||
while i < len(sys.argv):
|
||||
if sys.argv[i] == "--package-dir":
|
||||
if i + 1 < len(sys.argv):
|
||||
package_dir = sys.argv[i + 1]
|
||||
i += 2 # Skip both --package-dir and its value
|
||||
else:
|
||||
print("Error: --package-dir requires a valid directory path")
|
||||
sys.exit(1)
|
||||
elif sys.argv[i].startswith("--package-dir="):
|
||||
package_dir = sys.argv[i].split("=", 1)[1]
|
||||
i += 1
|
||||
elif sys.argv[i] == "--win-torch-lib-dir":
|
||||
if i + 1 < len(sys.argv):
|
||||
win_torch_lib_dir = sys.argv[i + 1]
|
||||
i += 2 # Skip both --win-torch-lib-dir and its value
|
||||
else:
|
||||
print("Error: --win-torch-lib-dir requires a valid directory path")
|
||||
sys.exit(1)
|
||||
elif sys.argv[i].startswith("--win-torch-lib-dir="):
|
||||
win_torch_lib_dir = sys.argv[i].split("=", 1)[1]
|
||||
i += 1
|
||||
else:
|
||||
filtered_argv.append(sys.argv[i])
|
||||
i += 1
|
||||
|
||||
# Validate and set the base path for package storage
|
||||
if package_dir:
|
||||
try:
|
||||
package_path = Path(package_dir)
|
||||
package_path.mkdir(parents=True, exist_ok=True)
|
||||
# Test write access
|
||||
test_file = package_path / ".test_write"
|
||||
test_file.touch()
|
||||
test_file.unlink()
|
||||
WindowsCrossCompilationTestFramework.set_base_path(package_path)
|
||||
except Exception:
|
||||
print("Error: --package-dir requires a valid directory path")
|
||||
sys.exit(1)
|
||||
|
||||
# Set Windows torch libs path if provided (only needed for compile tests)
|
||||
if win_torch_lib_dir:
|
||||
WindowsCrossCompilationTestFramework.set_win_torch_libs_path(win_torch_lib_dir)
|
||||
|
||||
# Update sys.argv to remove our custom arguments
|
||||
sys.argv = filtered_argv
|
||||
|
||||
if HAS_GPU:
|
||||
run_tests(needs="filelock")
|
@ -623,7 +623,8 @@ class TestFP8Lowering(TestCase):
|
||||
bias,
|
||||
)
|
||||
|
||||
FileCheck().check("SCALING_ROWWISE : tl.constexpr = False").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 0").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 0").run(code[0])
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
# depending on the kernel config (BLOCK_M size, etc) selected during Inductor
|
||||
@ -768,7 +769,8 @@ class TestFP8Lowering(TestCase):
|
||||
bias,
|
||||
)
|
||||
|
||||
FileCheck().check("SCALING_ROWWISE : tl.constexpr = True").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_A : tl.constexpr = 1").run(code[0])
|
||||
FileCheck().check("SCALE_RECIPE_B : tl.constexpr = 1").run(code[0])
|
||||
self.assertEqual(y_eager.dtype, dtype)
|
||||
self.assertEqual(y_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
|
||||
|
@ -8423,22 +8423,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
self.assertEqual(fn(x[0:]), x[16:][:16])
|
||||
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
|
||||
|
||||
def test_index_float_zero(self):
|
||||
def fn(arg0, arg1, arg2):
|
||||
t1 = torch.tanh(arg0)
|
||||
t2 = t1.clone()
|
||||
t2.fill_(arg1.item())
|
||||
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
|
||||
return torch.nn.functional.embedding(t3, arg2)
|
||||
|
||||
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
|
||||
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
|
||||
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
|
||||
|
||||
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
|
||||
|
||||
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
|
||||
|
||||
# from GPT2ForSequenceClassification
|
||||
@skip_if_gpu_halide
|
||||
def test_index_tensor(self):
|
||||
|
@ -592,7 +592,6 @@ def graph_executor(
|
||||
proto = onnxscript_function.to_function_proto()
|
||||
ir_function = ir.serde.deserialize_function(proto)
|
||||
onnx_model.functions[identifier] = ir_function
|
||||
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
|
||||
_ir_passes.add_opset_imports(onnx_model)
|
||||
# Make sure the model is valid
|
||||
model_proto = ir.to_proto(onnx_model)
|
||||
|
@ -384,143 +384,6 @@ class TestTorchAutocast(TestCase):
|
||||
with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg):
|
||||
torch.autocast(device_type=dev)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_nograd_caching_issue_158232(self):
|
||||
"""
|
||||
Regression test for issue #158232: autocast + no_grad incompatibility
|
||||
|
||||
When torch.no_grad() is nested inside torch.autocast(), the autocast cache
|
||||
must not cache tensors created in the no_grad context, because they lack
|
||||
gradient tracking. If cached, subsequent operations in gradient-enabled mode
|
||||
would incorrectly use the no-gradient cached version.
|
||||
|
||||
Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
||||
After fix: Should work correctly
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# First forward pass in no_grad context (e.g., shape inference)
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(
|
||||
out1.requires_grad, "Output in no_grad should not require grad"
|
||||
)
|
||||
|
||||
# Second forward pass with gradients enabled (e.g., training)
|
||||
out2 = model(inp)
|
||||
self.assertTrue(
|
||||
out2.requires_grad,
|
||||
"Output should require gradients after exiting no_grad",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
out2.grad_fn, "Output should have grad_fn after exiting no_grad"
|
||||
)
|
||||
|
||||
# Backward pass should work
|
||||
loss = out2.mean()
|
||||
loss.backward()
|
||||
|
||||
# Verify gradients were computed
|
||||
self.assertIsNotNone(model.weight.grad)
|
||||
self.assertIsNotNone(model.bias.grad)
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_inference_mode_interaction(self):
|
||||
"""
|
||||
Test that autocast works correctly with torch.inference_mode()
|
||||
|
||||
InferenceMode is a stricter version of no_grad that provides additional
|
||||
performance optimizations. Verify it doesn't break with autocast.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
# Test 1: inference_mode inside autocast
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
with torch.inference_mode():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
self.assertEqual(out1.dtype, torch.bfloat16)
|
||||
|
||||
# After exiting inference_mode, gradients should work
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
out2.mean().backward()
|
||||
|
||||
# Test 2: autocast inside inference_mode
|
||||
with torch.inference_mode():
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
out = model(inp)
|
||||
self.assertFalse(out.requires_grad)
|
||||
self.assertEqual(out.dtype, torch.bfloat16)
|
||||
|
||||
def test_autocast_caching_still_works_with_gradients(self):
|
||||
"""
|
||||
Verify that autocast caching still functions correctly when gradients ARE enabled.
|
||||
|
||||
This test ensures the fix for #158232 didn't break normal caching behavior.
|
||||
We can't directly observe cache hits, but we verify that repeated operations
|
||||
with gradients enabled work correctly.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Multiple forward passes with gradients enabled
|
||||
out1 = model(inp)
|
||||
out2 = model(inp)
|
||||
out3 = model(inp)
|
||||
|
||||
# All should have gradients
|
||||
self.assertTrue(out1.requires_grad)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
self.assertTrue(out3.requires_grad)
|
||||
|
||||
# All should have grad_fn
|
||||
self.assertIsNotNone(out1.grad_fn)
|
||||
self.assertIsNotNone(out2.grad_fn)
|
||||
self.assertIsNotNone(out3.grad_fn)
|
||||
|
||||
# Backward should work on all
|
||||
out1.mean().backward(retain_graph=True)
|
||||
out2.mean().backward(retain_graph=True)
|
||||
out3.mean().backward()
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
def test_autocast_mixed_grad_contexts(self):
|
||||
"""
|
||||
Test complex nesting of gradient contexts within autocast.
|
||||
|
||||
This ensures the gradient mode check works correctly across
|
||||
multiple transitions between gradient-enabled and disabled states.
|
||||
"""
|
||||
model = torch.nn.Linear(2, 2)
|
||||
inp = torch.randn(8, 2)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True):
|
||||
# Pass 1: no_grad
|
||||
with torch.no_grad():
|
||||
out1 = model(inp)
|
||||
self.assertFalse(out1.requires_grad)
|
||||
|
||||
# Pass 2: gradients enabled
|
||||
out2 = model(inp)
|
||||
self.assertTrue(out2.requires_grad)
|
||||
|
||||
# Pass 3: no_grad again
|
||||
with torch.no_grad():
|
||||
out3 = model(inp)
|
||||
self.assertFalse(out3.requires_grad)
|
||||
|
||||
# Pass 4: gradients enabled again
|
||||
out4 = model(inp)
|
||||
self.assertTrue(out4.requires_grad)
|
||||
|
||||
# Backward on gradient-enabled outputs
|
||||
(out2.mean() + out4.mean()).backward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -56,14 +56,15 @@ if TEST_CUDA:
|
||||
# Protects against includes accidentally setting the default dtype
|
||||
assert torch.get_default_dtype() is torch.float32
|
||||
|
||||
def xfailIfSM100OrLaterAndCondition(condition_fn):
|
||||
def xfailIfSM100OrLaterNonRTXAndCondition(condition_fn):
|
||||
"""
|
||||
Conditionally xfail tests on SM100+ based on a condition function.
|
||||
Conditionally xfail tests on SM100+ datacenter SKUs based on a condition function.
|
||||
The condition function receives the test parameters dict and returns True to xfail.
|
||||
"""
|
||||
computeCapabilityCheck = SM100OrLater and torch.cuda.get_device_capability()[0] != 12
|
||||
return decorateIf(
|
||||
unittest.expectedFailure,
|
||||
lambda params: SM100OrLater and condition_fn(params)
|
||||
lambda params: computeCapabilityCheck and condition_fn(params)
|
||||
)
|
||||
|
||||
|
||||
@ -163,7 +164,7 @@ class TestMatmulCuda(InductorTestCase):
|
||||
self.cublas_addmm(size, dtype, False)
|
||||
|
||||
@onlyCUDA
|
||||
@xfailIfSM100OrLaterAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000)
|
||||
@xfailIfSM100OrLaterNonRTXAndCondition(lambda params: params.get('dtype') == torch.bfloat16 and params.get('size') == 10000)
|
||||
# imported 'tol' as 'xtol' to avoid aliasing in code above
|
||||
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
|
||||
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
|
||||
|
@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
|
||||
from torchgen.context import native_function_manager
|
||||
|
@ -36,7 +36,7 @@ from __future__ import annotations
|
||||
import itertools
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
|
||||
@ -77,7 +77,7 @@ from .gen_trace_type import should_trace
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
|
||||
|
||||
#
|
||||
|
@ -29,7 +29,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.autograd import (
|
||||
@ -106,7 +106,7 @@ from .gen_trace_type import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
|
||||
# We don't set or modify grad_fn on these methods. Generally, they return
|
||||
|
@ -196,7 +196,7 @@ class FuzzTemplate:
|
||||
|
||||
class DefaultFuzzTemplate(FuzzTemplate):
|
||||
def __init__(self):
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithNumericsCheck
|
||||
from torchfuzz.checks import EagerVsFullGraphDynamicCompileCheck
|
||||
|
||||
super().__init__(
|
||||
supported_ops=[
|
||||
@ -236,7 +236,7 @@ class DefaultFuzzTemplate(FuzzTemplate):
|
||||
# Regularization
|
||||
"torch.nn.functional.dropout",
|
||||
],
|
||||
check=EagerVsFullGraphDynamicCompileWithNumericsCheck(),
|
||||
check=EagerVsFullGraphDynamicCompileCheck(),
|
||||
)
|
||||
|
||||
def spec_distribution(self):
|
||||
|
@ -241,7 +241,7 @@ if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
try:
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||
except ImportError:
|
||||
# If importing as a module fails, import from the same directory
|
||||
import os
|
||||
@ -249,7 +249,7 @@ if __name__ == "__main__":
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer
|
||||
from multi_process_fuzzer import run_multi_process_fuzzer, run_until_failure
|
||||
|
||||
# Set up command-line argument parsing
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -296,6 +296,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Print detailed output for all runs (not just failures)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop-at-first-failure",
|
||||
action="store_true",
|
||||
help="Pick a random seed and keep iterating until finding a failure (exits with non-zero code)",
|
||||
)
|
||||
|
||||
# Legacy arguments
|
||||
parser.add_argument(
|
||||
@ -337,6 +342,30 @@ if __name__ == "__main__":
|
||||
supported_ops=parsed_supported_ops,
|
||||
op_weights=(parsed_weights if parsed_weights else None),
|
||||
)
|
||||
elif args.stop_at_first_failure:
|
||||
# Stop-at-first-failure mode
|
||||
# Default number of processes
|
||||
if args.processes is None:
|
||||
cpu_count = mp.cpu_count()
|
||||
args.processes = max(1, min(16, int(cpu_count * 0.75)))
|
||||
|
||||
if args.processes < 1:
|
||||
print("❌ Error: Number of processes must be at least 1")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
run_until_failure(
|
||||
num_processes=args.processes,
|
||||
verbose=args.verbose,
|
||||
template=args.template,
|
||||
supported_ops=args.supported_ops,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
elif args.start is not None or args.count is not None:
|
||||
# Multi-process fuzzing mode
|
||||
if args.start is None:
|
||||
|
@ -66,6 +66,12 @@ IGNORE_PATTERNS: list[re.Pattern] = [
|
||||
re.compile(
|
||||
r"torch\._inductor\.exc\.InductorError: CppCompileError: C\+\+ compile error"
|
||||
), # https://github.com/pytorch/pytorch/issues/164686
|
||||
re.compile(
|
||||
r"\.item\(\) # dtype="
|
||||
), # https://github.com/pytorch/pytorch/issues/164725
|
||||
re.compile(
|
||||
r"dimensionality of sizes \(0\) must match dimensionality of strides \(1\)"
|
||||
), # https://github.com/pytorch/pytorch/issues/164814
|
||||
# Add more patterns here as needed, e.g.:
|
||||
# re.compile(r"Some other error message"),
|
||||
]
|
||||
@ -516,3 +522,143 @@ def _print_operation_distribution(results: list[FuzzerResult]) -> None:
|
||||
persist_print(
|
||||
"\n📊 No operation statistics collected (no successful runs with stats)"
|
||||
)
|
||||
|
||||
|
||||
def run_until_failure(
|
||||
num_processes: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
template: str = "default",
|
||||
supported_ops: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the multi-process fuzzer with a random starting seed, iterating until a failure is found.
|
||||
|
||||
Args:
|
||||
num_processes: Number of worker processes to use
|
||||
verbose: Whether to print detailed output
|
||||
template: The template to use for code generation
|
||||
supported_ops: Comma-separated ops string with optional weights
|
||||
|
||||
Returns:
|
||||
Exits with non-zero code when a failure is found
|
||||
"""
|
||||
import random
|
||||
|
||||
# Pick a random seed to start from
|
||||
initial_seed = random.randint(0, 2**31 - 1)
|
||||
|
||||
persist_print(
|
||||
f"🎲 Starting continuous fuzzing with random initial seed: {initial_seed}"
|
||||
)
|
||||
persist_print(f"🚀 Using {num_processes} processes")
|
||||
persist_print(
|
||||
f"🔧 Command template: python fuzzer.py --seed {{seed}} --template {template}"
|
||||
)
|
||||
persist_print("🎯 Running until first failure is found...")
|
||||
persist_print("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
current_seed = initial_seed
|
||||
total_successful = 0
|
||||
total_ignored = 0
|
||||
batch_size = 100 # Process seeds in batches of 100
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Process a batch of seeds
|
||||
seeds = list(range(current_seed, current_seed + batch_size))
|
||||
|
||||
with mp.Pool(processes=num_processes) as pool:
|
||||
future_results = []
|
||||
for seed in seeds:
|
||||
future = pool.apply_async(
|
||||
run_fuzzer_with_seed, (seed, template, supported_ops)
|
||||
)
|
||||
future_results.append((seed, future))
|
||||
|
||||
# Set up progress bar for this batch
|
||||
if HAS_TQDM:
|
||||
from tqdm import tqdm
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(seeds),
|
||||
desc=f"Batch starting at seed {current_seed}",
|
||||
file=sys.stdout,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}] ✅/🚫={postfix}",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||
|
||||
def write_func(msg):
|
||||
pbar.write(msg)
|
||||
else:
|
||||
pbar = None
|
||||
|
||||
# Collect results as they complete
|
||||
for seed, future in future_results:
|
||||
result: FuzzerResult = future.get()
|
||||
|
||||
if result.ignored_pattern_idx != -1:
|
||||
total_ignored += 1
|
||||
|
||||
if result.success:
|
||||
total_successful += 1
|
||||
elif result.ignored_pattern_idx == -1:
|
||||
# Found a failure that is not ignored!
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.close()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
persist_print("\n" + "=" * 60)
|
||||
persist_print("🎯 FAILURE FOUND!")
|
||||
persist_print("=" * 60)
|
||||
persist_print(f"❌ Failing seed: {result.seed}")
|
||||
persist_print(
|
||||
f"⏱️ Duration for this seed: {result.duration:.2f}s"
|
||||
)
|
||||
persist_print(f"⏱️ Total time elapsed: {elapsed:.2f}s")
|
||||
persist_print(f"✅ Successful seeds tested: {total_successful}")
|
||||
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||
persist_print(
|
||||
f"📊 Total seeds tested: {total_successful + total_ignored + 1}"
|
||||
)
|
||||
persist_print("\n💥 Failure output:")
|
||||
persist_print("-" * 60)
|
||||
print_output_lines(result.output, persist_print)
|
||||
persist_print("-" * 60)
|
||||
persist_print(
|
||||
f"\n🔄 Reproduce with: python fuzzer.py --seed {result.seed} --template {template}"
|
||||
)
|
||||
|
||||
# Exit with non-zero code
|
||||
sys.exit(1)
|
||||
|
||||
# Update progress bar
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.set_postfix_str(f"{total_successful}/{total_ignored}")
|
||||
pbar.update(1)
|
||||
elif verbose:
|
||||
status_emoji = "✅" if result.success else "🚫"
|
||||
persist_print(f"Seed {result.seed}: {status_emoji}")
|
||||
|
||||
# Close progress bar for this batch
|
||||
if HAS_TQDM and pbar:
|
||||
pbar.close()
|
||||
|
||||
# Move to next batch
|
||||
current_seed += batch_size
|
||||
|
||||
except KeyboardInterrupt:
|
||||
persist_print("\n🛑 Interrupted by user (Ctrl+C)")
|
||||
elapsed = time.time() - start_time
|
||||
persist_print("=" * 60)
|
||||
persist_print("📈 SUMMARY (interrupted)")
|
||||
persist_print("=" * 60)
|
||||
persist_print(f"⏱️ Total time: {elapsed:.2f}s")
|
||||
persist_print(f"✅ Successful seeds: {total_successful}")
|
||||
persist_print(f"🚫 Ignored seeds: {total_ignored}")
|
||||
persist_print(f"📊 Total seeds tested: {total_successful + total_ignored}")
|
||||
persist_print(
|
||||
f"⚡ Throughput: {((total_successful + total_ignored) / (elapsed / 3600)):.2f} seeds/hr"
|
||||
)
|
||||
sys.exit(130)
|
||||
|
@ -5,7 +5,8 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class FlightRecorderLogger:
|
||||
|
@ -4,12 +4,16 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast, TYPE_CHECKING
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def gh_fetch_url_and_headers(
|
||||
url: str,
|
||||
*,
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
import sys
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
|
||||
_FILE = Path(__file__).absolute()
|
||||
@ -18,7 +18,7 @@ else:
|
||||
import _linter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
|
||||
|
||||
GRANDFATHER_LIST = _FILE.parent / "docstring_linter-grandfather.json"
|
||||
|
@ -22,11 +22,15 @@ import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
from typing import Any, NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
from yaml import load
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
# Safely load fast C Yaml loader/dumper if they are available
|
||||
try:
|
||||
from yaml import CSafeLoader as Loader
|
||||
|
@ -65,10 +65,11 @@ import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from ast import literal_eval
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from platform import system as platform_system
|
||||
from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, cast, NamedTuple, TYPE_CHECKING, TypeVar
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -7,10 +7,14 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast, TYPE_CHECKING
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
|
@ -6,13 +6,17 @@ import json
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, cast
|
||||
from typing import Any, cast, TYPE_CHECKING
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from tools.stats.upload_stats_lib import upload_to_s3
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
FILTER_OUT_USERS = {
|
||||
"pytorchmergebot",
|
||||
"facebook-github-bot",
|
||||
|
@ -9,12 +9,16 @@ import time
|
||||
import zipfile
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Optional
|
||||
from typing import Any, cast, Optional, TYPE_CHECKING
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
import requests
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
|
||||
|
||||
|
||||
|
@ -107,6 +107,7 @@ TESTS = discover_tests(
|
||||
"lazy/test_meta_kernel",
|
||||
"lazy/test_extract_compiled_graph",
|
||||
"test/inductor/test_aot_inductor_utils",
|
||||
"inductor/test_aoti_cross_compile_windows",
|
||||
"onnx/test_onnxscript_no_runtime",
|
||||
"onnx/test_pytorch_onnx_onnxruntime_cuda",
|
||||
"onnx/test_models",
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from warnings import warn
|
||||
|
||||
from tools.testing.target_determination.heuristics.interface import (
|
||||
@ -17,6 +17,10 @@ from tools.testing.target_determination.heuristics.utils import (
|
||||
from tools.testing.test_run import TestRun
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).parents[3]
|
||||
|
||||
keyword_synonyms: dict[str, list[str]] = {
|
||||
|
@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from tools.stats.import_test_stats import get_disabled_tests
|
||||
from tools.testing.test_run import ShardedTest, TestRun
|
||||
@ -19,7 +19,7 @@ except ImportError:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
@ -1442,6 +1442,7 @@ _has_cuda: _bool
|
||||
_has_magma: _bool
|
||||
_has_xpu: _bool
|
||||
_has_mkldnn: _bool
|
||||
_has_mkldnn_acl: _bool
|
||||
_has_cudnn: _bool
|
||||
_has_cusparselt: _bool
|
||||
has_spectral: _bool
|
||||
|
@ -1376,6 +1376,7 @@ class CompilationMetrics:
|
||||
recompile_user_contexts: Optional[set[str]] = None
|
||||
inline_inbuilt_nn_modules_candidate: Optional[bool] = False
|
||||
pytorch_version: Optional[str] = None
|
||||
inductor_provenance: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:
|
||||
|
@ -42,7 +42,12 @@ import torch.distributed as dist
|
||||
from torch import SymInt, Tensor
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.exc import SkipFrame
|
||||
from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed
|
||||
from torch._dynamo.utils import (
|
||||
CompileEventLogger,
|
||||
counters,
|
||||
dynamo_timed,
|
||||
get_metrics_context,
|
||||
)
|
||||
from torch._inductor import config, exc, metrics
|
||||
from torch._inductor.codegen.common import (
|
||||
custom_backend_codegen_configs,
|
||||
@ -339,7 +344,7 @@ def sha256_hash(data: bytes) -> str:
|
||||
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
|
||||
|
||||
|
||||
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
|
||||
def code_hash(code: str | bytes, extra: str | bytes = "") -> str:
|
||||
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
|
||||
if extra:
|
||||
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
|
||||
@ -361,9 +366,7 @@ def get_path(
|
||||
return basename, subdir, path
|
||||
|
||||
|
||||
def get_hash(
|
||||
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
|
||||
) -> str:
|
||||
def get_hash(content: str | bytes, extra: str = "", hash_type: str = "code") -> str:
|
||||
if hash_type in {"amdgcn", "code", "ptx", "spv"}:
|
||||
return code_hash(content, extra)
|
||||
if hash_type in {"cubin", "hsaco", "spv"}:
|
||||
@ -409,7 +412,7 @@ class WritableTempFile:
|
||||
|
||||
|
||||
def write(
|
||||
content: Union[str, bytes],
|
||||
content: str | bytes,
|
||||
extension: str,
|
||||
extra: str = "",
|
||||
hash_type: str = "code",
|
||||
@ -436,7 +439,7 @@ def write_text(text: str) -> str:
|
||||
|
||||
def write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
content: str | bytes,
|
||||
make_dirs: bool = False,
|
||||
encode_utf_8: bool = False,
|
||||
) -> None:
|
||||
@ -547,7 +550,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
|
||||
def _reduce_tensor(
|
||||
self, t: Tensor
|
||||
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||
) -> tuple[Callable[[T], T], tuple[TensorMetadata | TensorMetadataAndValues]]:
|
||||
"""
|
||||
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
||||
stored as attributes on the GraphModule.
|
||||
@ -943,7 +946,7 @@ class FxGraphHashDetails:
|
||||
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
|
||||
self, custom_pass: CustomGraphPassType | CustomGraphModulePass
|
||||
) -> Any | None:
|
||||
if not custom_pass:
|
||||
return None
|
||||
@ -1058,7 +1061,7 @@ class GuardedCache(Generic[T]):
|
||||
key: str,
|
||||
local: bool,
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool],
|
||||
hints: list[int],
|
||||
) -> tuple[T | None, bytes | None, dict[str, str]]:
|
||||
"""
|
||||
@ -1283,6 +1286,10 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
},
|
||||
payload_fn=lambda: graph.inductor_provenance_stack_traces_str,
|
||||
)
|
||||
if get_metrics_context().in_progress():
|
||||
get_metrics_context().add_to_set(
|
||||
"inductor_provenance", graph.inductor_provenance_stack_traces_str
|
||||
)
|
||||
return graph, cache_info
|
||||
|
||||
@staticmethod
|
||||
@ -1292,7 +1299,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
local: bool,
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
@ -1543,7 +1550,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
is_backward: bool,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
@ -1723,12 +1730,12 @@ class AotCodeCompiler:
|
||||
*,
|
||||
device_type: str,
|
||||
additional_files: list[str],
|
||||
) -> Union[list[Union[str, Weights]], str]:
|
||||
) -> list[Union[str, Weights]] | str:
|
||||
"""
|
||||
Returns the .so path, or returns a list of files that were generated if
|
||||
config.aot_inductor.package=True.
|
||||
"""
|
||||
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
|
||||
generated_files: list[str | Weights] = additional_files # type: ignore[assignment]
|
||||
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
|
||||
@ -2342,7 +2349,7 @@ end
|
||||
f.write(json.dumps(qual_name_to_id))
|
||||
generated_files.append(constants_config_json)
|
||||
|
||||
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
|
||||
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
|
||||
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
|
||||
)
|
||||
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
|
||||
@ -2555,7 +2562,7 @@ end
|
||||
_libgomp: CDLL | None = None
|
||||
|
||||
|
||||
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
|
||||
def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
|
||||
# This function will be called from generated cpp wrapper code in the JIT mode.
|
||||
# Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them.
|
||||
def convert_arg(arg: Any) -> Any:
|
||||
@ -2698,16 +2705,16 @@ class CppCodeCache:
|
||||
"""Compiles and caches C++ libraries. Users of this class supply the source code to
|
||||
be compiled, while compilation flags are set by CppBuilder."""
|
||||
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||
def _load_library_inner(path: str, key: str) -> CDLL | ModuleType:
|
||||
return cdll.LoadLibrary(path)
|
||||
|
||||
@classmethod
|
||||
def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||
def _load_library(cls, path: str, key: str) -> CDLL | ModuleType:
|
||||
try:
|
||||
result = cls._load_library_inner(path, key)
|
||||
result.key = key # type: ignore[union-attr]
|
||||
@ -2910,7 +2917,7 @@ def _worker_compile_cpp(
|
||||
# Customized Python binding for cpp kernels
|
||||
@clear_on_fresh_cache
|
||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
# kernels have no dependency on libtorch
|
||||
@ -3092,7 +3099,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
"include_pytorch": True,
|
||||
@ -3161,7 +3168,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache: dict[str, Callable[[], ModuleType | CDLL]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_standalone_runtime_path: str | None = None
|
||||
prefix = textwrap.dedent(
|
||||
|
@ -950,6 +950,7 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||
or _all_in_parens(string)
|
||||
):
|
||||
# don't put extra parens for strings that are already wrapped in parens
|
||||
# pyrefly: ignore # bad-return
|
||||
return string
|
||||
return f"({string})"
|
||||
|
||||
@ -1736,7 +1737,9 @@ class KernelArgs:
|
||||
)
|
||||
)
|
||||
for outer, inner in chain(
|
||||
self.input_buffers.items(), self.output_buffers.items()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
self.input_buffers.items(),
|
||||
self.output_buffers.items(),
|
||||
):
|
||||
if outer in self.inplace_buffers or isinstance(inner, RemovedArg):
|
||||
continue
|
||||
@ -2047,6 +2050,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if increase_kernel_count:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_kernel_count += 1
|
||||
self.args = args or KernelArgs()
|
||||
self.loads = IndentedBuffer()
|
||||
@ -2113,6 +2117,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
self.compute = compute
|
||||
self.stores = stores
|
||||
self.cse = cse
|
||||
# pyrefly: ignore # unbound-name
|
||||
if disallow_stores:
|
||||
assert not sb, "unexpected store inside swap_buffers"
|
||||
|
||||
@ -2384,6 +2389,7 @@ class KernelTemplate:
|
||||
class DetailedTemplateSyntaxError(TemplateSyntaxError):
|
||||
def __init__(self, original_error: TemplateSyntaxError) -> None:
|
||||
super().__init__(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
original_error.message,
|
||||
original_error.lineno,
|
||||
original_error.name,
|
||||
@ -2395,6 +2401,7 @@ class KernelTemplate:
|
||||
error_info = f"Error in template at line {self.lineno}\n"
|
||||
error_info += f"Error message: {self.message}\n"
|
||||
if hasattr(self.original_error, "source"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
lines = self.original_error.source.split("\n")
|
||||
error_info += "Context:\n"
|
||||
start = max(0, self.lineno - 2)
|
||||
|
@ -504,6 +504,7 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
|
||||
if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
|
||||
return cls(
|
||||
node1.scheduler,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
list(node1.get_outer_nodes())
|
||||
if type(node1) is OuterLoopFusedSchedulerNode
|
||||
@ -1716,6 +1717,7 @@ class CppVecOverrides(CppOverrides):
|
||||
body_vec_var.dtype = dtype
|
||||
other_vec_var.dtype = dtype
|
||||
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
V.kernel.overrides
|
||||
) # type: ignore[has-type]
|
||||
code.writeline(
|
||||
@ -1759,6 +1761,7 @@ class CppVecOverrides(CppOverrides):
|
||||
csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
|
||||
None, index, dtype, V.kernel.compute
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
csevar.update_on_args("index_expr", (expr, dtype), {})
|
||||
return csevar
|
||||
|
||||
@ -2036,6 +2039,7 @@ class CppKernel(Kernel):
|
||||
# mask's dtype should be bool
|
||||
mask.dtype = torch.bool
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self._load_mask = mask
|
||||
try:
|
||||
yield mask
|
||||
@ -2363,6 +2367,7 @@ class CppKernel(Kernel):
|
||||
sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
|
||||
for n in range(len(self.ranges))
|
||||
]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.reduction_depth = len(lengths)
|
||||
return (
|
||||
self.itervars[: self.reduction_depth],
|
||||
@ -2610,7 +2615,9 @@ class CppKernel(Kernel):
|
||||
and end == self.ranges[var_id]
|
||||
):
|
||||
end = 1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
conditions.append(f"{var} >= {cexpr_index(start)}")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
conditions.append(f"{var} < {cexpr_index(end)}")
|
||||
return True
|
||||
|
||||
@ -4085,6 +4092,7 @@ class CppKernelProxy(CppKernel):
|
||||
and (dt := get_output_dtype(_node)) in DTYPE_LOWP_FP
|
||||
):
|
||||
# No need to promote to float if all users are ops that accepts lowp fp input
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if all(is_lowp_fp_sink(user, dt) for user in _node.users):
|
||||
continue
|
||||
ops = _node.args[0]
|
||||
@ -4095,12 +4103,14 @@ class CppKernelProxy(CppKernel):
|
||||
_node.replace_all_uses_with(
|
||||
to_type_node, lambda n: n is not to_type_node
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
elif (
|
||||
_node.target == "store"
|
||||
and (dt := get_input_dtype(_node)) in DTYPE_LOWP_FP
|
||||
):
|
||||
ops, name, _, value_var, _ = _node.args
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
if is_lowp_fp_source_no_promote(value_var, dt):
|
||||
continue
|
||||
dtype = V.graph.get_dtype(name)
|
||||
@ -4109,6 +4119,7 @@ class CppKernelProxy(CppKernel):
|
||||
"to_dtype", args=(ops, value_var, dtype)
|
||||
)
|
||||
_node.replace_input_with(value_var, to_type_node)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
elif _node.target == "reduction":
|
||||
(
|
||||
@ -4178,6 +4189,7 @@ class CppKernelProxy(CppKernel):
|
||||
"to_dtype", args=(ops, value_var, src_dtype)
|
||||
)
|
||||
_node.replace_input_with(value_var, to_type_node)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
|
||||
# to_dtype_bitcast act as a lowp fp source:
|
||||
@ -4196,6 +4208,7 @@ class CppKernelProxy(CppKernel):
|
||||
_node.replace_all_uses_with(
|
||||
to_type_node, lambda n: n is not to_type_node
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.cpp_to_dtype_count += 1
|
||||
|
||||
def eliminate_to_dtype(sub_graph: torch.fx.Graph):
|
||||
@ -4289,6 +4302,7 @@ class CppKernelProxy(CppKernel):
|
||||
with kernel_group.new_kernel(cls, *args) as kernel:
|
||||
# Ugly hack to maintain the metrics kernel count since
|
||||
# we only count in CppKernelProxy, not those contained in it
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_kernel_count -= 1
|
||||
|
||||
run(kernel)
|
||||
@ -4360,6 +4374,7 @@ class CppKernelProxy(CppKernel):
|
||||
)
|
||||
|
||||
if len(tiling_indices) == 1:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_cpp_vec_kernel_count += 1
|
||||
loop = self.loop_nest.tile(tiling_indices[0], factor=tiling_factors[0])
|
||||
vec_kernel = codegen_kernel(
|
||||
@ -4386,6 +4401,7 @@ class CppKernelProxy(CppKernel):
|
||||
and tiling_factors[0] == tiling_factors[1]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.generated_cpp_vec_kernel_count += 2
|
||||
outer_loop = self.loop_nest.tile(
|
||||
tiling_indices[0], factor=tiling_factors[0]
|
||||
@ -5134,10 +5150,12 @@ class CppScheduling(BaseScheduling):
|
||||
contiguous_index_expr = 0
|
||||
stride = 1
|
||||
for var, range in reversed(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
scheduler_node._body.var_ranges.items()
|
||||
):
|
||||
contiguous_index_expr += stride * var
|
||||
stride *= range
|
||||
# pyrefly: ignore # missing-attribute
|
||||
write_index_expr = scheduler_node._body.get_write_expr(
|
||||
scheduler_buffer.get_name()
|
||||
)
|
||||
@ -5206,6 +5224,7 @@ class CppScheduling(BaseScheduling):
|
||||
)
|
||||
local_buffers.append(local_buffer_used)
|
||||
local_to_global_buffers[local_buffer_used.name] = [] # type: ignore[index]
|
||||
# pyrefly: ignore # index-error
|
||||
local_to_global_buffers[local_buffer_used.name].append(
|
||||
global_buffer,
|
||||
)
|
||||
@ -5450,6 +5469,7 @@ class CppScheduling(BaseScheduling):
|
||||
wrapper = V.graph.wrapper_code
|
||||
debug_handle = set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, # type: ignore[arg-type]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
kernel_name,
|
||||
)
|
||||
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
|
||||
@ -5771,6 +5791,7 @@ class LoopNest:
|
||||
loop = self.loops[par_depth.start_depth]
|
||||
loop.parallel = par_depth.parallel_depth
|
||||
if loop.is_reduction:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
metrics.parallel_reduction_count += 1
|
||||
for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
|
||||
self.loops[i].collapsed = True
|
||||
|
@ -396,12 +396,15 @@ def transpose_w(W: _T, trans_w: bool) -> _T:
|
||||
if isinstance(W, ir.IRNode):
|
||||
if trans_w:
|
||||
if not isinstance(W, ir.TensorBox):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
W = ir.TensorBox(W)
|
||||
W = L.permute(W, [1, 0])
|
||||
else:
|
||||
if trans_w:
|
||||
assert isinstance(W, torch.Tensor)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
W = W.transpose(0, 1)
|
||||
# pyrefly: ignore # bad-return
|
||||
return W
|
||||
|
||||
|
||||
@ -412,12 +415,15 @@ def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
|
||||
if B is not None:
|
||||
if isinstance(B, ir.IRNode):
|
||||
if not isinstance(B, ir.TensorBox):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
B = ir.TensorBox(B)
|
||||
assert hasattr(X, "get_size")
|
||||
# pyrefly: ignore # missing-attribute
|
||||
B = L.expand(B, (X.get_size()[0], B.get_size()[-1]))
|
||||
else:
|
||||
assert isinstance(B, torch.Tensor)
|
||||
assert isinstance(X, torch.Tensor)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
B = B.expand(X.shape[0], B.shape[-1])
|
||||
return B
|
||||
|
||||
@ -1043,6 +1049,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
return cls.prep_weight(
|
||||
new_inputs,
|
||||
new_layout,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
micro_gemm,
|
||||
pre_block_weights,
|
||||
use_int8_fast_compensation_path,
|
||||
@ -1066,6 +1073,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
new_input_nodes, _ = cls.prep_weight(
|
||||
new_input_nodes,
|
||||
new_layout,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
micro_gemm,
|
||||
pre_block_weights,
|
||||
use_int8_fast_compensation_path,
|
||||
@ -1470,7 +1478,9 @@ class CppGemmTemplate(CppTemplate):
|
||||
assert isinstance(template_buffer, ir.IRNode)
|
||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
||||
gemm_output_buffer = ir.Buffer(
|
||||
name=gemm_output_name, layout=template_buffer.layout
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=gemm_output_name,
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
current_input_buffer = gemm_output_buffer
|
||||
for i, creator in enumerate(epilogue_creators):
|
||||
@ -1481,6 +1491,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
epilogues.append(
|
||||
ir.ComputedBuffer(
|
||||
name=buffer_name,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
layout=template_buffer.layout,
|
||||
data=creator(current_input_buffer),
|
||||
)
|
||||
@ -1490,7 +1501,9 @@ class CppGemmTemplate(CppTemplate):
|
||||
reindexers.append(None)
|
||||
if i < len(epilogue_creators) - 1:
|
||||
current_input_buffer = ir.Buffer(
|
||||
name=buffer_name, layout=template_buffer.layout
|
||||
# pyrefly: ignore # missing-attribute
|
||||
name=buffer_name,
|
||||
layout=template_buffer.layout,
|
||||
)
|
||||
|
||||
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
|
||||
@ -1521,6 +1534,7 @@ class CppGemmTemplate(CppTemplate):
|
||||
self.n,
|
||||
self.k,
|
||||
input_dtype=X.get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input2_dtype=W.get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
|
@ -183,12 +183,14 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
)
|
||||
self.act_mapping = act_mapping
|
||||
self.gemm_grouped_num = gemm_grouped_num
|
||||
# pyrefly: ignore # bad-override
|
||||
self.output_node: list[ir.Buffer] = [
|
||||
ir.Buffer(name="buf_out" + str(idx), layout=layout)
|
||||
for idx in range(gemm_grouped_num)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def add_choices(
|
||||
cls,
|
||||
choices: list[ChoiceCaller],
|
||||
@ -231,6 +233,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
if isinstance(inputs[idx], torch.Tensor):
|
||||
W = inputs[idx]
|
||||
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
@ -246,8 +249,10 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
new_input = new_inputs[wgt_idx]
|
||||
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
|
||||
for bias_idx in range(bias_start_idx, len(new_inputs)):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
new_bias = expand_bias(new_inputs[bias_idx], X)
|
||||
assert new_bias is not None
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
new_inputs[bias_idx] = new_bias
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
@ -308,6 +313,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
W_tensor = []
|
||||
for W_node in W_nodes:
|
||||
assert W_node.get_name() in V.graph.constants
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
|
||||
W_tensor # type: ignore[assignment]
|
||||
@ -324,6 +330,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
template_buffer.inputs[idx] = (
|
||||
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
return output
|
||||
|
||||
template = DataProcessorTemplateWrapper(
|
||||
@ -362,6 +369,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
cur_idx = bias_start_idx
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
inp = None
|
||||
# pyrefly: ignore # index-error
|
||||
if self.has_bias[inp_idx]:
|
||||
inp = self.input_nodes[cur_idx]
|
||||
cur_idx += 1
|
||||
@ -390,6 +398,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
self.n,
|
||||
self.k,
|
||||
input_dtype=X_list[0].get_dtype(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
input2_dtype=W_list[0].get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
@ -427,6 +436,7 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
for x_idx in range(wgt_start_idx):
|
||||
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
|
||||
for w_idx in range(self.gemm_grouped_num):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
kernel_args["W" + str(w_idx)] = W_list[w_idx]
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
|
||||
|
@ -85,6 +85,7 @@ class CppTemplate(KernelTemplate):
|
||||
bmreq = CppBenchmarkRequest(
|
||||
kernel_name=kernel_name,
|
||||
input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
output_tensor_meta=TensorMeta.from_irnodes(self.output_node),
|
||||
extra_args=extra_args,
|
||||
source_code=code,
|
||||
@ -112,6 +113,7 @@ class CppTemplate(KernelTemplate):
|
||||
kernel_hash_name,
|
||||
self.name,
|
||||
self.input_nodes,
|
||||
# pyrefly: ignore # index-error
|
||||
self.output_node[0].get_layout()
|
||||
if isinstance(self.output_node, Iterable)
|
||||
else self.output_node.get_layout(),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user