mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 22:25:10 +08:00
Compare commits
110 Commits
annotate_f
...
ciflow/roc
Author | SHA1 | Date | |
---|---|---|---|
479627f26b | |||
6b226773e9 | |||
3af2f0c12a | |||
6ece527fc5 | |||
ce29d0d796 | |||
7231118db3 | |||
5d4da26ed0 | |||
574c9fc950 | |||
80d2ca7566 | |||
4a22139eea | |||
cb6e4d7d82 | |||
202f83dc4e | |||
9fe3b2afbe | |||
d0c24b392c | |||
b44fb14906 | |||
51348c0219 | |||
fdd560afd1 | |||
e925dfcc6b | |||
f1d882212a | |||
24879f0de9 | |||
9e94ec76b8 | |||
364624e209 | |||
7e150467f7 | |||
43d78423ac | |||
fcbde24c1c | |||
861cdb887b | |||
3154482072 | |||
9fccbdd4f0 | |||
7dabfb07cb | |||
d0add0be43 | |||
11e2084308 | |||
9726553653 | |||
d82527b32a | |||
5d9b024276 | |||
5b2afe4c5d | |||
b2953f5643 | |||
470e2f61c3 | |||
e0fe37fa68 | |||
d2c82bafb7 | |||
98a488c9aa | |||
5b3ea75895 | |||
556fc09a9f | |||
ce109b3f79 | |||
4d833f859b | |||
d7e275d4b4 | |||
d5db3aee0d | |||
5641de7b6b | |||
cbc08c8993 | |||
1a54d3333d | |||
4c1c341fa0 | |||
5f21cc786a | |||
e86942f422 | |||
2cd5fd1588 | |||
7d0f872cb3 | |||
fb06e49ce8 | |||
27a98e6ae9 | |||
b10f463b1a | |||
431c13cf61 | |||
aead9270f5 | |||
9bf5b38c14 | |||
aba8c43594 | |||
37f3ba274a | |||
82a603414f | |||
3d3e4be9be | |||
a7c5524023 | |||
e365285a57 | |||
7d2afcf919 | |||
7dcbb5f610 | |||
64210febd2 | |||
11e1c80965 | |||
1bfb16e0c7 | |||
1970cbfaec | |||
e23ec3f287 | |||
6b41f33303 | |||
b864122f8f | |||
5a319f32f7 | |||
dca5868a4b | |||
71e55186e1 | |||
ee1f754c73 | |||
c3f73c3759 | |||
b09bba1ae2 | |||
3c4c1aa965 | |||
9cf12cb64f | |||
e7838ab2ef | |||
be81282147 | |||
03a4032ba3 | |||
9b60804b95 | |||
dd3ca0b818 | |||
3f7021ee7e | |||
bfc83368f1 | |||
fa4eae9c4c | |||
d21a727f03 | |||
84388230df | |||
b25bf8345f | |||
9cd9145a00 | |||
1bd02632ec | |||
fced155a01 | |||
43be93fa8c | |||
0f0a5ccea3 | |||
6c7cbe21f8 | |||
7c7bee6737 | |||
52f5528cf0 | |||
c55ca807a5 | |||
f0fca2f739 | |||
e538b5052c | |||
c6f2cddbba | |||
7d7e3fc5c0 | |||
c5a07e0770 | |||
5d6943ddd8 | |||
4cb196c3cd |
@ -113,6 +113,7 @@ case "$tag" in
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
INSTALL_MINGW=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11)
|
||||
CUDA_VERSION=13.0.0
|
||||
@ -361,6 +362,7 @@ docker build \
|
||||
--build-arg "OPENBLAS=${OPENBLAS:-}" \
|
||||
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
|
||||
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
|
||||
--build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \
|
||||
-f $(dirname ${DOCKERFILE})/Dockerfile \
|
||||
-t "$tmp_tag" \
|
||||
"$@" \
|
||||
|
10
.ci/docker/common/install_mingw.sh
Normal file
10
.ci/docker/common/install_mingw.sh
Normal file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Install MinGW-w64 for Windows cross-compilation
|
||||
apt-get update
|
||||
apt-get install -y g++-mingw-w64-x86-64-posix
|
||||
|
||||
echo "MinGW-w64 installed successfully"
|
||||
x86_64-w64-mingw32-g++ --version
|
@ -20,7 +20,7 @@ pip_install \
|
||||
|
||||
pip_install coloredlogs packaging
|
||||
pip_install onnxruntime==1.23.0
|
||||
pip_install onnxscript==0.5.3
|
||||
pip_install onnxscript==0.5.4
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||
|
@ -103,6 +103,11 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
||||
|
||||
ARG INSTALL_MINGW
|
||||
COPY ./common/install_mingw.sh install_mingw.sh
|
||||
RUN if [ -n "${INSTALL_MINGW}" ]; then bash ./install_mingw.sh; fi
|
||||
RUN rm install_mingw.sh
|
||||
|
||||
ARG TRITON
|
||||
ARG TRITON_CPU
|
||||
|
||||
|
@ -485,6 +485,22 @@ test_inductor_aoti() {
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
test_inductor_aoti_cross_compile_for_windows() {
|
||||
|
||||
TEST_REPORTS_DIR=$(pwd)/test/test-reports
|
||||
mkdir -p "$TEST_REPORTS_DIR"
|
||||
|
||||
# Set WINDOWS_CUDA_HOME environment variable
|
||||
WINDOWS_CUDA_HOME="$(pwd)/win-torch-wheel-extracted"
|
||||
export WINDOWS_CUDA_HOME
|
||||
|
||||
echo "WINDOWS_CUDA_HOME is set to: $WINDOWS_CUDA_HOME"
|
||||
echo "Contents:"
|
||||
ls -lah "$(pwd)/win-torch-wheel-extracted/lib/x64/" || true
|
||||
|
||||
python test/inductor/test_aoti_cross_compile_windows.py -k compile --package-dir "$TEST_REPORTS_DIR" --win-torch-lib-dir "$(pwd)/win-torch-wheel-extracted/torch/lib"
|
||||
}
|
||||
|
||||
test_inductor_cpp_wrapper_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
echo "NUM_TEST_SHARDS must be defined to run a Python test shard"
|
||||
@ -900,7 +916,7 @@ test_inductor_set_cpu_affinity(){
|
||||
export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD"
|
||||
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
|
||||
|
||||
if [[ "${TEST_CONFIG}" != *aarch64* ]]; then
|
||||
if [[ "$(uname -m)" != "aarch64" ]]; then
|
||||
# Use Intel OpenMP for x86
|
||||
IOMP_LIB="$(dirname "$(which python)")/../lib/libiomp5.so"
|
||||
export LD_PRELOAD="$IOMP_LIB":"$LD_PRELOAD"
|
||||
@ -914,7 +930,7 @@ test_inductor_set_cpu_affinity(){
|
||||
cores=$((cpus / thread_per_core))
|
||||
|
||||
# Set number of cores to 16 on aarch64 for performance runs
|
||||
if [[ "${TEST_CONFIG}" == *aarch64* && $cores -gt 16 ]]; then
|
||||
if [[ "$(uname -m)" == "aarch64" && $cores -gt 16 ]]; then
|
||||
cores=16
|
||||
fi
|
||||
export OMP_NUM_THREADS=$cores
|
||||
@ -1667,7 +1683,7 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then
|
||||
python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0
|
||||
fi
|
||||
python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then
|
||||
test_linux_aarch64
|
||||
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
|
||||
test_forward_backward_compatibility
|
||||
@ -1718,6 +1734,8 @@ elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
test_inductor_micro_benchmark
|
||||
elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then
|
||||
test_inductor_aoti_cross_compile_for_windows
|
||||
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
|
||||
install_torchvision
|
||||
id=$((SHARD_NUMBER-1))
|
||||
|
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -3,6 +3,7 @@ ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
|
2
.github/scripts/trymerge.py
vendored
2
.github/scripts/trymerge.py
vendored
@ -1092,7 +1092,7 @@ class GitHubPR:
|
||||
editor = node["editor"]
|
||||
return GitHubComment(
|
||||
body_text=node["bodyText"],
|
||||
created_at=node["createdAt"] if "createdAt" in node else "",
|
||||
created_at=node.get("createdAt", ""),
|
||||
author_login=node["author"]["login"],
|
||||
author_url=node["author"].get("url", None),
|
||||
author_association=node["authorAssociation"],
|
||||
|
40
.github/workflows/_linux-test.yml
vendored
40
.github/workflows/_linux-test.yml
vendored
@ -224,6 +224,46 @@ jobs:
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/download-td-artifacts
|
||||
|
||||
- name: Download Windows torch wheel for cross-compilation
|
||||
if: matrix.win_torch_wheel_artifact != ''
|
||||
uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0
|
||||
with:
|
||||
name: ${{ matrix.win_torch_wheel_artifact }}
|
||||
path: win-torch-wheel
|
||||
|
||||
- name: Extract Windows wheel and setup CUDA libraries
|
||||
if: matrix.win_torch_wheel_artifact != ''
|
||||
shell: bash
|
||||
run: |
|
||||
set -x
|
||||
|
||||
# Find the wheel file
|
||||
WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1)
|
||||
if [ -z "$WHEEL_FILE" ]; then
|
||||
echo "Error: No wheel file found in win-torch-wheel directory"
|
||||
exit 1
|
||||
fi
|
||||
echo "Found wheel file: $WHEEL_FILE"
|
||||
|
||||
# Unzip the wheel file
|
||||
unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted
|
||||
echo "Extracted wheel contents"
|
||||
|
||||
# Setup CUDA libraries (cuda.lib and cudart.lib) directory
|
||||
mkdir -p win-torch-wheel-extracted/lib/x64
|
||||
if [ -f "win-torch-wheel/cuda.lib" ]; then
|
||||
mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/
|
||||
echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/"
|
||||
fi
|
||||
if [ -f "win-torch-wheel/cudart.lib" ]; then
|
||||
mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/
|
||||
echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/"
|
||||
fi
|
||||
|
||||
# Verify CUDA libraries are present
|
||||
echo "CUDA libraries:"
|
||||
ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found"
|
||||
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
|
25
.github/workflows/_win-build.yml
vendored
25
.github/workflows/_win-build.yml
vendored
@ -168,6 +168,31 @@ jobs:
|
||||
run: |
|
||||
.ci/pytorch/win-build.sh
|
||||
|
||||
# Collect Windows torch libs and CUDA libs for cross-compilation
|
||||
- name: Collect Windows CUDA libs for cross-compilation
|
||||
if: steps.build.outcome != 'skipped' && inputs.cuda-version != 'cpu'
|
||||
shell: bash
|
||||
run: |
|
||||
set -ex
|
||||
|
||||
# Create directory structure if does not exist
|
||||
mkdir -p /c/${{ github.run_id }}/build-results
|
||||
|
||||
# Copy CUDA libs
|
||||
CUDA_PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${{ inputs.cuda-version }}"
|
||||
|
||||
if [ -f "${CUDA_PATH}/lib/x64/cuda.lib" ]; then
|
||||
cp "${CUDA_PATH}/lib/x64/cuda.lib" /c/${{ github.run_id }}/build-results/
|
||||
fi
|
||||
|
||||
if [ -f "${CUDA_PATH}/lib/x64/cudart.lib" ]; then
|
||||
cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/
|
||||
fi
|
||||
|
||||
# List collected files
|
||||
echo "Collected CUDA libs:"
|
||||
ls -lah /c/${{ github.run_id }}/build-results/*.lib
|
||||
|
||||
# Upload to github so that people can click and download artifacts
|
||||
- name: Upload artifacts to s3
|
||||
if: steps.build.outcome != 'skipped'
|
||||
|
62
.github/workflows/b200-distributed.yml
vendored
Normal file
62
.github/workflows/b200-distributed.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
name: CI for distributed tests on B200
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- .github/workflows/b200-distributed.yml
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- ciflow/b200-distributed/*
|
||||
schedule:
|
||||
- cron: 46 8 * * * # about 1:46am PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
|
||||
get-label-type:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed-b200
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.dgx.b200.8" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed-b200:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-test-b200
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200
|
||||
with:
|
||||
timeout-minutes: 1200
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed-b200.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
@ -88,27 +88,27 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_huggingface_perf_rocm_mi355", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 1, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 2, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 3, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 4, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 5, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 6, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_timm_perf_rocm_mi355", shard: 7, num_shards: 7, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 1, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 2, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 3, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 4, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 5, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 6, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 7, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 8, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "inductor_torchbench_perf_rocm_mi355", shard: 9, num_shards: 9, runner: "linux.rocm.gpu.mi355.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
24
.github/workflows/operator_benchmark.yml
vendored
24
.github/workflows/operator_benchmark.yml
vendored
@ -52,3 +52,27 @@ jobs:
|
||||
docker-image: ${{ needs.x86-opbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.x86-opbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
aarch64-opbenchmark-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: aarch64-opbenchmark-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
aarch64-opbenchmark-test:
|
||||
name: aarch64-opbenchmark-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: aarch64-opbenchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
docker-image: ${{ needs.aarch64-opbenchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.aarch64-opbenchmark-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
12
.github/workflows/rocm-mi355.yml
vendored
12
.github/workflows/rocm-mi355.yml
vendored
@ -45,12 +45,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.2" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi355.1" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
17
.github/workflows/trunk.yml
vendored
17
.github/workflows/trunk.yml
vendored
@ -200,6 +200,23 @@ jobs:
|
||||
cuda-arch-list: '8.0'
|
||||
secrets: inherit
|
||||
|
||||
# Test cross-compiled models with Windows libs extracted from wheel
|
||||
cross-compile-linux-test:
|
||||
name: cross-compile-linux-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-build
|
||||
- get-label-type
|
||||
- win-vs2022-cuda12_8-py3-build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
verify-cachebench-cpu-build:
|
||||
name: verify-cachebench-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <mutex>
|
||||
#include <ATen/CachedTensorUtils.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
namespace at::autocast {
|
||||
@ -37,29 +36,10 @@ namespace {
|
||||
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
|
||||
using val_type = std::tuple<weakref_type, Tensor>;
|
||||
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes.
|
||||
// This ensures that tensors cached in torch.no_grad() (with requires_grad=False)
|
||||
// are not incorrectly reused in gradient-enabled contexts.
|
||||
// This fixes issue #158232 while maintaining optimal performance for both modes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_enabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_enabled;
|
||||
return cached_casts_grad_enabled;
|
||||
ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts;
|
||||
return cached_casts;
|
||||
}
|
||||
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts_grad_disabled() {
|
||||
static ska::flat_hash_map<TensorImpl*, val_type> cached_casts_grad_disabled;
|
||||
return cached_casts_grad_disabled;
|
||||
}
|
||||
|
||||
// Helper function to get the appropriate cache based on current gradient mode.
|
||||
// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts,
|
||||
// preventing incorrect cache hits when gradient mode changes.
|
||||
static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() {
|
||||
return at::GradMode::is_enabled() ?
|
||||
get_cached_casts_grad_enabled() :
|
||||
get_cached_casts_grad_disabled();
|
||||
}
|
||||
|
||||
std::mutex cached_casts_mutex;
|
||||
|
||||
|
||||
@ -106,9 +86,7 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
// Clear both caches to ensure consistent behavior regardless of current gradient mode
|
||||
get_cached_casts_grad_enabled().clear();
|
||||
get_cached_casts_grad_disabled().clear();
|
||||
get_cached_casts().clear();
|
||||
}
|
||||
|
||||
int increment_nesting() {
|
||||
@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
|
||||
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
|
||||
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
|
||||
// See cached_casts declaration above for detailed strategy.
|
||||
//
|
||||
// We maintain separate caches for gradient-enabled and gradient-disabled modes
|
||||
// (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad()
|
||||
// with torch.autocast(), while maintaining optimal performance for both training and inference.
|
||||
// This fixes issue #158232 without any performance regression.
|
||||
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
|
||||
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
|
||||
arg.is_leaf() && !arg.is_view() && cache_enabled &&
|
||||
|
@ -229,10 +229,10 @@ private:
|
||||
}
|
||||
|
||||
|
||||
static const uint32_t kPhilox10A = 0x9E3779B9;
|
||||
static const uint32_t kPhilox10B = 0xBB67AE85;
|
||||
static const uint32_t kPhiloxSA = 0xD2511F53;
|
||||
static const uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||
static constexpr uint32_t kPhilox10A = 0x9E3779B9;
|
||||
static constexpr uint32_t kPhilox10B = 0xBB67AE85;
|
||||
static constexpr uint32_t kPhiloxSA = 0xD2511F53;
|
||||
static constexpr uint32_t kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
typedef philox_engine Philox4_32;
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec128/vec128_bfloat16_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_float_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_half_neon.h>
|
||||
#include <ATen/cpu/vec/vec128/vec128_int_aarch64.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/vec128/vec128_convert.h>
|
||||
|
794
aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
Normal file
794
aten/src/ATen/cpu/vec/vec128/vec128_int_aarch64.h
Normal file
@ -0,0 +1,794 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
namespace at::vec {
|
||||
// Note [CPU_CAPABILITY namespace]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// This header, and all of its subheaders, will be compiled with
|
||||
// different architecture flags for each supported set of vector
|
||||
// intrinsics. So we need to make sure they aren't inadvertently
|
||||
// linked together. We do this by declaring objects in an `inline
|
||||
// namespace` which changes the name mangling, but can still be
|
||||
// accessed as `at::vec`.
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#define VEC_INT_NEON_TEMPLATE(vl, bit) \
|
||||
template <> \
|
||||
struct is_vec_specialized_for<int##bit##_t> : std::bool_constant<true> {}; \
|
||||
\
|
||||
template <> \
|
||||
class Vectorized<int##bit##_t> { \
|
||||
using neon_type = int##bit##x##vl##_t; \
|
||||
\
|
||||
private: \
|
||||
neon_type values; \
|
||||
\
|
||||
public: \
|
||||
using value_type = int##bit##_t; \
|
||||
using size_type = int; \
|
||||
static constexpr size_type size() { \
|
||||
return vl; \
|
||||
} \
|
||||
Vectorized() { \
|
||||
values = vdupq_n_s##bit(0); \
|
||||
} \
|
||||
Vectorized(neon_type v) : values(v) {} \
|
||||
Vectorized(int##bit##_t val); \
|
||||
template < \
|
||||
typename... Args, \
|
||||
typename = std::enable_if_t<(sizeof...(Args) == size())>> \
|
||||
Vectorized(Args... vals) { \
|
||||
__at_align__ int##bit##_t buffer[size()] = {vals...}; \
|
||||
values = vld1q_s##bit(buffer); \
|
||||
} \
|
||||
operator neon_type() const { \
|
||||
return values; \
|
||||
} \
|
||||
static Vectorized<int##bit##_t> loadu( \
|
||||
const void* ptr, \
|
||||
int64_t count = size()); \
|
||||
void store(void* ptr, int64_t count = size()) const; \
|
||||
template <int64_t mask> \
|
||||
static Vectorized<int##bit##_t> blend( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b); \
|
||||
static Vectorized<int##bit##_t> blendv( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b, \
|
||||
const Vectorized<int##bit##_t>& mask_) { \
|
||||
return vbslq_s##bit(vreinterpretq_u##bit##_s##bit(mask_.values), b, a); \
|
||||
} \
|
||||
template <typename step_t> \
|
||||
static Vectorized<int##bit##_t> arange( \
|
||||
value_type base = 0, \
|
||||
step_t step = static_cast<step_t>(1)); \
|
||||
static Vectorized<int##bit##_t> set( \
|
||||
const Vectorized<int##bit##_t>& a, \
|
||||
const Vectorized<int##bit##_t>& b, \
|
||||
int64_t count = size()); \
|
||||
const int##bit##_t& operator[](int idx) const = delete; \
|
||||
int##bit##_t& operator[](int idx) = delete; \
|
||||
Vectorized<int##bit##_t> abs() const { \
|
||||
return vabsq_s##bit(values); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> real() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<int##bit##_t> imag() const { \
|
||||
return vdupq_n_s##bit(0); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> conj() const { \
|
||||
return values; \
|
||||
} \
|
||||
Vectorized<int##bit##_t> neg() const { \
|
||||
return vnegq_s##bit(values); \
|
||||
} \
|
||||
int##bit##_t reduce_add() const { \
|
||||
return vaddvq_s##bit(values); \
|
||||
} \
|
||||
int##bit##_t reduce_max() const; \
|
||||
Vectorized<int##bit##_t> operator==( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vceqq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator!=( \
|
||||
const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> operator<( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcltq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator<=( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcleq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator>( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcgtq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> operator>=( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return Vectorized<value_type>( \
|
||||
vreinterpretq_s##bit##_u##bit(vcgeq_s##bit(values, other.values))); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> eq(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> ne(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> gt(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> ge(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> lt(const Vectorized<int##bit##_t>& other) const; \
|
||||
Vectorized<int##bit##_t> le(const Vectorized<int##bit##_t>& other) const; \
|
||||
}; \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator+( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vaddq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator-( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vsubq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator&( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vandq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator|( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return vorrq_s##bit(a, b); \
|
||||
} \
|
||||
template <> \
|
||||
Vectorized<int##bit##_t> inline operator^( \
|
||||
const Vectorized<int##bit##_t>& a, const Vectorized<int##bit##_t>& b) { \
|
||||
return veorq_s##bit(a, b); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::eq( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this == other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ne( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this != other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::gt( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this > other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::ge( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this >= other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::lt( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this < other) & Vectorized<int##bit##_t>(1); \
|
||||
} \
|
||||
Vectorized<int##bit##_t> inline Vectorized<int##bit##_t>::le( \
|
||||
const Vectorized<int##bit##_t>& other) const { \
|
||||
return (*this <= other) & Vectorized<int##bit##_t>(1); \
|
||||
}
|
||||
|
||||
VEC_INT_NEON_TEMPLATE(2, 64)
|
||||
VEC_INT_NEON_TEMPLATE(4, 32)
|
||||
VEC_INT_NEON_TEMPLATE(8, 16)
|
||||
VEC_INT_NEON_TEMPLATE(16, 8)
|
||||
|
||||
inline int32_t Vectorized<int32_t>::reduce_max() const {
|
||||
return vmaxvq_s32(values);
|
||||
}
|
||||
|
||||
inline int16_t Vectorized<int16_t>::reduce_max() const {
|
||||
return vmaxvq_s16(values);
|
||||
}
|
||||
|
||||
inline int8_t Vectorized<int8_t>::reduce_max() const {
|
||||
return vmaxvq_s8(values);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator*(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vmulq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator*(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vmulq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator*(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vmulq_s8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int64_t> operator~(const Vectorized<int64_t>& a) {
|
||||
int64x2_t val = a;
|
||||
return ~val;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int32_t> operator~(const Vectorized<int32_t>& a) {
|
||||
return vmvnq_s32(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int16_t> operator~(const Vectorized<int16_t>& a) {
|
||||
return vmvnq_s16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Vectorized<int8_t> operator~(const Vectorized<int8_t>& a) {
|
||||
return vmvnq_s8(a);
|
||||
}
|
||||
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::operator!=(
|
||||
const Vectorized<int64_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::operator!=(
|
||||
const Vectorized<int32_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::operator!=(
|
||||
const Vectorized<int16_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::operator!=(
|
||||
const Vectorized<int8_t>& other) const {
|
||||
return ~(*this == other);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline minimum(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vminq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline minimum(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vminq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline minimum(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vminq_s8(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline maximum(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
return vmaxq_s32(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline maximum(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
return vmaxq_s16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline maximum(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
return vmaxq_s8(a, b);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int64_t> Vectorized<int64_t>::blend(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint64x2_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFFFFFFFFFFFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFFFFFFFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s64(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int32_t> Vectorized<int32_t>::blend(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint32x4_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 4LL) ? 0xFFFFFFFF : 0,
|
||||
(mask & 8LL) ? 0xFFFFFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s32(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int16_t> Vectorized<int16_t>::blend(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint16x8_t maskArray = {
|
||||
(mask & 1LL) ? 0xFFFF : 0,
|
||||
(mask & 2LL) ? 0xFFFF : 0,
|
||||
(mask & 4LL) ? 0xFFFF : 0,
|
||||
(mask & 8LL) ? 0xFFFF : 0,
|
||||
(mask & 16LL) ? 0xFFFF : 0,
|
||||
(mask & 32LL) ? 0xFFFF : 0,
|
||||
(mask & 64LL) ? 0xFFFF : 0,
|
||||
(mask & 128LL) ? 0xFFFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s16(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
template <int64_t mask>
|
||||
Vectorized<int8_t> Vectorized<int8_t>::blend(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding bit
|
||||
// in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
(mask & 1LL) ? 0xFF : 0,
|
||||
(mask & 2LL) ? 0xFF : 0,
|
||||
(mask & 4LL) ? 0xFF : 0,
|
||||
(mask & 8LL) ? 0xFF : 0,
|
||||
(mask & 16LL) ? 0xFF : 0,
|
||||
(mask & 32LL) ? 0xFF : 0,
|
||||
(mask & 64LL) ? 0xFF : 0,
|
||||
(mask & 128LL) ? 0xFF : 0,
|
||||
(mask & 256LL) ? 0xFF : 0,
|
||||
(mask & 512LL) ? 0xFF : 0,
|
||||
(mask & 1024LL) ? 0xFF : 0,
|
||||
(mask & 2048LL) ? 0xFF : 0,
|
||||
(mask & 4096LL) ? 0xFF : 0,
|
||||
(mask & 8192LL) ? 0xFF : 0,
|
||||
(mask & 16384LL) ? 0xFF : 0,
|
||||
(mask & 32768LL) ? 0xFF : 0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s8(maskArray, b.values, a.values);
|
||||
}
|
||||
|
||||
#define VEC_INT_NEON_OPS(vl, bit) \
|
||||
inline Vectorized<int##bit##_t>::Vectorized(int##bit##_t val) { \
|
||||
values = vdupq_n_s##bit(val); \
|
||||
} \
|
||||
inline Vectorized<int##bit##_t> Vectorized<int##bit##_t>::loadu( \
|
||||
const void* ptr, int64_t count) { \
|
||||
if (count == size()) { \
|
||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(ptr)); \
|
||||
} else { \
|
||||
__at_align__ int##bit##_t tmp_values[size()]; \
|
||||
for (const auto i : c10::irange(size())) { \
|
||||
tmp_values[i] = 0; \
|
||||
} \
|
||||
std::memcpy( \
|
||||
tmp_values, \
|
||||
reinterpret_cast<const int##bit##_t*>(ptr), \
|
||||
count * sizeof(int##bit##_t)); \
|
||||
return vld1q_s##bit(reinterpret_cast<const int##bit##_t*>(tmp_values)); \
|
||||
} \
|
||||
} \
|
||||
inline void Vectorized<int##bit##_t>::store(void* ptr, int64_t count) \
|
||||
const { \
|
||||
if (count == size()) { \
|
||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(ptr), values); \
|
||||
} else { \
|
||||
int##bit##_t tmp_values[size()]; \
|
||||
vst1q_s##bit(reinterpret_cast<int##bit##_t*>(tmp_values), values); \
|
||||
std::memcpy(ptr, tmp_values, count * sizeof(int##bit##_t)); \
|
||||
} \
|
||||
}
|
||||
|
||||
VEC_INT_NEON_OPS(2, 64)
|
||||
VEC_INT_NEON_OPS(4, 32)
|
||||
VEC_INT_NEON_OPS(8, 16)
|
||||
VEC_INT_NEON_OPS(16, 8)
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator*(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return x * y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator/(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator/(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t x = a;
|
||||
int32x4_t y = b;
|
||||
return x / y;
|
||||
}
|
||||
|
||||
inline int64_t Vectorized<int64_t>::reduce_max() const {
|
||||
return std::max(values[0], values[1]);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline minimum(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return {std::min(x[0], y[0]), std::min(x[1], y[1])};
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline maximum(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
return {std::max(x[0], y[0]), std::max(x[1], y[1])};
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::arange(
|
||||
int64_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int64_t> base_vec(base);
|
||||
const Vectorized<int64_t> step_vec(step);
|
||||
const int64x2_t step_sizes = {0, 1};
|
||||
return base_vec.values + step_sizes * step_vec.values;
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::arange(
|
||||
int32_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int32_t> base_vec(base);
|
||||
const Vectorized<int32_t> step_vec(step);
|
||||
const int32x4_t step_sizes = {0, 1, 2, 3};
|
||||
return vmlaq_s32(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::arange(
|
||||
int16_t base,
|
||||
step_t step) {
|
||||
const Vectorized<int16_t> base_vec(base);
|
||||
const Vectorized<int16_t> step_vec(step);
|
||||
const int16x8_t step_sizes = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
return vmlaq_s16(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <typename step_t>
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::arange(int8_t base, step_t step) {
|
||||
const Vectorized<int8_t> base_vec(base);
|
||||
const Vectorized<int8_t> step_vec(step);
|
||||
const int8x16_t step_sizes = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
return vmlaq_s8(base_vec, step_sizes, step_vec);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator>>(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t x = a;
|
||||
int64x2_t y = b;
|
||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
||||
uint64x2_t z = {std::min(u[0], (uint64_t)63), std::min(u[1], (uint64_t)63)};
|
||||
return x >> vreinterpretq_s64_u64(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator>>(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t x = a;
|
||||
int32x4_t y = b;
|
||||
uint32x4_t bound = vdupq_n_u32(31);
|
||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
||||
return x >> vreinterpretq_s32_u32(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator>>(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
int16x8_t x = a;
|
||||
int16x8_t y = b;
|
||||
uint16x8_t bound = vdupq_n_u16(15);
|
||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
||||
return x >> vreinterpretq_s16_u16(z);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator>>(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
int8x16_t x = a;
|
||||
int8x16_t y = b;
|
||||
uint8x16_t bound = vdupq_n_u8(7);
|
||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
||||
return x >> z;
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline operator<<(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b) {
|
||||
int64x2_t y = b;
|
||||
uint64x2_t u = vreinterpretq_u64_s64(y);
|
||||
uint64x2_t z = {std::min(u[0], (uint64_t)64), std::min(u[1], (uint64_t)64)};
|
||||
return vshlq_s64(a, vreinterpretq_s64_u64(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline operator<<(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b) {
|
||||
int32x4_t y = b;
|
||||
uint32x4_t bound = vdupq_n_u32(32);
|
||||
uint32x4_t z = vminq_u32(vreinterpretq_u32_s32(y), bound);
|
||||
return vshlq_s32(a, vreinterpretq_s32_u32(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator<<(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
int16x8_t y = b;
|
||||
uint16x8_t bound = vdupq_n_u16(16);
|
||||
uint16x8_t z = vminq_u16(vreinterpretq_u16_s16(y), bound);
|
||||
return vshlq_s16(a, vreinterpretq_s16_u16(z));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator<<(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
int8x16_t y = b;
|
||||
uint8x16_t bound = vdupq_n_u8(8);
|
||||
int8x16_t z = vreinterpretq_s8_u8(vminq_u8(vreinterpretq_u8_s8(y), bound));
|
||||
return vshlq_s8(a, z);
|
||||
}
|
||||
|
||||
inline Vectorized<int64_t> Vectorized<int64_t>::set(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 2) {
|
||||
return b;
|
||||
} else {
|
||||
int64x2_t c = {b.values[0], a.values[1]};
|
||||
return c;
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int32_t> Vectorized<int32_t>::set(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 4) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint32x4_t maskArray = {
|
||||
(count >= 1LL) ? 0xFFFFFFFF : 0,
|
||||
(count >= 2LL) ? 0xFFFFFFFF : 0,
|
||||
(count >= 3LL) ? 0xFFFFFFFF : 0,
|
||||
0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s32(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int16_t> Vectorized<int16_t>::set(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 8) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint16x8_t maskArray = {
|
||||
static_cast<uint16_t>((count >= 1LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 2LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 3LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 4LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 5LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 6LL) ? 0xFFFF : 0),
|
||||
static_cast<uint16_t>((count >= 7LL) ? 0xFFFF : 0),
|
||||
0};
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s16(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
inline Vectorized<int8_t> Vectorized<int8_t>::set(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b,
|
||||
int64_t count) {
|
||||
if (count == 0) {
|
||||
return a;
|
||||
} else if (count >= 16) {
|
||||
return b;
|
||||
} else {
|
||||
// Build an array of flags: each bit of element is 1 if the corresponding
|
||||
// bit in 'mask' is set, 0 otherwise.
|
||||
uint8x16_t maskArray = {
|
||||
static_cast<uint8_t>((count >= 1LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 2LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 3LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 4LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 5LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 6LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 7LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 8LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 9LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 10LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 11LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 12LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 13LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 14LL) ? 0xFF : 0),
|
||||
static_cast<uint8_t>((count >= 15LL) ? 0xFF : 0),
|
||||
0};
|
||||
|
||||
// Use BSL to select elements from b where the mask is 1, else from a
|
||||
return vbslq_s8(maskArray, b.values, a.values);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline operator/(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& b) {
|
||||
Vectorized<int32_t> highBitsA = vmovl_high_s16(a);
|
||||
Vectorized<int32_t> highBitsB = vmovl_high_s16(b);
|
||||
Vectorized<int32_t> lowBitsA = vmovl_s16(vget_low_s16(a));
|
||||
Vectorized<int32_t> lowBitsB = vmovl_s16(vget_low_s16(b));
|
||||
int32x4_t highBitsResult = highBitsA / highBitsB;
|
||||
int32x4_t lowBitsResult = lowBitsA / lowBitsB;
|
||||
return vuzp1q_s16(
|
||||
vreinterpretq_s16_s32(lowBitsResult),
|
||||
vreinterpretq_s16_s32(highBitsResult));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline operator/(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& b) {
|
||||
Vectorized<int16_t> highBitsA = vmovl_high_s8(a);
|
||||
Vectorized<int16_t> highBitsB = vmovl_high_s8(b);
|
||||
Vectorized<int16_t> lowBitsA = vmovl_s8(vget_low_s8(a));
|
||||
Vectorized<int16_t> lowBitsB = vmovl_s8(vget_low_s8(b));
|
||||
int16x8_t highBitsResult = highBitsA / highBitsB;
|
||||
int16x8_t lowBitsResult = lowBitsA / lowBitsB;
|
||||
return vuzp1q_s8(
|
||||
vreinterpretq_s8_s16(lowBitsResult),
|
||||
vreinterpretq_s8_s16(highBitsResult));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& min,
|
||||
const Vectorized<int64_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& min,
|
||||
const Vectorized<int32_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& min,
|
||||
const Vectorized<int16_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& min,
|
||||
const Vectorized<int8_t>& max) {
|
||||
return minimum(max, maximum(min, a));
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp_max(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp_max(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp_max(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp_max(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& max) {
|
||||
return minimum(max, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int64_t> inline clamp_min(
|
||||
const Vectorized<int64_t>& a,
|
||||
const Vectorized<int64_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int32_t> inline clamp_min(
|
||||
const Vectorized<int32_t>& a,
|
||||
const Vectorized<int32_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int16_t> inline clamp_min(
|
||||
const Vectorized<int16_t>& a,
|
||||
const Vectorized<int16_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
template <>
|
||||
Vectorized<int8_t> inline clamp_min(
|
||||
const Vectorized<int8_t>& a,
|
||||
const Vectorized<int8_t>& min) {
|
||||
return maximum(min, a);
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
@ -1377,7 +1377,7 @@ Vectorized<c10::quint8> inline maximum(
|
||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||
std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s8x8 = vget_low_s8(src);
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8));
|
||||
@ -1402,7 +1402,7 @@ std::pair<Vectorized<float>, Vectorized<float>> inline convert_int8_to_float(
|
||||
|
||||
Vectorized<float> inline convert_int8_half_register_to_float(
|
||||
at::vec::Vectorized<int8_t> src) {
|
||||
auto s8x8 = vld1_s8(src.operator const int8_t*());
|
||||
auto s8x8 = vget_low_s8(src);
|
||||
auto s16x8 = vmovl_s8(s8x8);
|
||||
|
||||
auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8));
|
||||
|
@ -16,6 +16,8 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <ATen/cuda/detail/BLASConstants.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
@ -1954,13 +1956,15 @@ void scaled_gemm(
|
||||
const void *result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool use_fast_accum,
|
||||
const std::optional<Tensor>& alpha) {
|
||||
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
|
||||
// of input arguments to this function.
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const float alpha_val = 1.0;
|
||||
const float beta_val = 0.0;
|
||||
// Note: alpha_val may change later depending on user-passed argument
|
||||
float alpha_val = 1.0;
|
||||
float beta_val = 0.0;
|
||||
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||
@ -2031,6 +2035,33 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
|
||||
// Handle user-passed alpha
|
||||
float *alpha_ptr = &alpha_val;
|
||||
float *beta_ptr = &beta_val;
|
||||
|
||||
if (alpha.has_value()) {
|
||||
auto& a = alpha.value();
|
||||
|
||||
// if device-tensor
|
||||
if (a.is_cuda()) {
|
||||
// NOTE: there are lifetime requirements on device-side pointers for alpha/beta -- the value must be
|
||||
// valid & correct until the cublas call finishes (not is scheduled like host-side values). Thus
|
||||
// we need to use allocations for alpha/beta that have some guarantees on lifetime - a statically
|
||||
// managed 4B buffer for alpha that we'll copy the passed alpha value into, and constant memory
|
||||
// for beta respectively.
|
||||
float *user_alpha_ptr = at::cuda::detail::get_user_alpha_ptr();
|
||||
at::Tensor user_alpha = at::from_blob(user_alpha_ptr, {1}, TensorOptions().device(kCUDA).dtype(kFloat));
|
||||
user_alpha.copy_(a);
|
||||
// Tell cublasLt we're using device-side pointers for alpha/beta
|
||||
auto pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_POINTER_MODE, pointer_mode);
|
||||
alpha_ptr = user_alpha.data_ptr<float>();
|
||||
beta_ptr = at::cuda::detail::get_cublas_device_zero();
|
||||
} else {
|
||||
alpha_val = a.item<float>();
|
||||
}
|
||||
}
|
||||
// For other data types, use the get_scale_mode function based on scaling type
|
||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+/ROCm 7.0 or in recent hipblaslt,
|
||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||
@ -2048,6 +2079,7 @@ void scaled_gemm(
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
@ -2088,10 +2120,10 @@ void scaled_gemm(
|
||||
auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
&alpha_val,
|
||||
alpha_ptr,
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
&beta_val,
|
||||
beta_ptr,
|
||||
Cdesc.descriptor(),
|
||||
Ddesc.descriptor(),
|
||||
all_algos[i].algo,
|
||||
@ -2110,17 +2142,14 @@ void scaled_gemm(
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
&alpha_val,
|
||||
alpha_ptr,
|
||||
mat1_ptr,
|
||||
Adesc.descriptor(),
|
||||
mat2_ptr,
|
||||
Bdesc.descriptor(),
|
||||
&beta_val,
|
||||
#ifdef USE_ROCM
|
||||
beta_ptr,
|
||||
// NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either
|
||||
result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
|
||||
#else
|
||||
nullptr,
|
||||
#endif // ifdef USE_ROCM
|
||||
Cdesc.descriptor(),
|
||||
result_ptr,
|
||||
Ddesc.descriptor(),
|
||||
|
@ -161,7 +161,8 @@ void scaled_gemm(
|
||||
const void* result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum);
|
||||
bool use_fast_accum,
|
||||
const std::optional<Tensor>& alpha);
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
|
||||
|
||||
|
@ -325,9 +325,9 @@ uint64_t CUDAGeneratorImpl::seed() {
|
||||
*/
|
||||
c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
// The RNG state comprises the seed, and an offset used for Philox.
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(int64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||
@ -346,9 +346,9 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
* and size of the internal state.
|
||||
*/
|
||||
void CUDAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(int64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(int64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
detail::check_rng_state(new_state);
|
||||
|
||||
|
@ -253,7 +253,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
scan_op,
|
||||
num_items,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
C10_HIP_KERNEL_LAUNCH_CHECK();
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
// non synchronizing cub call
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
@ -531,7 +531,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
|
||||
init_value,
|
||||
num_items,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
C10_HIP_KERNEL_LAUNCH_CHECK();
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
#else
|
||||
// non synchronizing cub call
|
||||
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
|
||||
|
54
aten/src/ATen/cuda/detail/BLASConstants.cu
Normal file
54
aten/src/ATen/cuda/detail/BLASConstants.cu
Normal file
@ -0,0 +1,54 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
namespace detail {
|
||||
|
||||
__device__ __constant__ float cublas_one_device;
|
||||
__device__ __constant__ float cublas_zero_device;
|
||||
|
||||
float *get_cublas_device_one() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
const float one = 1.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
||||
});
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_cublas_device_zero() {
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
const float zero = 0.f;
|
||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
||||
});
|
||||
|
||||
float *ptr;
|
||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
float *get_user_alpha_ptr() {
|
||||
static float *alpha_ptr;
|
||||
|
||||
static c10::once_flag init_flag;
|
||||
|
||||
c10::call_once(init_flag, []() {
|
||||
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
||||
});
|
||||
|
||||
return alpha_ptr;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace cuda
|
||||
} // namespace at
|
11
aten/src/ATen/cuda/detail/BLASConstants.h
Normal file
11
aten/src/ATen/cuda/detail/BLASConstants.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/TensorBase.h>
|
||||
|
||||
namespace at::cuda::detail {
|
||||
|
||||
float *get_cublas_device_one();
|
||||
float *get_cublas_device_zero();
|
||||
float *get_user_alpha_ptr();
|
||||
|
||||
} // namespace at::cuda::detail
|
@ -109,7 +109,8 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
params->c_scale_ptr,
|
||||
params->ldc,
|
||||
params->c_dtype,
|
||||
params->use_fast_accum);
|
||||
params->use_fast_accum,
|
||||
std::nullopt /* alpha */);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
@ -1,239 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10::hip {
|
||||
|
||||
// Takes a valid HIPAllocator (of any sort) and turns it into
|
||||
// an allocator pretending to be a CUDA allocator. See
|
||||
// Note [Masquerading as CUDA]
|
||||
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
|
||||
HIPCachingAllocator::HIPAllocator* allocator_;
|
||||
public:
|
||||
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
|
||||
: allocator_(allocator) {}
|
||||
|
||||
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
|
||||
|
||||
// From c10::Allocator
|
||||
|
||||
DataPtr allocate(size_t size) override {
|
||||
DataPtr r = allocator_->allocate(size);
|
||||
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
|
||||
return r;
|
||||
}
|
||||
|
||||
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
|
||||
return allocator_->is_simple_data_ptr(data_ptr);
|
||||
}
|
||||
|
||||
DeleterFnPtr raw_deleter() const override {
|
||||
return allocator_->raw_deleter();
|
||||
}
|
||||
|
||||
void copy_data(void* dest, const void* src, std::size_t count) const final {
|
||||
allocator_->copy_data(dest, src, count);
|
||||
}
|
||||
|
||||
// From DeviceAllocator
|
||||
|
||||
bool initialized() override {
|
||||
return allocator_->initialized();
|
||||
}
|
||||
|
||||
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
|
||||
allocator_->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
|
||||
HIPStream hip_stream = HIPStream(stream);
|
||||
recordStream(ptr, hip_stream);
|
||||
}
|
||||
|
||||
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
|
||||
return allocator_->getDeviceStats(device);
|
||||
}
|
||||
|
||||
void resetAccumulatedStats(c10::DeviceIndex device) override {
|
||||
allocator_->resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
void resetPeakStats(c10::DeviceIndex device) override {
|
||||
allocator_->resetPeakStats(device);
|
||||
}
|
||||
|
||||
// From CUDAAllocator
|
||||
|
||||
void* raw_alloc(size_t nbytes) override {
|
||||
return allocator_->raw_alloc(nbytes);
|
||||
}
|
||||
|
||||
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
|
||||
return allocator_->raw_alloc_with_stream(nbytes, stream);
|
||||
}
|
||||
|
||||
void raw_delete(void* ptr) override {
|
||||
allocator_->raw_delete(ptr);
|
||||
}
|
||||
|
||||
void init(int device_count) override {
|
||||
allocator_->init(device_count);
|
||||
}
|
||||
|
||||
double getMemoryFraction(c10::DeviceIndex device) override {
|
||||
return allocator_->getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
|
||||
allocator_->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
|
||||
return allocator_->getExpandableSegmentSizes(device);
|
||||
}
|
||||
|
||||
void enable(bool value) override {
|
||||
allocator_->enable(value);
|
||||
}
|
||||
|
||||
bool isEnabled() const override {
|
||||
return allocator_->isEnabled();
|
||||
}
|
||||
|
||||
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
|
||||
allocator_->cacheInfo(device, largestBlock);
|
||||
}
|
||||
|
||||
void* getBaseAllocation(void* ptr, size_t* size) override {
|
||||
return allocator_->getBaseAllocation(ptr, size);
|
||||
}
|
||||
|
||||
void recordStream(const DataPtr& ptr, HIPStream stream) override {
|
||||
allocator_->recordStream(ptr, stream);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
|
||||
return allocator_->snapshot(mempool_id);
|
||||
}
|
||||
|
||||
void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(hipStream_t)> filter) override {
|
||||
allocator_->beginAllocateToPool(device, mempool_id, filter);
|
||||
}
|
||||
|
||||
void endAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id) override {
|
||||
allocator_->endAllocateToPool(device, mempool_id);
|
||||
}
|
||||
|
||||
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
allocator_->releasePool(device, mempool_id);
|
||||
}
|
||||
|
||||
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
return allocator_->getPoolUseCount(device, mempool_id);
|
||||
}
|
||||
|
||||
void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
HIPAllocator* allocator = nullptr) override {
|
||||
allocator_->createOrIncrefPool(device, mempool_id, allocator);
|
||||
}
|
||||
|
||||
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
|
||||
allocator_->setUseOnOOM(device, mempool_id);
|
||||
}
|
||||
|
||||
bool checkPoolLiveAllocations(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) override {
|
||||
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
|
||||
return allocator_->shareIpcHandle(ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
|
||||
return allocator_->getIpcDevPtr(handle);
|
||||
}
|
||||
|
||||
bool isHistoryEnabled() override {
|
||||
return allocator_->isHistoryEnabled();
|
||||
}
|
||||
|
||||
void recordHistory(
|
||||
bool enabled,
|
||||
HIPCachingAllocator::CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
HIPCachingAllocator::RecordContext when,
|
||||
bool clearHistory) override {
|
||||
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
|
||||
}
|
||||
|
||||
void recordAnnotation(
|
||||
const std::vector<std::pair<std::string, std::string>>& md) override {
|
||||
allocator_->recordAnnotation(md);
|
||||
}
|
||||
|
||||
void pushCompileContext(std::string& md) override {
|
||||
allocator_->pushCompileContext(md);
|
||||
}
|
||||
|
||||
void popCompileContext() override {
|
||||
allocator_->popCompileContext();
|
||||
}
|
||||
|
||||
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
|
||||
allocator_->attachOutOfMemoryObserver(observer);
|
||||
}
|
||||
|
||||
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
|
||||
allocator_->attachAllocatorTraceTracker(tracker);
|
||||
}
|
||||
|
||||
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
|
||||
allocator_->enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
hipError_t memcpyAsync(
|
||||
void* dst,
|
||||
int dstDevice,
|
||||
const void* src,
|
||||
int srcDevice,
|
||||
size_t count,
|
||||
hipStream_t stream,
|
||||
bool p2p_enabled) override {
|
||||
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
|
||||
}
|
||||
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t id) override {
|
||||
return allocator_->getCheckpointState(device, id);
|
||||
}
|
||||
|
||||
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
|
||||
auto cpd = allocator_->setCheckpointPoolState(device, pps);
|
||||
for (auto& ptr : cpd.dataptrs_allocd) {
|
||||
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
|
||||
}
|
||||
return cpd;
|
||||
}
|
||||
|
||||
std::string name() override {
|
||||
return allocator_->name();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace c10::hip
|
@ -1,18 +0,0 @@
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
|
||||
|
||||
namespace c10 { namespace hip {
|
||||
namespace HIPCachingAllocatorMasqueradingAsCUDA {
|
||||
|
||||
HIPCachingAllocator::HIPAllocator* get() {
|
||||
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
|
||||
return &allocator;
|
||||
}
|
||||
|
||||
void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream) {
|
||||
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
|
||||
}
|
||||
|
||||
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
|
||||
}} // namespace c10::hip
|
@ -1,194 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPCachingAllocator.h>
|
||||
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
|
||||
namespace c10 {
|
||||
// forward declaration
|
||||
class DataPtr;
|
||||
namespace hip {
|
||||
namespace HIPCachingAllocatorMasqueradingAsCUDA {
|
||||
|
||||
C10_HIP_API HIPCachingAllocator::HIPAllocator* get();
|
||||
C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream);
|
||||
|
||||
inline void* raw_alloc(size_t nbytes) {
|
||||
return get()->raw_alloc(nbytes);
|
||||
}
|
||||
|
||||
inline void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) {
|
||||
return get()->raw_alloc_with_stream(nbytes, stream);
|
||||
}
|
||||
|
||||
inline void raw_delete(void* ptr) {
|
||||
return get()->raw_delete(ptr);
|
||||
}
|
||||
|
||||
inline void init(int device_count) {
|
||||
return get()->init(device_count);
|
||||
}
|
||||
|
||||
inline double getMemoryFraction(c10::DeviceIndex device) {
|
||||
return get()->getMemoryFraction(device);
|
||||
}
|
||||
|
||||
inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
|
||||
return get()->setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->emptyCache(mempool_id);
|
||||
}
|
||||
|
||||
inline void enable(bool value) {
|
||||
return get()->enable(value);
|
||||
}
|
||||
|
||||
inline bool isEnabled() {
|
||||
return get()->isEnabled();
|
||||
}
|
||||
|
||||
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
|
||||
return get()->cacheInfo(device, largestBlock);
|
||||
}
|
||||
|
||||
inline void* getBaseAllocation(void* ptr, size_t* size) {
|
||||
return get()->getBaseAllocation(ptr, size);
|
||||
}
|
||||
|
||||
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
c10::DeviceIndex device) {
|
||||
return get()->getDeviceStats(device);
|
||||
}
|
||||
|
||||
inline void resetAccumulatedStats(c10::DeviceIndex device) {
|
||||
return get()->resetAccumulatedStats(device);
|
||||
}
|
||||
|
||||
inline void resetPeakStats(c10::DeviceIndex device) {
|
||||
return get()->resetPeakStats(device);
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
|
||||
return get()->snapshot(mempool_id);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t id) {
|
||||
return get()->getCheckpointState(device, id);
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) {
|
||||
return get()->setCheckpointPoolState(device, std::move(pps));
|
||||
}
|
||||
|
||||
inline void beginAllocateToPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
std::function<bool(hipStream_t)> filter) {
|
||||
get()->beginAllocateToPool(device, mempool_id, std::move(filter));
|
||||
}
|
||||
|
||||
inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
get()->endAllocateToPool(device, mempool_id);
|
||||
}
|
||||
|
||||
inline void recordHistory(
|
||||
bool enabled,
|
||||
HIPCachingAllocator::CreateContextFn context_recorder,
|
||||
size_t alloc_trace_max_entries,
|
||||
HIPCachingAllocator::RecordContext when,
|
||||
bool clearHistory) {
|
||||
return get()->recordHistory(
|
||||
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
|
||||
}
|
||||
|
||||
inline void recordAnnotation(
|
||||
const std::vector<std::pair<std::string, std::string>>& md) {
|
||||
return get()->recordAnnotation(md);
|
||||
}
|
||||
|
||||
inline void pushCompileContext(std::string& md) {
|
||||
return get()->pushCompileContext(md);
|
||||
}
|
||||
|
||||
inline void popCompileContext() {
|
||||
return get()->popCompileContext();
|
||||
}
|
||||
|
||||
inline bool isHistoryEnabled() {
|
||||
return get()->isHistoryEnabled();
|
||||
}
|
||||
|
||||
inline bool checkPoolLiveAllocations(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
const std::unordered_set<void*>& expected_live_allocations) {
|
||||
return get()->checkPoolLiveAllocations(
|
||||
device, mempool_id, expected_live_allocations);
|
||||
}
|
||||
|
||||
inline void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) {
|
||||
return get()->attachOutOfMemoryObserver(std::move(observer));
|
||||
}
|
||||
|
||||
inline void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) {
|
||||
return get()->attachAllocatorTraceTracker(std::move(tracker));
|
||||
}
|
||||
|
||||
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->releasePool(device, mempool_id);
|
||||
}
|
||||
|
||||
inline void createOrIncrefPool(
|
||||
c10::DeviceIndex device,
|
||||
MempoolId_t mempool_id,
|
||||
HIPCachingAllocator::HIPAllocator* allocator_ptr = nullptr) {
|
||||
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
|
||||
}
|
||||
|
||||
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
get()->setUseOnOOM(device, mempool_id);
|
||||
}
|
||||
|
||||
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
|
||||
return get()->getPoolUseCount(device, mempool_id);
|
||||
}
|
||||
|
||||
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
|
||||
return get()->getIpcDevPtr(std::move(handle));
|
||||
}
|
||||
|
||||
inline HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) {
|
||||
return get()->shareIpcHandle(ptr);
|
||||
}
|
||||
|
||||
inline std::string name() {
|
||||
return get()->name();
|
||||
}
|
||||
|
||||
inline hipError_t memcpyAsync(
|
||||
void* dst,
|
||||
int dstDevice,
|
||||
const void* src,
|
||||
int srcDevice,
|
||||
size_t count,
|
||||
hipStream_t stream,
|
||||
bool p2p_enabled) {
|
||||
return get()->memcpyAsync(
|
||||
dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
|
||||
}
|
||||
|
||||
inline void enablePeerAccess(
|
||||
c10::DeviceIndex dev,
|
||||
c10::DeviceIndex dev_to_access) {
|
||||
return get()->enablePeerAccess(dev, dev_to_access);
|
||||
}
|
||||
|
||||
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
|
||||
} // namespace hip
|
||||
} // namespace c10
|
@ -1,14 +0,0 @@
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
|
||||
// THIS IS A MASSIVE HACK. This will BREAK you Caffe2 CUDA code if you
|
||||
// load ATen_hip, even if you don't ever actually use ATen_hip at runtime.
|
||||
//
|
||||
// If you ever link ATen_hip statically into the full library along
|
||||
// with ATen_cuda (libomnibus), the loading order of this versus the regular
|
||||
// ATen_cuda will be nondeterministic, and you'll nondeterministically get
|
||||
// one or the other. (This will be obvious because all of your code
|
||||
// will fail.)
|
||||
//
|
||||
// This hack can be removed once PyTorch is out-of-place HIPified, and
|
||||
// doesn't pretend CUDA is HIP.
|
||||
C10_REGISTER_GUARD_IMPL(CUDA, at::cuda::HIPGuardImplMasqueradingAsCUDA)
|
@ -1,383 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/hip/HIPConfig.h>
|
||||
|
||||
// The includes of HIPGuard.h
|
||||
#include <c10/hip/impl/HIPGuardImpl.h>
|
||||
#include <c10/hip/HIPMacros.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/impl/InlineDeviceGuard.h>
|
||||
#include <c10/core/impl/InlineStreamGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/hip/impl/HIPGuardImpl.h>
|
||||
|
||||
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10 { namespace hip {
|
||||
|
||||
// Note [Masquerading as CUDA]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// c10_hip is very easy to understand: it is HIPified from c10_cuda,
|
||||
// and anywhere you said CUDA, the source code now says HIP. HIPified
|
||||
// PyTorch is much harder to understand: it is HIPified from regular
|
||||
// PyTorch, yes, but NO source-to-source translation from CUDA to
|
||||
// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
|
||||
// For example, when you use HIPified PyTorch, you say x.cuda() to
|
||||
// move a tensor onto ROCm device. We call this situation "HIP
|
||||
// masquerading as CUDA".
|
||||
//
|
||||
// This leads to a very awkward situation when we want to call c10_hip
|
||||
// code from PyTorch, since c10_hip is expecting things to be called
|
||||
// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
|
||||
// fix this impedance mismatch, we have MasqueradingAsCUDA variants
|
||||
// for all c10_hip classes. These translate between the "HIP" and "CUDA
|
||||
// masquerading as HIP" worlds. For example,
|
||||
// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
|
||||
// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
|
||||
// returns CUDA, getDevice() reports the current HIP device as a CUDA
|
||||
// device.)
|
||||
//
|
||||
// We should be able to delete all of these classes entirely once
|
||||
// we switch PyTorch to calling a HIP a HIP.
|
||||
//
|
||||
// When you add a new MasqueradingAsCUDA class/function, you need to
|
||||
// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
|
||||
//
|
||||
//
|
||||
//
|
||||
// By the way, note that the cpp file associated with this also
|
||||
// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
|
||||
// this HIP implementation.
|
||||
|
||||
struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
|
||||
HIPGuardImplMasqueradingAsCUDA() {}
|
||||
HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
|
||||
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
|
||||
}
|
||||
c10::DeviceType type() const override {
|
||||
return c10::DeviceType::CUDA;
|
||||
}
|
||||
Device exchangeDevice(Device d) const override {
|
||||
TORCH_INTERNAL_ASSERT(d.is_cuda());
|
||||
Device old_device = getDevice();
|
||||
if (old_device.index() != d.index()) {
|
||||
C10_HIP_CHECK(hipSetDevice(d.index()));
|
||||
}
|
||||
return old_device;
|
||||
}
|
||||
Device getDevice() const override {
|
||||
int device;
|
||||
C10_HIP_CHECK(hipGetDevice(&device));
|
||||
return Device(c10::DeviceType::CUDA, device);
|
||||
}
|
||||
void setDevice(Device d) const override {
|
||||
TORCH_INTERNAL_ASSERT(d.is_cuda());
|
||||
C10_HIP_CHECK(hipSetDevice(d.index()));
|
||||
}
|
||||
void uncheckedSetDevice(Device d) const noexcept override {
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
|
||||
}
|
||||
Stream getStream(Device d) const override {
|
||||
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
|
||||
}
|
||||
Stream getDefaultStream(Device d) const override {
|
||||
return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
|
||||
}
|
||||
Stream getNewStream(Device d, int priority = 0) const override {
|
||||
return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
|
||||
}
|
||||
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
|
||||
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
|
||||
}
|
||||
Stream exchangeStream(Stream s) const override {
|
||||
HIPStreamMasqueradingAsCUDA cs(s);
|
||||
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
|
||||
setCurrentHIPStreamMasqueradingAsCUDA(cs);
|
||||
return old_stream.unwrap();
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
int deviceCnt;
|
||||
hipError_t _err;
|
||||
_err = hipGetDeviceCount(&deviceCnt);
|
||||
if(_err != hipErrorNoDevice && _err != hipSuccess)
|
||||
C10_HIP_CHECK(_err);
|
||||
return deviceCnt;
|
||||
}
|
||||
|
||||
// Event-related functions
|
||||
// Note: hipEventCreateWithFlags should be called on the same device as
|
||||
// the recording stream's device.
|
||||
void createEvent(
|
||||
hipEvent_t* hip_event,
|
||||
const EventFlag flag) const {
|
||||
// Maps PyTorch's Event::Flag to HIP flag
|
||||
auto hip_flag = hipEventDefault;
|
||||
switch (flag) {
|
||||
case EventFlag::PYTORCH_DEFAULT:
|
||||
hip_flag = hipEventDisableTiming;
|
||||
break;
|
||||
case EventFlag::BACKEND_DEFAULT:
|
||||
hip_flag = hipEventDefault;
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "HIP event received unknown flag");
|
||||
}
|
||||
|
||||
C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
|
||||
}
|
||||
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override {
|
||||
if (!event) return;
|
||||
auto hip_event = static_cast<hipEvent_t>(event);
|
||||
int orig_device;
|
||||
C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(device_index));
|
||||
C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
|
||||
C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
|
||||
}
|
||||
|
||||
void record(void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
|
||||
"Event device index ",
|
||||
device_index,
|
||||
" does not match recording stream's device index ",
|
||||
stream.device_index(),
|
||||
".");
|
||||
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
|
||||
// Moves to stream's device to record
|
||||
const auto orig_device = getDevice();
|
||||
setDevice(stream.device());
|
||||
|
||||
// Creates the event (lazily)
|
||||
if (!hip_event) createEvent(&hip_event, flag);
|
||||
C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
|
||||
// Makes the void* point to the (possibly just allocated) HIP event
|
||||
*event = hip_event;
|
||||
|
||||
// Resets device
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
if (!event) return;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
const auto orig_device = getDevice();
|
||||
setDevice(stream.device());
|
||||
C10_HIP_CHECK(hipStreamWaitEvent(
|
||||
hip_stream,
|
||||
hip_event,
|
||||
/*flags (must be zero)=*/ 0));
|
||||
setDevice(orig_device);
|
||||
}
|
||||
|
||||
bool queryEvent(void* event) const override {
|
||||
if (!event) return true;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
const hipError_t err = hipEventQuery(hip_event);
|
||||
if (err != hipErrorNotReady) C10_HIP_CHECK(err);
|
||||
else {
|
||||
// ignore and clear the error if not ready
|
||||
(void)hipGetLastError();
|
||||
}
|
||||
return (err == hipSuccess);
|
||||
}
|
||||
|
||||
// Stream-related functions
|
||||
bool queryStream(const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
return hip_stream.query();
|
||||
}
|
||||
|
||||
void synchronizeStream(const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
hip_stream.synchronize();
|
||||
}
|
||||
|
||||
void synchronizeEvent(void* event) const override {
|
||||
if (!event)
|
||||
return;
|
||||
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
|
||||
C10_HIP_CHECK(hipEventSynchronize(hip_event));
|
||||
}
|
||||
|
||||
// Note: synchronizeDevice can be safely called from any device
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
int orig_device{-1};
|
||||
C10_HIP_CHECK(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK(hipSetDevice(device_index));
|
||||
C10_HIP_CHECK(hipDeviceSynchronize());
|
||||
C10_HIP_CHECK(hipSetDevice(orig_device));
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const Stream& stream) const override {
|
||||
HIPStreamMasqueradingAsCUDA hip_stream{stream};
|
||||
HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
|
||||
}
|
||||
|
||||
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
|
||||
const override {
|
||||
TORCH_CHECK(
|
||||
event1 && event2,
|
||||
"Both events must be recorded before calculating elapsed time.");
|
||||
int orig_device;
|
||||
C10_HIP_CHECK(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK(hipSetDevice(device_index));
|
||||
hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
|
||||
hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
|
||||
float time_ms = 0;
|
||||
// raise hipErrorNotReady if either event is recorded but not yet completed
|
||||
C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
|
||||
C10_HIP_CHECK(hipSetDevice(orig_device));
|
||||
return static_cast<double>(time_ms);
|
||||
}
|
||||
};
|
||||
|
||||
// All of the guards which have HIPGuardImpl burned in need to also have
|
||||
// variants using HIPGuardImplMasqueradingAsCUDA.
|
||||
|
||||
/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
|
||||
/// the correct InlineDeviceGuard burned in. Sorry about the
|
||||
/// copy-pasting.
|
||||
|
||||
struct HIPGuardMasqueradingAsCUDA {
|
||||
explicit HIPGuardMasqueradingAsCUDA() = delete;
|
||||
explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
|
||||
explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
|
||||
|
||||
HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct OptionalHIPGuardMasqueradingAsCUDA {
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {}
|
||||
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
|
||||
|
||||
OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void set_device(Device device) { guard_.set_device(device); }
|
||||
void reset_device(Device device) { guard_.reset_device(device); }
|
||||
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
|
||||
std::optional<Device> original_device() const { return guard_.original_device(); }
|
||||
std::optional<Device> current_device() const { return guard_.current_device(); }
|
||||
void reset() { guard_.reset(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct HIPStreamGuardMasqueradingAsCUDA {
|
||||
explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
|
||||
explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
|
||||
HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
|
||||
HIPStreamMasqueradingAsCUDA original_stream() const {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
|
||||
}
|
||||
HIPStreamMasqueradingAsCUDA current_stream() const {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
|
||||
}
|
||||
|
||||
Device current_device() const { return guard_.current_device(); }
|
||||
Device original_device() const { return guard_.original_device(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct OptionalHIPStreamGuardMasqueradingAsCUDA {
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
|
||||
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {}
|
||||
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
|
||||
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
|
||||
auto r = guard_.original_stream();
|
||||
if (r.has_value()) {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
|
||||
auto r = guard_.current_stream();
|
||||
if (r.has_value()) {
|
||||
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
void reset() { guard_.reset(); }
|
||||
|
||||
private:
|
||||
c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
};
|
||||
|
||||
struct HIPMultiStreamGuardMasqueradingAsCUDA {
|
||||
explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
|
||||
: guard_(unwrapStreams(streams)) {}
|
||||
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
|
||||
|
||||
private:
|
||||
c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
|
||||
|
||||
static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
|
||||
std::vector<Stream> streams;
|
||||
streams.reserve(hipStreams.size());
|
||||
for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
|
||||
streams.push_back(hipStream);
|
||||
}
|
||||
return streams;
|
||||
}
|
||||
};
|
||||
|
||||
}} // namespace c10::hip
|
@ -1,135 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/hip/HIPStream.h>
|
||||
|
||||
// Use of c10::hip namespace here makes hipification easier, because
|
||||
// I don't have to also fix namespaces. Sorry!
|
||||
namespace c10 { namespace hip {
|
||||
|
||||
// See Note [Masquerading as CUDA] for motivation
|
||||
|
||||
class HIPStreamMasqueradingAsCUDA {
|
||||
public:
|
||||
|
||||
enum Unchecked { UNCHECKED };
|
||||
|
||||
explicit HIPStreamMasqueradingAsCUDA(Stream stream)
|
||||
: HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
|
||||
// We did the coercion unchecked; check that it was right.
|
||||
TORCH_CHECK(stream.device().is_cuda() /* !!! */);
|
||||
}
|
||||
|
||||
explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
|
||||
// Unsafely coerce the "CUDA" stream into a HIP stream
|
||||
: stream_(
|
||||
HIPStream(
|
||||
Stream(
|
||||
Stream::UNSAFE,
|
||||
Device(c10::DeviceType::HIP, stream.device_index()),
|
||||
stream.id())
|
||||
)
|
||||
) {}
|
||||
|
||||
// New constructor, just for this. Does NOT coerce.
|
||||
explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
|
||||
|
||||
bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
|
||||
return stream_ == other.stream_;
|
||||
}
|
||||
|
||||
bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
|
||||
return stream_ != other.stream_;
|
||||
}
|
||||
|
||||
operator hipStream_t() const { return stream_.stream(); }
|
||||
|
||||
operator Stream() const {
|
||||
// Unsafely coerce HIP stream into a "CUDA" stream
|
||||
return Stream(Stream::UNSAFE, device(), id());
|
||||
}
|
||||
|
||||
DeviceIndex device_index() const { return stream_.device_index(); }
|
||||
|
||||
// Unsafely coerce HIP device into CUDA device
|
||||
c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
|
||||
|
||||
Device device() const {
|
||||
// Unsafely coerce HIP device into CUDA device
|
||||
return Device(c10::DeviceType::CUDA, stream_.device_index());
|
||||
}
|
||||
|
||||
StreamId id() const { return stream_.id(); }
|
||||
bool query() const { return stream_.query(); }
|
||||
void synchronize() const { stream_.synchronize(); }
|
||||
int priority() const { return stream_.priority(); }
|
||||
hipStream_t stream() const { return stream_.stream(); }
|
||||
|
||||
Stream unwrap() const {
|
||||
// Unsafely coerce HIP stream into "CUDA" stream
|
||||
return Stream(Stream::UNSAFE, device(), id());
|
||||
}
|
||||
|
||||
c10::StreamData3 pack3() const noexcept {
|
||||
// Unsafely coerce HIP stream into "CUDA" stream before packing
|
||||
return unwrap().pack3();
|
||||
}
|
||||
|
||||
static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
|
||||
DeviceIndex device_index,
|
||||
c10::DeviceType device_type) {
|
||||
// NB: constructor manages CUDA->HIP translation for us
|
||||
return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
|
||||
stream_id, device_index, device_type));
|
||||
}
|
||||
|
||||
static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
|
||||
|
||||
// New method, gets the underlying HIPStream
|
||||
HIPStream hip_stream() const { return stream_; }
|
||||
|
||||
private:
|
||||
HIPStream stream_;
|
||||
};
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
|
||||
}
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
|
||||
}
|
||||
|
||||
HIPStreamMasqueradingAsCUDA
|
||||
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
|
||||
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
|
||||
}
|
||||
|
||||
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
|
||||
}
|
||||
|
||||
inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
|
||||
return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
|
||||
}
|
||||
|
||||
inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
|
||||
setCurrentHIPStream(stream.hip_stream());
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
|
||||
stream << s.hip_stream() << " (masquerading as CUDA)";
|
||||
return stream;
|
||||
}
|
||||
|
||||
}} // namespace c10::hip
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
|
||||
size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
|
||||
return std::hash<c10::Stream>{}(s.unwrap());
|
||||
}
|
||||
};
|
||||
} // namespace std
|
@ -39,7 +39,7 @@ using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<
|
||||
|
||||
miopenHandle_t getMiopenHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::hip::GetDevice(&device));
|
||||
AT_CUDA_CHECK(at::cuda::GetDevice(&device));
|
||||
|
||||
// Thread local PoolWindows are lazily-initialized
|
||||
// to avoid initialization issues that caused hangs on Windows.
|
||||
@ -51,7 +51,7 @@ miopenHandle_t getMiopenHandle() {
|
||||
pool->newPoolWindow());
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream()));
|
||||
MIOPEN_CHECK(miopenSetStream(handle, at::cuda::getCurrentCUDAStream()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
|
@ -240,8 +240,8 @@ TORCH_META_FUNC(gelu_backward) (
|
||||
|
||||
namespace at::native {
|
||||
|
||||
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||
static const double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||
static constexpr double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||
static constexpr double SELU_SCALE = 1.0507009873554804934193349852946;
|
||||
|
||||
DEFINE_DISPATCH(elu_stub);
|
||||
DEFINE_DISPATCH(elu_backward_stub);
|
||||
|
@ -286,7 +286,7 @@ template void scal_fast_path<scalar_t>(int *n, scalar_t *a, scalar_t *x, int *in
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
template <>
|
||||
bool scal_use_fast_path<double>(int64_t n, int64_t incx) {
|
||||
auto intmax = std::numeric_limits<int>::max();
|
||||
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||
return n <= intmax && incx <= intmax;
|
||||
}
|
||||
|
||||
@ -315,7 +315,7 @@ bool gemv_use_fast_path<float>(
|
||||
int64_t incx,
|
||||
[[maybe_unused]] float beta,
|
||||
int64_t incy) {
|
||||
auto intmax = std::numeric_limits<int>::max();
|
||||
auto constexpr intmax = std::numeric_limits<int>::max();
|
||||
return (m <= intmax) && (n <= intmax) && (lda <= intmax) &&
|
||||
(incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax);
|
||||
}
|
||||
|
@ -1,12 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <ATen/native/Math.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/MathConstants.h>
|
||||
|
||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||
// ROCm hip compiler doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#endif
|
||||
#define compat_exp c10::cuda::compat::exp
|
||||
#define compat_ceil c10::cuda::compat::ceil
|
||||
#define compat_floor c10::cuda::compat::floor
|
||||
@ -16,17 +21,6 @@
|
||||
#define compat_tan c10::cuda::compat::tan
|
||||
#define compat_abs c10::cuda::compat::abs
|
||||
#define compat_log1p c10::cuda::compat::log1p
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#define compat_exp c10::hip::compat::exp
|
||||
#define compat_ceil c10::hip::compat::ceil
|
||||
#define compat_floor c10::hip::compat::floor
|
||||
#define compat_log c10::hip::compat::log
|
||||
#define compat_pow c10::hip::compat::pow
|
||||
#define compat_sqrt c10::hip::compat::sqrt
|
||||
#define compat_tan c10::hip::compat::tan
|
||||
#define compat_abs c10::hip::compat::abs
|
||||
#define compat_log1p c10::hip::compat::log1p
|
||||
#else
|
||||
#define compat_exp std::exp
|
||||
#define compat_ceil std::ceil
|
||||
@ -127,7 +121,7 @@ C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, unifor
|
||||
|
||||
template<typename scalar_t>
|
||||
C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
||||
const static scalar_t kTailValues[] = {
|
||||
constexpr static scalar_t kTailValues[] = {
|
||||
0.0810614667953272,
|
||||
0.0413406959554092,
|
||||
0.0276779256849983,
|
||||
@ -139,7 +133,7 @@ C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
|
||||
0.00925546218271273,
|
||||
0.00833056343336287
|
||||
};
|
||||
if (k <= 9) {
|
||||
if (k < std::size(kTailValues)) {
|
||||
return kTailValues[static_cast<size_t>(k)];
|
||||
}
|
||||
scalar_t kp1sq = (k + 1) * (k + 1);
|
||||
|
@ -581,7 +581,7 @@ scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
|
||||
template <typename scalar_t>
|
||||
static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
// lanczos approximation
|
||||
static const scalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
0.006061842346248906525783753964555936883222,
|
||||
0.5098416655656676188125178644804694509993,
|
||||
19.51992788247617482847860966235652136208,
|
||||
@ -596,7 +596,7 @@ static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
103794043.1163445451906271053616070238554,
|
||||
56906521.91347156388090791033559122686859
|
||||
};
|
||||
static const scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
1.,
|
||||
66.,
|
||||
1925.,
|
||||
@ -712,7 +712,7 @@ static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
||||
template <typename scalar_t>
|
||||
static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
|
||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||
static const scalar_t d[25][25] =
|
||||
static constexpr scalar_t d[25][25] =
|
||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
|
||||
1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
|
||||
3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
|
||||
|
@ -62,7 +62,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
static const int MIOPEN_DIM_MAX = 5;
|
||||
static constexpr int MIOPEN_DIM_MAX = 5;
|
||||
|
||||
namespace at::meta {
|
||||
|
||||
|
@ -52,13 +52,14 @@ inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
|
||||
#define MIN(X, Y) min_impl(X,Y)
|
||||
#endif
|
||||
|
||||
// ROCM hcc doesn't work well with using std:: in kernel functions
|
||||
// ROCm hip compiler doesn't work well with using std:: in kernel functions
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#define compat_pow c10::cuda::compat::pow
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#define compat_pow c10::hip::compat::pow
|
||||
#endif
|
||||
#define compat_pow c10::cuda::compat::pow
|
||||
#else
|
||||
#define compat_pow std::pow
|
||||
#endif
|
||||
|
@ -77,7 +77,7 @@ inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) {
|
||||
// next broadcast all index tensors together
|
||||
try {
|
||||
indices = expand_outplace(indices);
|
||||
} catch (std::exception& e) {
|
||||
} catch (std::exception&) {
|
||||
TORCH_CHECK_INDEX(
|
||||
false,
|
||||
"shape mismatch: indexing tensors could not be broadcast together"
|
||||
|
@ -1038,7 +1038,7 @@ struct HelperInterpNearest : public HelperInterpBase {
|
||||
// We keep this structure for BC and consider as deprecated.
|
||||
// See HelperInterpNearestExact as replacement
|
||||
|
||||
static const int interp_size = 1;
|
||||
static constexpr int interp_size = 1;
|
||||
|
||||
static inline void init_indices_weights(
|
||||
at::ScalarType output_type,
|
||||
@ -1155,7 +1155,7 @@ struct HelperInterpNearestExact : public HelperInterpNearest {
|
||||
|
||||
struct HelperInterpLinear : public HelperInterpBase {
|
||||
|
||||
static const int interp_size = 2;
|
||||
static constexpr int interp_size = 2;
|
||||
|
||||
// Compute indices and weights for each interpolated dimension
|
||||
// indices_weights = {
|
||||
@ -1275,7 +1275,7 @@ struct HelperInterpLinear : public HelperInterpBase {
|
||||
|
||||
struct HelperInterpCubic : public HelperInterpBase {
|
||||
|
||||
static const int interp_size = 4;
|
||||
static constexpr int interp_size = 4;
|
||||
|
||||
// Compute indices and weights for each interpolated dimension
|
||||
// indices_weights = {
|
||||
|
@ -1359,7 +1359,8 @@ _scaled_gemm(
|
||||
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
|
||||
const std::optional<Tensor>& bias,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
Tensor& out,
|
||||
const std::optional<Tensor>& alpha = std::nullopt) {
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
@ -1410,7 +1411,8 @@ _scaled_gemm(
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
use_fast_accum,
|
||||
alpha);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ __global__ void max_pool_forward_nhwc(
|
||||
}
|
||||
|
||||
|
||||
static const int BLOCK_THREADS = 256;
|
||||
static constexpr int BLOCK_THREADS = 256;
|
||||
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
#if defined (USE_ROCM)
|
||||
|
@ -36,9 +36,9 @@ namespace at::native {
|
||||
namespace {
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
static const int BLOCKDIMY = 16;
|
||||
static constexpr int BLOCKDIMY = 16;
|
||||
#else
|
||||
static const int BLOCKDIMY = 32;
|
||||
static constexpr int BLOCKDIMY = 32;
|
||||
#endif
|
||||
|
||||
template
|
||||
|
@ -82,7 +82,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
// lanczos approximation
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
|
||||
static const accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
constexpr accscalar_t lanczos_sum_expg_scaled_num[13] = {
|
||||
0.006061842346248906525783753964555936883222,
|
||||
0.5098416655656676188125178644804694509993,
|
||||
19.51992788247617482847860966235652136208,
|
||||
@ -97,7 +97,7 @@ __host__ __device__ scalar_t lanczos_sum_expg_scaled(scalar_t x) {
|
||||
103794043.1163445451906271053616070238554,
|
||||
56906521.91347156388090791033559122686859
|
||||
};
|
||||
static const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
constexpr accscalar_t lanczos_sum_expg_scaled_denom[13] = {
|
||||
1.,
|
||||
66.,
|
||||
1925.,
|
||||
@ -126,10 +126,10 @@ __host__ __device__ scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t ax, fac, res, num, numfac;
|
||||
static const accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t MAXLOG = std::is_same_v<accscalar_t,double> ?
|
||||
7.09782712893383996843E2 : 88.72283905206835;
|
||||
static const accscalar_t EXP1 = 2.718281828459045;
|
||||
static const accscalar_t lanczos_g = 6.024680040776729583740234375;
|
||||
constexpr accscalar_t EXP1 = 2.718281828459045;
|
||||
constexpr accscalar_t lanczos_g = 6.024680040776729583740234375;
|
||||
|
||||
if (::fabs(a - x) > 0.4 * ::fabs(a)) {
|
||||
ax = a * ::log(x) - x - ::lgamma(a);
|
||||
@ -158,9 +158,9 @@ __host__ __device__ scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
|
||||
// Compute igam using DLMF 8.11.4. [igam1]
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
static const int MAXITER = 2000;
|
||||
constexpr int MAXITER = 2000;
|
||||
|
||||
int i;
|
||||
accscalar_t ans, ax, c, r;
|
||||
@ -196,8 +196,8 @@ __host__ __device__ scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
|
||||
accscalar_t fac = 1;
|
||||
accscalar_t sum = 0;
|
||||
accscalar_t term, logx;
|
||||
static const int MAXITER = 2000;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr int MAXITER = 2000;
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
|
||||
for (n = 1; n < MAXITER; n++) {
|
||||
@ -219,7 +219,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
||||
// Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const accscalar_t d[25][25] =
|
||||
constexpr accscalar_t d[25][25] =
|
||||
{{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, -1.9752288294349443e-15},
|
||||
{-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, -4.13125571381061e-15},
|
||||
{4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, 8.8592218725911273e-15},
|
||||
@ -248,7 +248,7 @@ __host__ __device__ scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t
|
||||
|
||||
int k, n, sgn;
|
||||
int maxpow = 0;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
accscalar_t lambda = x / a;
|
||||
accscalar_t sigma = (x - a) / a;
|
||||
@ -314,12 +314,12 @@ __host__ __device__ scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar
|
||||
int i;
|
||||
accscalar_t ans, ax, c, yc, r, t, y, z;
|
||||
accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
|
||||
static const int MAXITER = 2000;
|
||||
static const accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
constexpr int MAXITER = 2000;
|
||||
constexpr accscalar_t MACHEP = std::is_same_v<accscalar_t, double> ?
|
||||
1.11022302462515654042E-16 : 5.9604644775390625E-8;
|
||||
static const accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t BIG = std::is_same_v<accscalar_t,double> ?
|
||||
4.503599627370496e15 : 16777216.;
|
||||
static const accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
||||
constexpr accscalar_t BIGINV = std::is_same_v<accscalar_t,double> ?
|
||||
2.22044604925031308085e-16 : 5.9604644775390625E-8;
|
||||
|
||||
ax = _igam_helper_fac(a, x);
|
||||
@ -385,10 +385,10 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t absxma_a;
|
||||
|
||||
static const accscalar_t SMALL = 20.0;
|
||||
static const accscalar_t LARGE = 200.0;
|
||||
static const accscalar_t SMALLRATIO = 0.3;
|
||||
static const accscalar_t LARGERATIO = 4.5;
|
||||
constexpr accscalar_t SMALL = 20.0;
|
||||
constexpr accscalar_t LARGE = 200.0;
|
||||
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||
constexpr accscalar_t LARGERATIO = 4.5;
|
||||
|
||||
if ((x < 0) || (a < 0)) {
|
||||
// out of defined-region of the function
|
||||
@ -467,10 +467,10 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
accscalar_t absxma_a;
|
||||
static const accscalar_t SMALL = 20.0;
|
||||
static const accscalar_t LARGE = 200.0;
|
||||
static const accscalar_t SMALLRATIO = 0.3;
|
||||
static const accscalar_t LARGERATIO = 4.5;
|
||||
constexpr accscalar_t SMALL = 20.0;
|
||||
constexpr accscalar_t LARGE = 200.0;
|
||||
constexpr accscalar_t SMALLRATIO = 0.3;
|
||||
constexpr accscalar_t LARGERATIO = 4.5;
|
||||
|
||||
// boundary values following SciPy
|
||||
if ((x < 0) || (a < 0)) {
|
||||
|
@ -231,7 +231,7 @@ const auto lcm_string = jiterator_stringify(
|
||||
const auto digamma_string = jiterator_stringify(
|
||||
template <typename T>
|
||||
T digamma(T x) {
|
||||
static const double PI_f64 = 3.14159265358979323846;
|
||||
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||
|
||||
// Short-circuits if x is +/- 0 and returns -/+ ∞ per the C++ standard
|
||||
if (x == 0) {
|
||||
@ -3072,9 +3072,9 @@ template <typename scalar_t>
|
||||
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
|
||||
// [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
|
||||
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
|
||||
static const double PI_f64 = 3.14159265358979323846;
|
||||
const accscalar_t PSI_10 = 2.25175258906672110764;
|
||||
const accscalar_t A[] = {
|
||||
static constexpr double PI_f64 = 3.14159265358979323846;
|
||||
constexpr accscalar_t PSI_10 = 2.25175258906672110764;
|
||||
constexpr accscalar_t A[] = {
|
||||
8.33333333333333333333E-2,
|
||||
-2.10927960927960927961E-2,
|
||||
7.57575757575757575758E-3,
|
||||
|
@ -1097,11 +1097,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
||||
// threads with different threadIdx.x are independent and will produce results for different outputs.
|
||||
// In such case, values in each loaded vector always correspond to different outputs.
|
||||
if (fastest_moving_stride == sizeof(scalar_t)) {
|
||||
#ifdef USE_ROCM
|
||||
if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
|
||||
#else
|
||||
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
|
||||
#endif
|
||||
// Case 1: "vectorize along input"
|
||||
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
|
||||
// we should avoid vectorization.
|
||||
|
@ -39,9 +39,14 @@ static void std_var_kernel_cuda(TensorIterator& iter, double correction, bool ta
|
||||
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
|
||||
void mean_kernel_impl(TensorIterator& iter) {
|
||||
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
|
||||
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
|
||||
using factor_t = typename c10::scalar_value_type<acc_t>::type;
|
||||
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
|
||||
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
||||
if constexpr (is_16_bits) {
|
||||
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
||||
} else {
|
||||
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
|
||||
}
|
||||
}
|
||||
|
||||
static void mean_kernel_cuda(TensorIterator& iter) {
|
||||
|
@ -13,24 +13,19 @@ namespace at::native {
|
||||
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
|
||||
struct sum_functor {
|
||||
void operator()(TensorIterator& iter) {
|
||||
#ifdef USE_ROCM
|
||||
// Half and BFloat16 can be packed in groups of up to 8 elements and
|
||||
// can use *_DWORDX4 instructions to achieve that.
|
||||
const bool is_16_bits =
|
||||
( (std::is_same<at::Half, scalar_t>::value) ||
|
||||
(std::is_same<at::BFloat16, scalar_t>::value) );
|
||||
if (is_16_bits) {
|
||||
const auto sum_combine = [] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return a + b;
|
||||
};
|
||||
constexpr bool is_16_bits = sizeof(scalar_t) == 2;
|
||||
if constexpr (is_16_bits) {
|
||||
gpu_reduce_kernel<scalar_t, out_t, /*vt0=*/4, /*input_vec_size=*/8>(
|
||||
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return a + b;
|
||||
}));
|
||||
return;
|
||||
iter, func_wrapper<out_t>(sum_combine)
|
||||
);
|
||||
} else {
|
||||
gpu_reduce_kernel<scalar_t, out_t>(
|
||||
iter, func_wrapper<out_t>(sum_combine)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
gpu_reduce_kernel<scalar_t, out_t>(
|
||||
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
|
||||
return a + b;
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -277,7 +277,7 @@ struct BilinearFilterFunctor {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const int size = 2;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
|
||||
// taken from
|
||||
@ -301,7 +301,7 @@ struct BicubicFilterFunctor {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const int size = 4;
|
||||
static constexpr int size = 4;
|
||||
};
|
||||
|
||||
template <typename accscalar_t>
|
||||
|
@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
|
||||
if constexpr (!rms_norm){
|
||||
U delta = val - curr_sum.mean;
|
||||
U new_count = curr_sum.count + 1.f;
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
|
||||
#else
|
||||
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
|
||||
#endif
|
||||
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
|
||||
} else{
|
||||
return {0.f, curr_sum.sigma2 + val * val, 0};
|
||||
@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
|
||||
U count = dataA.count + dataB.count;
|
||||
U mean, sigma2;
|
||||
if (count > decltype(dataB.count){0}) {
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
auto coef = __builtin_amdgcn_rcpf(count);
|
||||
#else
|
||||
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
|
||||
#endif
|
||||
auto nA = dataA.count * coef;
|
||||
auto nB = dataB.count * coef;
|
||||
mean = nA*dataA.mean + nB*dataB.mean;
|
||||
|
@ -157,7 +157,7 @@ void bgemm_kernel_impl(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <numeric>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#include <ATen/native/hip/ck_types.h>
|
||||
|
||||
@ -233,7 +232,7 @@ void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
}
|
||||
|
||||
@ -391,7 +390,7 @@ void gemm_impl_wmma(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
}
|
||||
|
||||
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
#if 1
|
||||
invoker.Run(argument, StreamConfig{stream, false});
|
||||
#else
|
||||
|
@ -278,14 +278,14 @@ BenchmarkCache<size_t> bwd_filter_wssizes;
|
||||
|
||||
struct Workspace {
|
||||
Workspace(size_t size) : size(size), data(NULL) {
|
||||
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
|
||||
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
|
||||
}
|
||||
Workspace(const Workspace&) = delete;
|
||||
Workspace(Workspace&&) = default;
|
||||
Workspace& operator=(Workspace&&) = default;
|
||||
~Workspace() {
|
||||
if (data) {
|
||||
c10::hip::HIPCachingAllocator::raw_delete(data);
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -587,7 +587,7 @@ void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
|
||||
wsscache.insert(args.params, perfResults.memory);
|
||||
|
||||
if (at::native::_cudnn_get_conv_benchmark_empty_cache()) {
|
||||
c10::hip::HIPCachingAllocator::emptyCache();
|
||||
c10::cuda::CUDACachingAllocator::emptyCache();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -76,14 +76,14 @@ namespace {
|
||||
|
||||
struct DropoutState {
|
||||
DropoutState(size_t size) : size(size), data(NULL) {
|
||||
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
|
||||
data = c10::cuda::CUDACachingAllocator::raw_alloc(size);
|
||||
}
|
||||
DropoutState(const DropoutState&) = delete;
|
||||
DropoutState(DropoutState&&) = default;
|
||||
DropoutState& operator=(DropoutState&&) = default;
|
||||
~DropoutState() {
|
||||
if (data) {
|
||||
c10::hip::HIPCachingAllocator::raw_delete(data);
|
||||
c10::cuda::CUDACachingAllocator::raw_delete(data);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -416,7 +416,7 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||
// else if dim = 3, mat1's size = (b * m * n), mat2's size = (b * n * k)
|
||||
// else called from aten::mv, mat1.size = (m * n), mat2.size = (n)
|
||||
// only m * n * b * k(if exist) are large enough we can get benefit from mkldnn optimized gemm kernel
|
||||
static const int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
||||
constexpr int64_t mkldnn_gemm_min_size = 16 * 16 * 16;
|
||||
if (mat1.dim() == 1 && mat2.dim() == 1) {
|
||||
// aten::dot
|
||||
return mat1.size(0) > mkldnn_gemm_min_size;
|
||||
|
@ -1,16 +1,16 @@
|
||||
#pragma once
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeSharedParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatSharedParams {
|
||||
int32_t ndim;
|
||||
int32_t cat_dim;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
};
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim, typename idx_type_t = int64_t>
|
||||
struct CatLargeInputParams {
|
||||
template <typename idx_type_t = int64_t, unsigned N = c10::metal::max_ndim>
|
||||
struct CatInputParams {
|
||||
idx_type_t cat_dim_offset;
|
||||
idx_type_t input_element_offset;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
|
@ -6,26 +6,25 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T_in, typename T_out>
|
||||
kernel void cat_large(
|
||||
template <typename I, typename T_in, typename T_out>
|
||||
kernel void cat(
|
||||
constant T_in* input [[buffer(0)]],
|
||||
device T_out* output [[buffer(1)]],
|
||||
constant CatLargeSharedParams<>& shared_params [[buffer(2)]],
|
||||
constant CatLargeInputParams<>& input_params [[buffer(3)]],
|
||||
constant CatSharedParams<I>& shared_params [[buffer(2)]],
|
||||
constant CatInputParams<I>& input_params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto ndim = shared_params.ndim;
|
||||
auto cat_dim = shared_params.cat_dim;
|
||||
constant auto& output_strides = shared_params.output_strides;
|
||||
constant auto& output_sizes = shared_params.output_sizes;
|
||||
|
||||
auto cat_dim_offset = input_params.cat_dim_offset;
|
||||
auto input_element_offset = input_params.input_element_offset;
|
||||
constant auto& input_strides = input_params.input_strides;
|
||||
constant auto& input_sizes = input_params.input_sizes;
|
||||
|
||||
auto input_element_idx = static_cast<int64_t>(tid) + input_element_offset;
|
||||
int64_t input_offset = 0;
|
||||
int64_t output_offset = 0;
|
||||
auto input_element_idx = static_cast<I>(tid) + input_element_offset;
|
||||
I input_offset = 0;
|
||||
I output_offset = 0;
|
||||
|
||||
for (auto dim = ndim - 1; dim >= 0; dim--) {
|
||||
auto dim_size = input_sizes[dim];
|
||||
@ -42,41 +41,45 @@ kernel void cat_large(
|
||||
output[output_offset] = static_cast<T_out>(input[input_offset]);
|
||||
}
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP(T_in, T_out) \
|
||||
template [[host_name("cat_large_" #T_in "_" #T_out)]] \
|
||||
kernel void cat_large<T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatLargeSharedParams<> & shared_params [[buffer(2)]], \
|
||||
constant CatLargeInputParams<> & input_params [[buffer(3)]], \
|
||||
#define REGISTER_CAT_OP(I, T_in, T_out) \
|
||||
template [[host_name("cat_" #I "_" #T_in "_" #T_out)]] \
|
||||
kernel void cat<I, T_in, T_out>( \
|
||||
constant T_in * input [[buffer(0)]], \
|
||||
device T_out * output [[buffer(1)]], \
|
||||
constant CatSharedParams<I> & shared_params [[buffer(2)]], \
|
||||
constant CatInputParams<I> & input_params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(T_out) \
|
||||
REGISTER_CAT_LARGE_OP(float, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(half, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bfloat, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(int, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uint, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(long, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ulong, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(short, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(ushort, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(char, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(uchar, T_out); \
|
||||
REGISTER_CAT_LARGE_OP(bool, T_out);
|
||||
#define REGISTER_CAT_OP_ALL_INPUT_TYPES(I, T_out) \
|
||||
REGISTER_CAT_OP(I, float, T_out); \
|
||||
REGISTER_CAT_OP(I, half, T_out); \
|
||||
REGISTER_CAT_OP(I, bfloat, T_out); \
|
||||
REGISTER_CAT_OP(I, int, T_out); \
|
||||
REGISTER_CAT_OP(I, uint, T_out); \
|
||||
REGISTER_CAT_OP(I, long, T_out); \
|
||||
REGISTER_CAT_OP(I, ulong, T_out); \
|
||||
REGISTER_CAT_OP(I, short, T_out); \
|
||||
REGISTER_CAT_OP(I, ushort, T_out); \
|
||||
REGISTER_CAT_OP(I, char, T_out); \
|
||||
REGISTER_CAT_OP(I, uchar, T_out); \
|
||||
REGISTER_CAT_OP(I, bool, T_out);
|
||||
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(float);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(half);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bfloat);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(int);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uint);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(long);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ulong);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(short);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(ushort);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(char);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(uchar);
|
||||
REGISTER_CAT_LARGE_OP_ALL_INPUT_TYPES(bool);
|
||||
#define REGISTER_CAT_FOR_INDEX_TYPE(I) \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, float); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, half); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bfloat); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, int); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uint); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, long); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ulong); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, short); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, ushort); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, char); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, uchar); \
|
||||
REGISTER_CAT_OP_ALL_INPUT_TYPES(I, bool); \
|
||||
\
|
||||
REGISTER_CAT_OP(I, float2, float2); \
|
||||
REGISTER_CAT_OP(I, half2, half2);
|
||||
|
||||
REGISTER_CAT_LARGE_OP(float2, float2);
|
||||
REGISTER_CAT_LARGE_OP(half2, half2);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int64_t);
|
||||
REGISTER_CAT_FOR_INDEX_TYPE(int32_t);
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/TensorShape.h>
|
||||
#include <ATen/native/TypeProperties.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
@ -69,29 +70,40 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in
|
||||
}
|
||||
}
|
||||
|
||||
// This implementation of cat is used only if one of the inputs or the output is
|
||||
// too large to use MPSGraph.
|
||||
template <typename T>
|
||||
std::string get_type_str();
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int64_t>() {
|
||||
return "int64_t";
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string get_type_str<int32_t>() {
|
||||
return "int32_t";
|
||||
}
|
||||
|
||||
// NOTE: `output` is expected to already have the correct size.
|
||||
static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatLargeSharedParams shared_params;
|
||||
template <typename idx_type_t>
|
||||
static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) {
|
||||
CatSharedParams<idx_type_t> shared_params;
|
||||
|
||||
shared_params.ndim = output.dim();
|
||||
shared_params.cat_dim = dimension;
|
||||
|
||||
for (const auto dim : c10::irange(output.dim())) {
|
||||
shared_params.output_strides[dim] = output.stride(dim);
|
||||
shared_params.output_sizes[dim] = output.size(dim);
|
||||
shared_params.output_strides[dim] = safe_downcast<idx_type_t, int64_t>(output.stride(dim));
|
||||
shared_params.output_sizes[dim] = safe_downcast<idx_type_t, int64_t>(output.size(dim));
|
||||
}
|
||||
|
||||
int64_t cat_dim_offset = 0;
|
||||
idx_type_t cat_dim_offset = 0;
|
||||
size_t input_idx = 0;
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
// Launch a separate kernels for each input. This will produce some overhead,
|
||||
// but that should be relatively minimal since at least one of the inputs is
|
||||
// very large. In order to launch only one kernel to process all inputs, we
|
||||
// would have to copy all the input tensor data into a packed buffer, which
|
||||
// would not be ideal.
|
||||
// Launch a separate kernels for each input. This will produce some overhead.
|
||||
// In order to launch only one kernel to process all inputs, we would have to
|
||||
// copy all the input tensor data into a packed buffer, which would not be
|
||||
// ideal.
|
||||
for (const Tensor& input : inputs) {
|
||||
if (input.numel() == 0) {
|
||||
continue;
|
||||
@ -104,21 +116,23 @@ static void cat_out_large_tensor_mps(const ITensorListRef& inputs, int64_t dimen
|
||||
|
||||
for (int64_t numel_remaining = input.numel(); numel_remaining > 0; numel_remaining -= max_num_threads) {
|
||||
auto num_threads = std::min(max_num_threads, numel_remaining);
|
||||
CatLargeInputParams input_params;
|
||||
CatInputParams<idx_type_t> input_params;
|
||||
|
||||
input_params.cat_dim_offset = cat_dim_offset;
|
||||
input_params.input_element_offset = input.numel() - numel_remaining;
|
||||
input_params.cat_dim_offset = safe_downcast<idx_type_t, int64_t>(cat_dim_offset);
|
||||
input_params.input_element_offset = safe_downcast<idx_type_t, int64_t>(input.numel() - numel_remaining);
|
||||
|
||||
for (const auto dim : c10::irange(input.dim())) {
|
||||
input_params.input_strides[dim] = input.stride(dim);
|
||||
input_params.input_sizes[dim] = input.size(dim);
|
||||
input_params.input_strides[dim] = safe_downcast<idx_type_t, int64_t>(input.stride(dim));
|
||||
input_params.input_sizes[dim] = safe_downcast<idx_type_t, int64_t>(input.size(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(
|
||||
fmt::format("cat_large_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output)));
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("cat_{}_{}_{}",
|
||||
get_type_str<idx_type_t>(),
|
||||
scalarToMetalTypeString(input),
|
||||
scalarToMetalTypeString(output)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "cat", {input});
|
||||
[computeEncoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(computeEncoder, input, output, shared_params, input_params);
|
||||
@ -294,13 +308,6 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
" and out is on ",
|
||||
out.device());
|
||||
|
||||
// TODO: For better performance by eliminating input tensor gathering and post transpose,
|
||||
// TODO: it is better to keep the out tensor's memory format.
|
||||
// TODO: dimension needs to be recomputed as:
|
||||
// TODO: dim = 0 --> dim = 0; dim = 1 or 2 --> dim = out.dim()- dim; otherwise dim = dim-1
|
||||
if (needsGather(out)) {
|
||||
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
std::vector<int64_t> size(notSkippedTensor.sizes().vec());
|
||||
|
||||
// Compute size of the result in the cat dimension
|
||||
@ -331,82 +338,9 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
||||
has_large_tensor |= isTooLargeForMPSGraph(out);
|
||||
|
||||
if (has_large_tensor) {
|
||||
return mps::cat_out_large_tensor_mps(materialized_inputs, dimension, out);
|
||||
}
|
||||
|
||||
struct CachedGraph : public MPSCachedGraph {
|
||||
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
std::vector<MPSGraphTensor*> inputTensors_;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
@autoreleasepool {
|
||||
std::string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
|
||||
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
|
||||
if (!all_same_dtype) {
|
||||
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
|
||||
} else {
|
||||
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
|
||||
}
|
||||
for (auto idx : skipped_tensor_indices) {
|
||||
key += "," + std::to_string(idx);
|
||||
}
|
||||
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
|
||||
std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
|
||||
newCachedGraph->inputTensors_.reserve(len_tensor_array);
|
||||
|
||||
for (const auto idx : c10::irange(len_tensor_array)) {
|
||||
const Tensor& tensor = input_tensors[idx];
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
|
||||
if (tensor.scalar_type() != out_dtype) {
|
||||
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
|
||||
toType:getMPSDataType(out_dtype)
|
||||
name:@"castInput"];
|
||||
} else {
|
||||
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
|
||||
}
|
||||
}
|
||||
|
||||
auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data() count:len_tensor_array];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
|
||||
dimension:dimension // Maybe convert this from int64_t -> int32
|
||||
name:nil];
|
||||
if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"];
|
||||
}
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
std::vector<Placeholder> inputPlaceholders;
|
||||
int i = 0;
|
||||
int t_idx = 0;
|
||||
for (const Tensor& tensor : materialized_inputs) {
|
||||
if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
|
||||
auto scalar_type = getMPSScalarType(tensor.scalar_type());
|
||||
if (tensor.scalar_type() == kBool) {
|
||||
scalar_type = MPSDataTypeInt8;
|
||||
}
|
||||
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, nullptr, true, scalar_type);
|
||||
t_idx++;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
auto outputDataType = getMPSScalarType(out.scalar_type());
|
||||
Placeholder outputPlaceholder =
|
||||
Placeholder(cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
|
||||
}
|
||||
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
return mps::cat_out_mps_impl<int64_t>(materialized_inputs, dimension, out);
|
||||
} else {
|
||||
return mps::cat_out_mps_impl<int32_t>(materialized_inputs, dimension, out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6531,6 +6531,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: var
|
||||
MPS: var_mps
|
||||
MTIA: var_mtia
|
||||
tags: core
|
||||
|
||||
- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
@ -3551,7 +3551,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
|
||||
|
||||
#if defined(__ARM_NEON__) || defined(__aarch64__)
|
||||
|
||||
const static int PARALLEL_THRESHOLD = 1 << 20;
|
||||
constexpr static int PARALLEL_THRESHOLD = 1 << 20;
|
||||
|
||||
// Generic template defaults to naive quantize implementation
|
||||
template <typename T>
|
||||
|
@ -1388,7 +1388,7 @@ namespace at::native {
|
||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
||||
static std::optional<at::Tensor> other = std::nullopt;
|
||||
static const std::string_view binary_post_op = "none";
|
||||
constexpr std::string_view binary_post_op = "none";
|
||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
||||
return linear_int8_with_onednn_weight(
|
||||
act, act_scale.item().toDouble(), act_zp,
|
||||
|
@ -16,8 +16,8 @@ namespace {
|
||||
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
|
||||
const static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
||||
const static int qnnpack_softmax_output_zero_point = 0;
|
||||
constexpr static float qnnpack_softmax_output_scale = 0x1.0p-8f;
|
||||
constexpr static int qnnpack_softmax_output_zero_point = 0;
|
||||
|
||||
bool is_qnnpack_compatible(
|
||||
const Tensor& qx,
|
||||
|
@ -59,8 +59,6 @@
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/unique.h>
|
||||
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
|
||||
namespace at::native {
|
||||
namespace {
|
||||
|
||||
|
@ -110,9 +110,9 @@ class ApplyLogSumExp {
|
||||
using ElementCompute = ElementCompute_;
|
||||
using ElementLSE = ElementLSE_;
|
||||
|
||||
static int const kElementsPerAccess = ElementsPerAccess;
|
||||
static int const kCount = kElementsPerAccess;
|
||||
static const ScaleType::Kind kScale =
|
||||
static int constexpr kElementsPerAccess = ElementsPerAccess;
|
||||
static int constexpr kCount = kElementsPerAccess;
|
||||
static constexpr ScaleType::Kind kScale =
|
||||
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
using FragmentOutput = Array<ElementOutput, kCount>;
|
||||
|
@ -37,7 +37,6 @@
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/HIPGraphsUtils.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
@ -162,7 +161,7 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_) {
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
@ -348,8 +347,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
|
||||
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
@ -560,8 +559,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
const at::Tensor& philox_offset) {
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
@ -793,8 +792,8 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
check_gpu_arch(stream);
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
|
@ -261,7 +261,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -365,7 +365,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
}
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
@ -261,7 +261,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
@ -299,7 +299,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
|
||||
}
|
||||
@ -317,7 +317,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
|
@ -255,7 +255,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
@ -366,7 +366,7 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea
|
||||
}
|
||||
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
@ -273,7 +273,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
bool has_lse = true;
|
||||
@ -307,7 +307,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
auto philox_args = gen->philox_cuda_state(counter_offset);
|
||||
hipLaunchKernelGGL(
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
|
||||
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr);
|
||||
}
|
||||
|
||||
// remove const from attn_bias_
|
||||
@ -320,7 +320,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
|
||||
auto stream = at::cuda::getCurrentHIPStream().stream();
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
|
||||
auto traits =
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include <ATen/TensorIndexing.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
|
||||
#include <ATen/hip/HIPGraphsUtils.cuh>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
@ -14,16 +14,16 @@ using namespace at;
|
||||
|
||||
namespace {
|
||||
|
||||
const auto int_min = std::numeric_limits<int>::min();
|
||||
const auto int_max = std::numeric_limits<int>::max();
|
||||
const auto long_min = std::numeric_limits<int64_t>::min();
|
||||
const auto long_max = std::numeric_limits<int64_t>::max();
|
||||
const auto float_lowest = std::numeric_limits<float>::lowest();
|
||||
const auto float_min = std::numeric_limits<float>::min();
|
||||
const auto float_max = std::numeric_limits<float>::max();
|
||||
const auto double_lowest = std::numeric_limits<double>::lowest();
|
||||
const auto double_min = std::numeric_limits<double>::min();
|
||||
const auto double_max = std::numeric_limits<double>::max();
|
||||
constexpr auto int_min = std::numeric_limits<int>::min();
|
||||
constexpr auto int_max = std::numeric_limits<int>::max();
|
||||
constexpr auto long_min = std::numeric_limits<int64_t>::min();
|
||||
constexpr auto long_max = std::numeric_limits<int64_t>::max();
|
||||
constexpr auto float_lowest = std::numeric_limits<float>::lowest();
|
||||
constexpr auto float_min = std::numeric_limits<float>::min();
|
||||
constexpr auto float_max = std::numeric_limits<float>::max();
|
||||
constexpr auto double_lowest = std::numeric_limits<double>::lowest();
|
||||
constexpr auto double_min = std::numeric_limits<double>::min();
|
||||
constexpr auto double_max = std::numeric_limits<double>::max();
|
||||
|
||||
const std::vector<int> ints {
|
||||
int_min,
|
||||
|
@ -146,9 +146,9 @@ uint64_t XPUGeneratorImpl::seed() {
|
||||
|
||||
c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
// The RNG state comprises the seed, and an offset used for Philox.
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(uint64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
// The internal state is returned as a CPU byte tensor.
|
||||
auto state_tensor = at::detail::empty_cpu(
|
||||
@ -170,9 +170,9 @@ c10::intrusive_ptr<c10::TensorImpl> XPUGeneratorImpl::get_state() const {
|
||||
void XPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
|
||||
at::xpu::assertNotCapturing(
|
||||
"Please ensure to utilize the XPUGeneratorImpl::set_state_index method during capturing.");
|
||||
static const size_t seed_size = sizeof(uint64_t);
|
||||
static const size_t offset_size = sizeof(uint64_t);
|
||||
static const size_t total_size = seed_size + offset_size;
|
||||
constexpr size_t seed_size = sizeof(uint64_t);
|
||||
constexpr size_t offset_size = sizeof(uint64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
at::detail::check_rng_state(new_state);
|
||||
|
||||
|
@ -6,7 +6,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
|
@ -4060,7 +4060,7 @@ def run(runner, args, original_dir=None):
|
||||
else:
|
||||
optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
|
||||
experiment = (
|
||||
speedup_experiment if not args.backend == "torchao" else latency_experiment
|
||||
speedup_experiment if args.backend != "torchao" else latency_experiment
|
||||
)
|
||||
if args.accuracy:
|
||||
output_filename = f"accuracy_{args.backend}.csv"
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, NamedTuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd import functional
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
@ -1,5 +1,6 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
|
||||
all_experiments: dict[str, Callable] = {}
|
||||
|
@ -9,8 +9,9 @@ import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -7,6 +7,7 @@ from pt import ( # noqa: F401
|
||||
binary_inplace_test,
|
||||
binary_test,
|
||||
bmm_test,
|
||||
boolean_test,
|
||||
cat_test,
|
||||
channel_shuffle_test,
|
||||
chunk_test,
|
||||
|
@ -56,6 +56,9 @@ binary_ops_list = op_bench.op_list(
|
||||
["sub", torch.sub],
|
||||
["div", torch.div],
|
||||
["mul", torch.mul],
|
||||
["asr", torch.bitwise_right_shift],
|
||||
["lsl", torch.bitwise_left_shift],
|
||||
["xor", torch.bitwise_xor],
|
||||
],
|
||||
)
|
||||
|
||||
|
73
benchmarks/operator_benchmark/pt/boolean_test.py
Normal file
73
benchmarks/operator_benchmark/pt/boolean_test.py
Normal file
@ -0,0 +1,73 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for boolean operators. Supports both Caffe2/PyTorch."""
|
||||
|
||||
# Configs for PT all operator
|
||||
all_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128], N=[32, 64], K=[256, 512], device=["cpu", "cuda"], tags=["long"]
|
||||
)
|
||||
|
||||
|
||||
all_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
class AllBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.randint(0, 2, (M, N, K), device=device, dtype=torch.bool)
|
||||
}
|
||||
self.set_module_name("all")
|
||||
|
||||
def forward(self, input_one):
|
||||
return torch.all(input_one)
|
||||
|
||||
|
||||
# The generated test names based on all_short_configs will be in the following pattern:
|
||||
# all_M8_N16_K32_devicecpu
|
||||
# all_M8_N16_K32_devicecpu_bwdall
|
||||
# all_M8_N16_K32_devicecpu_bwd1
|
||||
# all_M8_N16_K32_devicecpu_bwd2
|
||||
# ...
|
||||
# Those names can be used to filter tests.
|
||||
|
||||
op_bench.generate_pt_test(all_long_configs + all_short_configs, AllBenchmark)
|
||||
|
||||
"""Mircobenchmark for any operator."""
|
||||
|
||||
|
||||
class AnyBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device):
|
||||
self.inputs = {
|
||||
"input_one": torch.randint(0, 2, (M, N), device=device, dtype=torch.bool)
|
||||
}
|
||||
self.set_module_name("any")
|
||||
|
||||
def forward(self, input_one):
|
||||
return torch.any(input_one)
|
||||
|
||||
|
||||
any_configs = op_bench.cross_product_configs(
|
||||
M=[8, 256],
|
||||
N=[256, 16],
|
||||
device=["cpu", "cuda"],
|
||||
tags=["any"],
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(any_configs, AnyBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -38,12 +38,16 @@ class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase):
|
||||
op_bench.generate_pt_test(
|
||||
configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark
|
||||
)
|
||||
op_bench.generate_pt_test(
|
||||
configs.convtranspose_1d_configs_short
|
||||
+ configs.conv_1d_configs_short
|
||||
+ configs.conv_1d_configs_long,
|
||||
ConvTranspose1dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
if not torch.backends.mkldnn.is_acl_available():
|
||||
# convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654
|
||||
op_bench.generate_pt_test(
|
||||
configs.convtranspose_1d_configs_short
|
||||
+ configs.conv_1d_configs_short
|
||||
+ configs.conv_1d_configs_long,
|
||||
ConvTranspose1dBenchmark,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
|
@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
@ -3,10 +3,11 @@ import csv
|
||||
import itertools
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
@ -270,7 +271,7 @@ def run_single_backend_sdpa(
|
||||
|
||||
if config.calculate_bwd_time:
|
||||
# TODO: debug backward pass for njt
|
||||
if eager_sdpa and not config.attn_type == "document_mask":
|
||||
if eager_sdpa and config.attn_type != "document_mask":
|
||||
d_out = torch.randn_like(out_eager.transpose(1, 2)).transpose(1, 2)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, d_out, retain_graph=True
|
||||
|
@ -1,8 +1,8 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable
|
||||
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
@ -1729,8 +1729,10 @@ def define_buck_targets(
|
||||
"torch/csrc/jit/backends/backend_debug_info.cpp",
|
||||
"torch/csrc/jit/backends/backend_interface.cpp",
|
||||
],
|
||||
compiler_flags = get_pt_compiler_flags(),
|
||||
fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
|
||||
compiler_flags = get_pt_compiler_flags() + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags
|
||||
}),
|
||||
# @lint-ignore BUCKLINT link_whole
|
||||
link_whole = True,
|
||||
linker_flags = get_no_as_needed_linker_flag(),
|
||||
@ -2023,6 +2025,9 @@ def define_buck_targets(
|
||||
"ovr_config//os:android-x86_64": [
|
||||
"-mssse3",
|
||||
],
|
||||
}) + select({
|
||||
"DEFAULT": [],
|
||||
"ovr_config//os:android": c2_fbandroid_xplat_compiler_flags,
|
||||
}),
|
||||
exported_preprocessor_flags = get_aten_preprocessor_flags(),
|
||||
exported_deps = [
|
||||
|
@ -119,8 +119,9 @@ class C10_API DataPtr {
|
||||
}
|
||||
// Unsafely mutates the device on a DataPtr. Under normal use,
|
||||
// you should never actually need to call this function.
|
||||
// We need this for the implementation of the hack detailed
|
||||
// in Note [Masquerading as CUDA]
|
||||
// We used to need this for the implementation of the hack detailed
|
||||
// in Note [Masquerading as CUDA], but that hack has been removed.
|
||||
// Other uses of this function now exist so it cannot be deprecated.
|
||||
void unsafe_set_device(Device device) {
|
||||
device_ = device;
|
||||
}
|
||||
|
@ -4,12 +4,13 @@
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <cmath>
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
|
||||
#elif defined(__HIPCC__)
|
||||
#include <c10/hip/HIPMathCompat.h>
|
||||
#define C10_COMPAT_COPYSIGN c10::hip::compat::copysign
|
||||
#endif
|
||||
#define C10_COMPAT_COPYSIGN c10::cuda::compat::copysign
|
||||
#else
|
||||
#include <c10/util/copysign.h>
|
||||
#define C10_COMPAT_COPYSIGN c10::copysign
|
||||
|
@ -120,17 +120,23 @@ inline void initGlobalDevicePoolState() {
|
||||
TORCH_CHECK(
|
||||
gDevicePool.devices.size() <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many XPU devices, DeviceIndex overflowed!");
|
||||
|
||||
#if defined(_WIN32) && SYCL_COMPILER_VERSION < 20250000
|
||||
// The default context feature is disabled by default on Windows for SYCL
|
||||
// compiler versions earlier than 2025.0.0.
|
||||
std::vector<sycl::device> deviceList;
|
||||
for (auto it = gDevicePool.devices.begin(); it != gDevicePool.devices.end();
|
||||
++it) {
|
||||
deviceList.push_back(*(*it));
|
||||
// Check each device's architecture and issue a warning if it is older than
|
||||
// the officially supported range (Intel GPUs starting from Arc (Alchemist)
|
||||
// series).
|
||||
namespace syclex = sycl::ext::oneapi::experimental;
|
||||
for (const auto& device : gDevicePool.devices) {
|
||||
auto architecture = device->get_info<syclex::info::device::architecture>();
|
||||
if (architecture < syclex::architecture::intel_gpu_acm_g10) {
|
||||
TORCH_WARN(
|
||||
"The detected GPU (",
|
||||
device->get_info<sycl::info::device::name>(),
|
||||
") is not officially supported by PyTorch XPU. Running workloads on this device may result in unexpected behavior.\n",
|
||||
"For stable and fully supported execution, please use GPUs based on Intel Arc (Alchemist) series or newer.\n",
|
||||
"Refer to the hardware prerequisites for more information: ",
|
||||
"https://github.com/pytorch/pytorch/blob/main/docs/source/notes/get_start_xpu.rst#hardware-prerequisite");
|
||||
}
|
||||
}
|
||||
gDevicePool.context = std::make_unique<sycl::context>(deviceList);
|
||||
#else
|
||||
|
||||
// The default context is utilized for each Intel GPU device, allowing the
|
||||
// retrieval of the context from any GPU device.
|
||||
const auto& platform = gDevicePool.devices[0]->get_platform();
|
||||
@ -140,7 +146,6 @@ inline void initGlobalDevicePoolState() {
|
||||
#else
|
||||
platform.ext_oneapi_get_default_context());
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void initDevicePoolCallOnce() {
|
||||
@ -165,9 +170,9 @@ void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
|
||||
#define ASSIGN_DEVICE_ASPECT(member) \
|
||||
device_prop->has_##member = raw_device.has(sycl::aspect::member);
|
||||
|
||||
#define ASSIGN_EXP_CL_ASPECT(member) \
|
||||
device_prop->has_##member = raw_device.ext_oneapi_supports_cl_extension( \
|
||||
"cl_intel_" #member, &cl_version);
|
||||
#define ASSIGN_EXP_CL_ASPECT(member) \
|
||||
device_prop->has_##member = \
|
||||
raw_device.ext_oneapi_supports_cl_extension("cl_intel_" #member);
|
||||
|
||||
#define ASSIGN_EXP_DEVICE_PROP(property) \
|
||||
device_prop->property = \
|
||||
@ -182,8 +187,6 @@ void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) {
|
||||
|
||||
AT_FORALL_XPU_DEVICE_ASPECT(ASSIGN_DEVICE_ASPECT);
|
||||
|
||||
// TODO: Remove cl_version since it is unnecessary.
|
||||
sycl::ext::oneapi::experimental::cl_version cl_version;
|
||||
AT_FORALL_XPU_EXP_CL_ASPECT(ASSIGN_EXP_CL_ASPECT);
|
||||
|
||||
#if SYCL_COMPILER_VERSION >= 20250000
|
||||
|
@ -1044,6 +1044,17 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
|
||||
endif(CMAKE_BUILD_TYPE MATCHES Debug)
|
||||
|
||||
# Get EnVar 'USE_LAYERNORM_FAST_RECIPROCAL' (or default to on).
|
||||
if(DEFINED ENV{USE_LAYERNORM_FAST_RECIPROCAL})
|
||||
set(USE_LAYERNORM_FAST_RECIPROCAL $ENV{USE_LAYERNORM_FAST_RECIPROCAL})
|
||||
else()
|
||||
set(USE_LAYERNORM_FAST_RECIPROCAL ON)
|
||||
endif()
|
||||
|
||||
if(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
add_definitions(-DUSE_LAYERNORM_FAST_RECIPROCAL)
|
||||
endif()
|
||||
|
||||
# needed for compat with newer versions of hip-clang that introduced C++20 mangling rules
|
||||
list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)
|
||||
|
||||
|
@ -128,11 +128,12 @@ function(caffe2_print_configuration_summary)
|
||||
endif()
|
||||
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
||||
if(${USE_ROCM})
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||
message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}")
|
||||
message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}")
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||
message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}")
|
||||
message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}")
|
||||
message(STATUS " USE_LAYERNORM_FAST_RECIPROCAL : ${USE_LAYERNORM_FAST_RECIPROCAL}")
|
||||
endif()
|
||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
|
@ -3,11 +3,11 @@ from __future__ import annotations
|
||||
import dis
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
@ -5,7 +5,7 @@ Python implementation of function wrapping functionality for functorch.dim.
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
@ -15,6 +15,10 @@ from ._enable_all_layers import EnableAllLayers
|
||||
from ._tensor_info import TensorInfo
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Handle tensor conversion for torch function integration."""
|
||||
return tensor
|
||||
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from types import (
|
||||
BuiltinMethodType,
|
||||
FunctionType,
|
||||
@ -12,7 +13,7 @@ from types import (
|
||||
MethodDescriptorType,
|
||||
WrapperDescriptorType,
|
||||
)
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
FUNC_TYPES = (
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch.dim import dims # noqa: F401
|
||||
@ -16,7 +16,7 @@ from ._parsing import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
|
@ -180,6 +180,7 @@ ignore = [
|
||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
"SIM300", # Yoda condition detected
|
||||
"UP007", # keep-runtime-typing
|
||||
"UP045", # keep-runtime-typing
|
||||
"TC006",
|
||||
@ -195,8 +196,7 @@ select = [
|
||||
"E",
|
||||
"EXE",
|
||||
"F",
|
||||
"SIM1",
|
||||
"SIM911",
|
||||
"SIM",
|
||||
"W",
|
||||
# Not included in flake8
|
||||
"FURB",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user