mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 11:14:56 +08:00
Compare commits
1 Commits
nofun-hack
...
csl/lint_t
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e223d6b3e |
@ -37,9 +37,9 @@ case ${DOCKER_TAG_PREFIX} in
|
|||||||
rocm*)
|
rocm*)
|
||||||
BASE_TARGET=rocm
|
BASE_TARGET=rocm
|
||||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
# add gfx950 conditionally starting in ROCm 7.0
|
||||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950"
|
||||||
fi
|
fi
|
||||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||||
;;
|
;;
|
||||||
|
|||||||
@ -19,8 +19,8 @@ pip_install \
|
|||||||
transformers==4.36.2
|
transformers==4.36.2
|
||||||
|
|
||||||
pip_install coloredlogs packaging
|
pip_install coloredlogs packaging
|
||||||
pip_install onnxruntime==1.23.0
|
pip_install onnxruntime==1.22.1
|
||||||
pip_install onnxscript==0.5.3
|
pip_install onnxscript==0.4.0
|
||||||
|
|
||||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||||
|
|||||||
@ -115,9 +115,6 @@ RUN env GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=True pip3 install grpcio
|
|||||||
# cmake-3.28.0 from pip for onnxruntime
|
# cmake-3.28.0 from pip for onnxruntime
|
||||||
RUN python3 -mpip install cmake==3.28.0
|
RUN python3 -mpip install cmake==3.28.0
|
||||||
|
|
||||||
ADD ./common/patch_libstdc.sh patch_libstdc.sh
|
|
||||||
RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh
|
|
||||||
|
|
||||||
# build onnxruntime 1.21.0 from sources.
|
# build onnxruntime 1.21.0 from sources.
|
||||||
# it is not possible to build it from sources using pip,
|
# it is not possible to build it from sources using pip,
|
||||||
# so just build it from upstream repository.
|
# so just build it from upstream repository.
|
||||||
|
|||||||
@ -120,8 +120,9 @@ ninja==1.11.1.4
|
|||||||
numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
|
numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
|
||||||
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
|
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
|
||||||
#Description: Just-In-Time Compiler for Numerical Functions
|
#Description: Just-In-Time Compiler for Numerical Functions
|
||||||
#Pinned versions: 0.55.2, 0.60.0
|
#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
|
||||||
#test that import: test_numba_integration.py
|
#test that import: test_numba_integration.py
|
||||||
|
#For numba issue see https://github.com/pytorch/pytorch/issues/51511
|
||||||
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
|
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
|
||||||
|
|
||||||
#numpy
|
#numpy
|
||||||
@ -241,9 +242,10 @@ pygments==2.15.0
|
|||||||
#Pinned versions: 14.1.0
|
#Pinned versions: 14.1.0
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
scikit-image==0.22.0
|
scikit-image==0.19.3 ; python_version < "3.10"
|
||||||
|
scikit-image==0.22.0 ; python_version >= "3.10"
|
||||||
#Description: image processing routines
|
#Description: image processing routines
|
||||||
#Pinned versions: 0.22.0
|
#Pinned versions:
|
||||||
#test that import: test_nn.py
|
#test that import: test_nn.py
|
||||||
|
|
||||||
#scikit-learn
|
#scikit-learn
|
||||||
@ -339,7 +341,7 @@ onnx==1.18.0
|
|||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|
||||||
onnxscript==0.5.3
|
onnxscript==0.4.0
|
||||||
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
|
||||||
#Pinned versions:
|
#Pinned versions:
|
||||||
#test that import:
|
#test that import:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ DESIRED_ROCM ?= 7.0
|
|||||||
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
|
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
|
||||||
PACKAGE_NAME = magma-rocm
|
PACKAGE_NAME = magma-rocm
|
||||||
# inherit this from underlying docker image, do not pass this env var to docker
|
# inherit this from underlying docker image, do not pass this env var to docker
|
||||||
#PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1150;gfx1151;gfx1200;gfx1201
|
#PYTORCH_ROCM_ARCH ?= gfx900;gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201
|
||||||
|
|
||||||
DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
|
DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
|
||||||
-v $(shell git rev-parse --show-toplevel)/.ci:/builder \
|
-v $(shell git rev-parse --show-toplevel)/.ci:/builder \
|
||||||
@ -18,6 +18,7 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
|
|||||||
.PHONY: all
|
.PHONY: all
|
||||||
all: magma-rocm70
|
all: magma-rocm70
|
||||||
all: magma-rocm64
|
all: magma-rocm64
|
||||||
|
all: magma-rocm63
|
||||||
|
|
||||||
.PHONY:
|
.PHONY:
|
||||||
clean:
|
clean:
|
||||||
@ -33,3 +34,8 @@ magma-rocm70:
|
|||||||
magma-rocm64: DESIRED_ROCM := 6.4
|
magma-rocm64: DESIRED_ROCM := 6.4
|
||||||
magma-rocm64:
|
magma-rocm64:
|
||||||
$(DOCKER_RUN)
|
$(DOCKER_RUN)
|
||||||
|
|
||||||
|
.PHONY: magma-rocm63
|
||||||
|
magma-rocm63: DESIRED_ROCM := 6.3
|
||||||
|
magma-rocm63:
|
||||||
|
$(DOCKER_RUN)
|
||||||
|
|||||||
@ -67,7 +67,7 @@ fi
|
|||||||
# wheels with cxx11-abi
|
# wheels with cxx11-abi
|
||||||
|
|
||||||
echo "Checking that the gcc ABI is what we expect"
|
echo "Checking that the gcc ABI is what we expect"
|
||||||
if [[ "$(uname)" != 'Darwin' ]]; then
|
if [[ "$(uname)" != 'Darwin' && "$(uname -m)" != "s390x" ]]; then
|
||||||
# We also check that there are cxx11 symbols in libtorch
|
# We also check that there are cxx11 symbols in libtorch
|
||||||
#
|
#
|
||||||
echo "Checking that symbols in libtorch.so have the right gcc abi"
|
echo "Checking that symbols in libtorch.so have the right gcc abi"
|
||||||
|
|||||||
@ -34,14 +34,12 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
# Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878
|
# Patch numba to avoid CUDA-13 crash, see https://github.com/pytorch/pytorch/issues/162878
|
||||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
|
||||||
NUMBA_CUDA_DIR=$(python -c "import os;import numba.cuda; print(os.path.dirname(numba.cuda.__file__))" 2>/dev/null || true)
|
if [ -n "$NUMBA_CUDA_DIR" ]; then
|
||||||
if [ -n "$NUMBA_CUDA_DIR" ]; then
|
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
|
||||||
NUMBA_PATCH="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/numba-cuda-13.patch"
|
pushd "$NUMBA_CUDA_DIR"
|
||||||
pushd "$NUMBA_CUDA_DIR"
|
patch -p4 <"$NUMBA_PATCH"
|
||||||
patch -p4 <"$NUMBA_PATCH"
|
popd
|
||||||
popd
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Environment variables:"
|
echo "Environment variables:"
|
||||||
|
|||||||
@ -15,38 +15,26 @@ if errorlevel 1 exit /b 1
|
|||||||
if not errorlevel 0 exit /b 1
|
if not errorlevel 0 exit /b 1
|
||||||
|
|
||||||
cd %TMP_DIR_WIN%\build\torch\test
|
cd %TMP_DIR_WIN%\build\torch\test
|
||||||
|
|
||||||
|
:: Enable delayed variable expansion to make the list
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
set EXE_LIST=
|
||||||
for /r "." %%a in (*.exe) do (
|
for /r "." %%a in (*.exe) do (
|
||||||
call :libtorch_check "%%~na" "%%~fa"
|
set EXE_LIST=!EXE_LIST! cpp/%%~fa
|
||||||
if errorlevel 1 goto fail
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
:: Run python test\run_test.py on the list
|
||||||
|
python test\run_test.py --cpp --verbose -i !EXE_LIST! ^
|
||||||
|
--exclude ^
|
||||||
|
:: Skip verify_api_visibility as it a compile level test
|
||||||
|
"cpp/verify_api_visibility" ^
|
||||||
|
:: NB: This is not a gtest executable file, thus couldn't be handled by pytest-cpp
|
||||||
|
"cpp/c10_intrusive_ptr_benchmark"
|
||||||
|
if errorlevel 1 goto fail
|
||||||
|
if not errorlevel 0 goto fail
|
||||||
|
|
||||||
goto :eof
|
goto :eof
|
||||||
|
|
||||||
:libtorch_check
|
|
||||||
|
|
||||||
cd %CWD%
|
|
||||||
set CPP_TESTS_DIR=%TMP_DIR_WIN%\build\torch\test
|
|
||||||
|
|
||||||
:: Skip verify_api_visibility as it a compile level test
|
|
||||||
if "%~1" == "verify_api_visibility" goto :eof
|
|
||||||
|
|
||||||
echo Running "%~2"
|
|
||||||
if "%~1" == "c10_intrusive_ptr_benchmark" (
|
|
||||||
:: NB: This is not a gtest executable file, thus couldn't be handled by pytest-cpp
|
|
||||||
call "%~2"
|
|
||||||
goto :eof
|
|
||||||
)
|
|
||||||
|
|
||||||
python test\run_test.py --cpp --verbose -i "cpp/%~1"
|
|
||||||
if errorlevel 1 (
|
|
||||||
echo %1 failed with exit code %errorlevel%
|
|
||||||
goto fail
|
|
||||||
)
|
|
||||||
if not errorlevel 0 (
|
|
||||||
echo %1 failed with exit code %errorlevel%
|
|
||||||
goto fail
|
|
||||||
)
|
|
||||||
|
|
||||||
:eof
|
:eof
|
||||||
exit /b 0
|
exit /b 0
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# TODO: Move this to .ci/docker/requirements-ci.txt
|
# TODO: Move this to .ci/docker/requirements-ci.txt
|
||||||
python -m pip install "psutil==5.9.1" nvidia-ml-py "pytest-shard==0.1.2"
|
python -m pip install "psutil==5.9.1" "pynvml==11.4.1" "pytest-shard==0.1.2"
|
||||||
|
|
||||||
run_tests() {
|
run_tests() {
|
||||||
# Run nvidia-smi if available
|
# Run nvidia-smi if available
|
||||||
|
|||||||
@ -66,7 +66,6 @@ readability-simplify-subscript-expr,
|
|||||||
readability-string-compare,
|
readability-string-compare,
|
||||||
-readability-redundant-access-specifiers,
|
-readability-redundant-access-specifiers,
|
||||||
-readability-redundant-control-flow,
|
-readability-redundant-control-flow,
|
||||||
-readability-redundant-inline-specifier,
|
|
||||||
'
|
'
|
||||||
HeaderFilterRegex: '^(aten/|c10/|torch/).*$'
|
HeaderFilterRegex: '^(aten/|c10/|torch/).*$'
|
||||||
WarningsAsErrors: '*'
|
WarningsAsErrors: '*'
|
||||||
|
|||||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
2a9138a26ee257fef05310ad3fecf7c55fe80d73
|
0fc62aa26a30ed7ca419d285f285cb5ba02c4394
|
||||||
|
|||||||
9
.github/workflows/_get-changed-files.yml
vendored
9
.github/workflows/_get-changed-files.yml
vendored
@ -40,15 +40,6 @@ jobs:
|
|||||||
# Use gh CLI to get changed files in the PR with explicit repo
|
# Use gh CLI to get changed files in the PR with explicit repo
|
||||||
CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//')
|
CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//')
|
||||||
|
|
||||||
# See https://github.com/pytorch/pytorch/pull/134215#issuecomment-2332128790
|
|
||||||
PYI_FILES_TO_ADD=""
|
|
||||||
for file in ${CHANGED_FILES}; do
|
|
||||||
if [[ "${file}" == *".pyi.in" ]]; then
|
|
||||||
PYI_FILES_TO_ADD="${PYI_FILES_TO_ADD} ${file//.in/}"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
CHANGED_FILES="${CHANGED_FILES}${PYI_FILES_TO_ADD}"
|
|
||||||
|
|
||||||
if [ -z "$CHANGED_FILES" ]; then
|
if [ -z "$CHANGED_FILES" ]; then
|
||||||
echo "No changed files found, setting to '*'"
|
echo "No changed files found, setting to '*'"
|
||||||
CHANGED_FILES="*"
|
CHANGED_FILES="*"
|
||||||
|
|||||||
@ -63,7 +63,6 @@ jobs:
|
|||||||
# Same as the build job
|
# Same as the build job
|
||||||
python-version: 3.12.7
|
python-version: 3.12.7
|
||||||
test-matrix: ${{ needs.macos-perf-py3-arm64-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.macos-perf-py3-arm64-build.outputs.test-matrix }}
|
||||||
timeout-minutes: 300
|
|
||||||
disable-monitor: false
|
disable-monitor: false
|
||||||
monitor-log-interval: 15
|
monitor-log-interval: 15
|
||||||
monitor-data-collect-interval: 4
|
monitor-data-collect-interval: 4
|
||||||
|
|||||||
10
.github/workflows/inductor-periodic.yml
vendored
10
.github/workflows/inductor-periodic.yml
vendored
@ -106,16 +106,6 @@ jobs:
|
|||||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
{ config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
{ config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
6
.github/workflows/periodic.yml
vendored
6
.github/workflows/periodic.yml
vendored
@ -213,9 +213,9 @@ jobs:
|
|||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/pull.yml
vendored
2
.github/workflows/pull.yml
vendored
@ -127,6 +127,8 @@ jobs:
|
|||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
|
# More memory is needed to build with asan
|
||||||
|
runner: linux.2xlarge.memory
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-jammy-py3.10-clang18-asan
|
build-environment: linux-jammy-py3.10-clang18-asan
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||||
|
|||||||
2
.github/workflows/slow.yml
vendored
2
.github/workflows/slow.yml
vendored
@ -140,6 +140,8 @@ jobs:
|
|||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
|
# More memory is needed to build with asan
|
||||||
|
runner: linux.2xlarge.memory
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-jammy-py3.10-clang18-asan
|
build-environment: linux-jammy-py3.10-clang18-asan
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan
|
||||||
|
|||||||
42
.github/workflows/trunk.yml
vendored
42
.github/workflows/trunk.yml
vendored
@ -160,10 +160,9 @@ jobs:
|
|||||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
{ config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
||||||
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
{ config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
||||||
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
{ config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
||||||
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" },
|
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
@ -190,6 +189,41 @@ jobs:
|
|||||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-jammy-rocm-py3_10-build:
|
||||||
|
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
||||||
|
name: linux-jammy-rocm-py3.10
|
||||||
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
needs: get-label-type
|
||||||
|
with:
|
||||||
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||||
|
sync-tag: rocm-build
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
|
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
|
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
||||||
|
]}
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-jammy-rocm-py3_10-test:
|
||||||
|
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }}
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
name: linux-jammy-rocm-py3.10
|
||||||
|
uses: ./.github/workflows/_rocm-test.yml
|
||||||
|
needs:
|
||||||
|
- linux-jammy-rocm-py3_10-build
|
||||||
|
- target-determination
|
||||||
|
with:
|
||||||
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
|
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||||
|
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||||
|
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
inductor-build:
|
inductor-build:
|
||||||
name: inductor-build
|
name: inductor-build
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
|||||||
@ -28,7 +28,7 @@ exclude_patterns = [
|
|||||||
'torch/lib/**',
|
'torch/lib/**',
|
||||||
'venv/**',
|
'venv/**',
|
||||||
'**/*.pyi',
|
'**/*.pyi',
|
||||||
"tools/experimental/torchfuzz/**",
|
"tools/experimental/dynamic_shapes/torchfuzz/**",
|
||||||
'tools/test/test_selective_build.py',
|
'tools/test/test_selective_build.py',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
@ -198,7 +198,7 @@ exclude_patterns = [
|
|||||||
'tools/test/gen_operators_yaml_test.py',
|
'tools/test/gen_operators_yaml_test.py',
|
||||||
'tools/test/gen_oplist_test.py',
|
'tools/test/gen_oplist_test.py',
|
||||||
'tools/test/test_selective_build.py',
|
'tools/test/test_selective_build.py',
|
||||||
'tools/experimental/torchfuzz/**',
|
'tools/experimental/dynamic_shapes/torchfuzz/**',
|
||||||
]
|
]
|
||||||
command = [
|
command = [
|
||||||
'python3',
|
'python3',
|
||||||
@ -1573,7 +1573,6 @@ exclude_patterns = [
|
|||||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||||
'torch/_inductor/autoheuristic/artifacts/**',
|
'torch/_inductor/autoheuristic/artifacts/**',
|
||||||
'test/dynamo/cpython/**',
|
'test/dynamo/cpython/**',
|
||||||
'test/test_torchfuzz_repros.py',
|
|
||||||
'scripts/**',
|
'scripts/**',
|
||||||
'third_party/**',
|
'third_party/**',
|
||||||
'fb/**',
|
'fb/**',
|
||||||
|
|||||||
@ -53,7 +53,7 @@ ARG CUDA_PATH=cu121
|
|||||||
ARG INSTALL_CHANNEL=whl/nightly
|
ARG INSTALL_CHANNEL=whl/nightly
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
# pinning version of conda here see: https://github.com/pytorch/pytorch/issues/164574
|
# pinning version of conda here see: https://github.com/pytorch/pytorch/issues/164574
|
||||||
RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} conda=25.7.0
|
RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION} conda=25.7.0
|
||||||
|
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
|
||||||
|
|||||||
@ -40,6 +40,41 @@ namespace {
|
|||||||
->conv
|
->conv
|
||||||
->rnn
|
->rnn
|
||||||
*/
|
*/
|
||||||
|
const std::map<std::string, std::vector<std::string>> _fp32_precisions = {
|
||||||
|
{"generic", {{"ieee", "tf32", "bf16", "none"}}},
|
||||||
|
{"mkldnn", {{"ieee", "tf32", "bf16", "none"}}},
|
||||||
|
{"cuda", {{"ieee", "tf32", "none"}}}};
|
||||||
|
|
||||||
|
// Check whether the backend and op are legal
|
||||||
|
void check_fp32_prec_backend_and_op(
|
||||||
|
const std::string& backend,
|
||||||
|
const std::string& op) {
|
||||||
|
static std::vector<std::string> backends = {"generic", "mkldnn", "cuda"};
|
||||||
|
static std::vector<std::string> operators = {"conv", "matmul", "rnn", "all"};
|
||||||
|
TORCH_CHECK(
|
||||||
|
std::find(backends.begin(), backends.end(), backend) != backends.end(),
|
||||||
|
"Invalid backend: ",
|
||||||
|
backend);
|
||||||
|
TORCH_CHECK(
|
||||||
|
std::find(operators.begin(), operators.end(), op) != operators.end(),
|
||||||
|
"Invalid operator: ",
|
||||||
|
op);
|
||||||
|
if (backend == "generic") {
|
||||||
|
TORCH_CHECK(op == "all", "Invalid operation for generic backend: ", op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return whether the precision is supported by backends
|
||||||
|
bool validate_fp32_prec(
|
||||||
|
const std::string& backend,
|
||||||
|
const std::string& precision) {
|
||||||
|
auto iterp = _fp32_precisions.find(backend);
|
||||||
|
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||||
|
auto precisions = iterp->second;
|
||||||
|
bool valid = std::find(precisions.begin(), precisions.end(), precision) !=
|
||||||
|
precisions.end();
|
||||||
|
return valid;
|
||||||
|
}
|
||||||
|
|
||||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||||
TORCH_WARN_ONCE(
|
TORCH_WARN_ONCE(
|
||||||
@ -51,54 +86,6 @@ namespace {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Float32Backend str2backend(const std::string& name) {
|
|
||||||
if (name == "generic")
|
|
||||||
return Float32Backend::GENERIC;
|
|
||||||
else if (name == "cuda")
|
|
||||||
return Float32Backend::CUDA;
|
|
||||||
else if (name == "mkldnn")
|
|
||||||
return Float32Backend::MKLDNN;
|
|
||||||
TORCH_CHECK(false, "Unknown backend: ", name);
|
|
||||||
}
|
|
||||||
|
|
||||||
Float32Op str2op(const std::string& name) {
|
|
||||||
if (name == "all")
|
|
||||||
return Float32Op::ALL;
|
|
||||||
else if (name == "conv")
|
|
||||||
return Float32Op::CONV;
|
|
||||||
else if (name == "rnn")
|
|
||||||
return Float32Op::RNN;
|
|
||||||
else if (name == "matmul")
|
|
||||||
return Float32Op::MATMUL;
|
|
||||||
TORCH_CHECK(false, "Unknown op: ", name);
|
|
||||||
}
|
|
||||||
|
|
||||||
Float32Precision str2precision(const std::string& name) {
|
|
||||||
if (name == "none")
|
|
||||||
return Float32Precision::NONE;
|
|
||||||
else if (name == "ieee")
|
|
||||||
return Float32Precision::IEEE;
|
|
||||||
else if (name == "tf32")
|
|
||||||
return Float32Precision::TF32;
|
|
||||||
else if (name == "bf16")
|
|
||||||
return Float32Precision::BF16;
|
|
||||||
TORCH_CHECK(false, "Unknown precision: ", name);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string precision2str(Float32Precision prec) {
|
|
||||||
switch (prec) {
|
|
||||||
case Float32Precision::NONE:
|
|
||||||
return "none";
|
|
||||||
case Float32Precision::IEEE:
|
|
||||||
return "ieee";
|
|
||||||
case Float32Precision::TF32:
|
|
||||||
return "tf32";
|
|
||||||
case Float32Precision::BF16:
|
|
||||||
return "bf16";
|
|
||||||
}
|
|
||||||
TORCH_CHECK(false, "Invalid enum Float32Precision(", static_cast<int>(prec), ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
Context::Context() = default;
|
Context::Context() = default;
|
||||||
|
|
||||||
// TODO: This could be bad juju if someone calls globalContext() in the
|
// TODO: This could be bad juju if someone calls globalContext() in the
|
||||||
@ -192,10 +179,10 @@ void Context::setUserEnabledNNPACK(bool e) {
|
|||||||
enabled_nnpack = e;
|
enabled_nnpack = e;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
bool Context::allowTF32CuDNN(const std::string& op) const {
|
||||||
if (!op.has_value()) {
|
if (op.empty()){
|
||||||
bool allow_tf32_rnn = float32Precision(Float32Backend::CUDA, Float32Op::RNN) == Float32Precision::TF32;
|
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
|
||||||
bool allow_tf32_conv = float32Precision(Float32Backend::CUDA, Float32Op::CONV) == Float32Precision::TF32;
|
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
|
allow_tf32_rnn == allow_tf32_conv && allow_tf32_rnn == allow_tf32_cudnn,
|
||||||
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
|
"PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,",
|
||||||
@ -204,15 +191,15 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
|||||||
"We suggest only using the new API to set the TF32 flag(s). See also: ",
|
"We suggest only using the new API to set the TF32 flag(s). See also: ",
|
||||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||||
} else {
|
} else {
|
||||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
return float32Precision("cuda", op) == "tf32";
|
||||||
}
|
}
|
||||||
warn_deprecated_fp32_precision_api();
|
warn_deprecated_fp32_precision_api();
|
||||||
return allow_tf32_cudnn;
|
return allow_tf32_cudnn;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setAllowTF32CuDNN(bool b) {
|
void Context::setAllowTF32CuDNN(bool b) {
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
setFloat32Precision("cuda", "rnn", b ? "tf32" : "none");
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
setFloat32Precision("cuda", "conv", b ? "tf32" : "none");
|
||||||
allow_tf32_cudnn = b;
|
allow_tf32_cudnn = b;
|
||||||
warn_deprecated_fp32_precision_api();
|
warn_deprecated_fp32_precision_api();
|
||||||
}
|
}
|
||||||
@ -318,7 +305,7 @@ void Context::setImmediateMiopen(bool b) {
|
|||||||
|
|
||||||
bool Context::allowTF32CuBLAS() const {
|
bool Context::allowTF32CuBLAS() const {
|
||||||
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
|
||||||
bool allow_tf32_new = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32;
|
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
legacy_allow_tf32 == allow_tf32_new,
|
legacy_allow_tf32 == allow_tf32_new,
|
||||||
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
|
"PyTorch is checking whether allow_tf32_new is enabled for cuBlas matmul,",
|
||||||
@ -331,17 +318,17 @@ bool Context::allowTF32CuBLAS() const {
|
|||||||
|
|
||||||
void Context::setAllowTF32CuBLAS(bool b) {
|
void Context::setAllowTF32CuBLAS(bool b) {
|
||||||
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
|
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, b ? Float32Precision::TF32 : Float32Precision::IEEE);
|
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
|
||||||
}
|
}
|
||||||
|
|
||||||
Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||||
bool invalid = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32 &&
|
bool invalid = float32Precision("cuda", "matmul") == "tf32" &&
|
||||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
|
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST;
|
||||||
invalid = invalid ||
|
invalid = invalid ||
|
||||||
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::BF16 &&
|
(float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||||
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
|
float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM);
|
||||||
invalid = invalid ||
|
invalid = invalid ||
|
||||||
(float32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL) == Float32Precision::TF32 &&
|
(float32Precision("mkldnn", "matmul") == "tf32" &&
|
||||||
float32_matmul_precision != at::Float32MatmulPrecision::HIGH);
|
float32_matmul_precision != at::Float32MatmulPrecision::HIGH);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!invalid,
|
!invalid,
|
||||||
@ -353,26 +340,15 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
|||||||
return float32_matmul_precision;
|
return float32_matmul_precision;
|
||||||
}
|
}
|
||||||
|
|
||||||
Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) const {
|
std::string Context::float32Precision(const std::string& backend, const std::string& op) const {
|
||||||
std::pair<Float32Backend, Float32Op> key{backend, op};
|
check_fp32_prec_backend_and_op(backend, op);
|
||||||
auto it = fp32_precision.find(key);
|
auto precision = fp32_precision.find(backend)->second.find(op)->second;
|
||||||
TORCH_CHECK(it != fp32_precision.end(), "Invalid (backend, op) pair: (", backend, ", ", op, ")");
|
if (precision == "none")
|
||||||
|
precision = fp32_precision.find(backend)->second.find("all")->second;
|
||||||
Float32Precision precision = it->second;
|
if (precision == "none")
|
||||||
if (precision == Float32Precision::NONE) {
|
precision = fp32_precision.find("generic")->second.find("all")->second;
|
||||||
key.second = Float32Op::ALL;
|
bool valid_prec = validate_fp32_prec(backend, precision);
|
||||||
precision = fp32_precision.find(key)->second;
|
return valid_prec ? precision : "none";
|
||||||
}
|
|
||||||
if (precision == Float32Precision::NONE) {
|
|
||||||
key.first = Float32Backend::GENERIC;
|
|
||||||
precision = fp32_precision.find(key)->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
// "cuda" does not support "bf16"
|
|
||||||
if (backend == Float32Backend::CUDA && precision == Float32Precision::BF16) {
|
|
||||||
return Float32Precision::NONE;
|
|
||||||
}
|
|
||||||
return precision;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||||
@ -381,18 +357,18 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
|
|||||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||||
if (s_ == "highest") {
|
if (s_ == "highest") {
|
||||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::IEEE);
|
setFloat32Precision("cuda", "matmul", "ieee");
|
||||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::IEEE);
|
setFloat32Precision("mkldnn", "matmul", "ieee");
|
||||||
return true;
|
return true;
|
||||||
} else if (s_ == "high") {
|
} else if (s_ == "high") {
|
||||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
|
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
|
setFloat32Precision("cuda", "matmul", "tf32");
|
||||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::TF32);
|
setFloat32Precision("mkldnn", "matmul", "tf32");
|
||||||
return true;
|
return true;
|
||||||
} else if (s_ == "medium") {
|
} else if (s_ == "medium") {
|
||||||
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
|
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::MATMUL, Float32Precision::TF32);
|
setFloat32Precision("cuda", "matmul", "tf32");
|
||||||
setFloat32Precision(Float32Backend::MKLDNN, Float32Op::MATMUL, Float32Precision::BF16);
|
setFloat32Precision("mkldnn", "matmul", "bf16");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -406,16 +382,25 @@ void Context::setFloat32MatmulPrecision(const std::string &s) {
|
|||||||
"setFloat32MatmulPrecision call has no effect.");
|
"setFloat32MatmulPrecision call has no effect.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setFloat32Precision(Float32Backend backend, Float32Op op, Float32Precision p) {
|
void Context::setFloat32Precision(const std::string& backend, const std::string& op, const std::string& p) {
|
||||||
auto it = fp32_precision.find(std::make_pair(backend, op));
|
check_fp32_prec_backend_and_op(backend, op);
|
||||||
TORCH_CHECK(
|
if (validate_fp32_prec(backend, p)) {
|
||||||
it != fp32_precision.end(),
|
fp32_precision[backend][op] = p;
|
||||||
"Invalid (backend, op) pair: (", backend, ", ", op, ")");
|
} else {
|
||||||
TORCH_CHECK(
|
std::string msg;
|
||||||
!(backend == Float32Backend::CUDA && p == Float32Precision::BF16),
|
auto iterp = _fp32_precisions.find(backend);
|
||||||
"backend 'cuda' does not support precision 'bf16'");
|
TORCH_CHECK(iterp != _fp32_precisions.end());
|
||||||
|
for (const auto& p : iterp->second) {
|
||||||
it->second = p;
|
msg += p;
|
||||||
|
msg += " ";
|
||||||
|
}
|
||||||
|
TORCH_WARN(
|
||||||
|
"you have set wrong precision for backend:",
|
||||||
|
backend,
|
||||||
|
" setFloat32Precision call has no effect.",
|
||||||
|
"Please choose precision from: ",
|
||||||
|
msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
at::LinalgBackend Context::linalgPreferredBackend() const {
|
at::LinalgBackend Context::linalgPreferredBackend() const {
|
||||||
@ -483,8 +468,8 @@ at::BlasBackend Context::blasPreferredBackend() {
|
|||||||
#if ROCM_VERSION >= 60300
|
#if ROCM_VERSION >= 60300
|
||||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||||
#endif
|
#endif
|
||||||
#if ROCM_VERSION >= 70000
|
#if ROCM_VERSION >= 60500
|
||||||
"gfx950", "gfx1150", "gfx1151"
|
"gfx950"
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
|
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
|
||||||
|
|||||||
@ -25,27 +25,17 @@
|
|||||||
#include <c10/util/CallOnce.h>
|
#include <c10/util/CallOnce.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/env.h>
|
#include <c10/util/env.h>
|
||||||
#include <c10/util/hash.h>
|
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
|
||||||
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
||||||
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
|
|
||||||
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
|
|
||||||
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
|
|
||||||
|
|
||||||
TORCH_API Float32Backend str2backend(const std::string& name);
|
|
||||||
TORCH_API Float32Op str2op(const std::string& name);
|
|
||||||
TORCH_API Float32Precision str2precision(const std::string& name);
|
|
||||||
TORCH_API std::string precision2str(Float32Precision prec);
|
|
||||||
|
|
||||||
class TORCH_API Context {
|
class TORCH_API Context {
|
||||||
public:
|
public:
|
||||||
@ -346,17 +336,19 @@ class TORCH_API Context {
|
|||||||
|
|
||||||
void setFloat32MatmulPrecision(const std::string& s);
|
void setFloat32MatmulPrecision(const std::string& s);
|
||||||
void setFloat32Precision(
|
void setFloat32Precision(
|
||||||
Float32Backend backend,
|
const std::string& backend,
|
||||||
Float32Op op,
|
const std::string& op,
|
||||||
Float32Precision p);
|
const std::string& s);
|
||||||
bool allowTF32CuDNN(std::optional<Float32Op> op = std::nullopt) const;
|
bool allowTF32CuDNN(const std::string& op = std::string()) const;
|
||||||
void setAllowTF32CuDNN(bool);
|
void setAllowTF32CuDNN(bool);
|
||||||
bool allowTF32OneDNN() const;
|
bool allowTF32OneDNN() const;
|
||||||
void setAllowTF32OneDNN(bool);
|
void setAllowTF32OneDNN(bool);
|
||||||
bool allowTF32CuBLAS() const;
|
bool allowTF32CuBLAS() const;
|
||||||
void setAllowTF32CuBLAS(bool);
|
void setAllowTF32CuBLAS(bool);
|
||||||
Float32MatmulPrecision float32MatmulPrecision() const;
|
Float32MatmulPrecision float32MatmulPrecision() const;
|
||||||
Float32Precision float32Precision(Float32Backend backend, Float32Op op) const;
|
std::string float32Precision(
|
||||||
|
const std::string& backend,
|
||||||
|
const std::string& op) const;
|
||||||
bool allowFP16ReductionCuBLAS() const;
|
bool allowFP16ReductionCuBLAS() const;
|
||||||
void setAllowFP16ReductionCuBLAS(bool);
|
void setAllowFP16ReductionCuBLAS(bool);
|
||||||
bool allowBF16ReductionCuBLAS() const;
|
bool allowBF16ReductionCuBLAS() const;
|
||||||
@ -483,20 +475,21 @@ class TORCH_API Context {
|
|||||||
bool enable_sparse_tensor_invariant_checks = false;
|
bool enable_sparse_tensor_invariant_checks = false;
|
||||||
bool allow_fp16_reduction_cpu = false;
|
bool allow_fp16_reduction_cpu = false;
|
||||||
|
|
||||||
using Key = std::pair<Float32Backend, Float32Op>;
|
std::map<std::string, std::map<std::string, std::string>> fp32_precision = {
|
||||||
std::unordered_map<Key, Float32Precision, c10::hash<Key>> fp32_precision = {
|
{"generic", {{"all", "none"}}},
|
||||||
{{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE},
|
{"mkldnn",
|
||||||
{{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE},
|
{{"matmul", "none"},
|
||||||
{{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE},
|
{"conv", "none"},
|
||||||
{{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE},
|
{"rnn", "none"},
|
||||||
{{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE},
|
{"all", "none"}}},
|
||||||
{{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE},
|
{"cuda",
|
||||||
{{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32},
|
{{"matmul",
|
||||||
{{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32},
|
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
|
||||||
{{Float32Backend::CUDA, Float32Op::MATMUL},
|
? "none"
|
||||||
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
|
: "tf32"},
|
||||||
? Float32Precision::NONE
|
{"conv", "tf32"},
|
||||||
: Float32Precision::TF32},
|
{"rnn", "tf32"},
|
||||||
|
{"all", "none"}}},
|
||||||
};
|
};
|
||||||
|
|
||||||
Allocator* prev_allocator_ptr_{nullptr};
|
Allocator* prev_allocator_ptr_{nullptr};
|
||||||
@ -678,4 +671,5 @@ struct TORCH_API ROCmBackwardPassGuard {
|
|||||||
~ROCmBackwardPassGuard();
|
~ROCmBackwardPassGuard();
|
||||||
static bool is_backward_pass();
|
static bool is_backward_pass();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|||||||
@ -179,7 +179,7 @@ void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const auto src_names = src.names();
|
const auto src_names = src.names();
|
||||||
const auto result_dim = result.dim();
|
const auto result_dim = static_cast<int64_t>(result.dim());
|
||||||
const auto src_dim = static_cast<int64_t>(src_names.size());
|
const auto src_dim = static_cast<int64_t>(src_names.size());
|
||||||
const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size());
|
const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size());
|
||||||
TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim);
|
TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim);
|
||||||
|
|||||||
@ -214,7 +214,7 @@ inline Tensor applySlice(
|
|||||||
"step must be greater than zero");
|
"step must be greater than zero");
|
||||||
|
|
||||||
// See NOTE [nested tensor size for indexing]
|
// See NOTE [nested tensor size for indexing]
|
||||||
if (self_sizes.has_value() && !self_sizes.value().empty()) {
|
if (self_sizes.has_value() && self_sizes.value().size() > 0) {
|
||||||
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
||||||
// over the shape of the `self` tensor, and we still want to record
|
// over the shape of the `self` tensor, and we still want to record
|
||||||
// the slice.
|
// the slice.
|
||||||
|
|||||||
@ -273,11 +273,11 @@ void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout)
|
|||||||
}
|
}
|
||||||
|
|
||||||
void * maybe_data_ptr(const Tensor& tensor) {
|
void * maybe_data_ptr(const Tensor& tensor) {
|
||||||
return tensor.defined() ? tensor.data_ptr() : nullptr;
|
return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void * maybe_data_ptr(const TensorArg& tensor) {
|
void * maybe_data_ptr(const TensorArg& tensor) {
|
||||||
return tensor->defined() ? tensor->data_ptr() : nullptr;
|
return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void check_dim_size(
|
void check_dim_size(
|
||||||
|
|||||||
@ -50,46 +50,6 @@ namespace {
|
|||||||
constexpr size_t MAX_SIZE_INDEX = 64;
|
constexpr size_t MAX_SIZE_INDEX = 64;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A large reserved pinned memory segment that is created in advance which is used
|
|
||||||
// to allocate small pinned memory requests to avoid calling into expensive APIs.
|
|
||||||
// We never free this memory and move up the pointer as we allocate new blocks
|
|
||||||
// and when blocks are freed, they are cached in the free lists.
|
|
||||||
struct PinnedReserveSegment {
|
|
||||||
PinnedReserveSegment(void *start, size_t size) : start_(start), size_(size),
|
|
||||||
current_ptr_(start_), initialized_(true) {}
|
|
||||||
|
|
||||||
PinnedReserveSegment() : start_(nullptr), size_(0), current_ptr_(nullptr), initialized_(false) {}
|
|
||||||
|
|
||||||
bool initialized() {
|
|
||||||
return initialized_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* allocate(size_t bytes) {
|
|
||||||
std::lock_guard<std::mutex> guard(mutex_);
|
|
||||||
|
|
||||||
// Round up the requested size to 4KB boundary for all including the small ones.
|
|
||||||
size_t rounded_bytes = (bytes + 4096 - 1) & ~(4096 - 1);
|
|
||||||
|
|
||||||
if (((uint8_t*)current_ptr_ + rounded_bytes) > ((uint8_t*)start_ + size_)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void* ptr = current_ptr_;
|
|
||||||
current_ptr_ = (uint8_t*)current_ptr_ + rounded_bytes;
|
|
||||||
return ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool owns(void* ptr) {
|
|
||||||
return ptr >= start_ && ptr < (uint8_t*)start_ + size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::mutex mutex_;
|
|
||||||
void* start_;
|
|
||||||
size_t size_;
|
|
||||||
void* current_ptr_;
|
|
||||||
bool initialized_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Struct containing memory allocator summary statistics for host.
|
// Struct containing memory allocator summary statistics for host.
|
||||||
struct TORCH_API HostStats {
|
struct TORCH_API HostStats {
|
||||||
// COUNT: total allocations (active)
|
// COUNT: total allocations (active)
|
||||||
@ -243,21 +203,7 @@ struct CachingHostAllocatorImpl {
|
|||||||
// background.
|
// background.
|
||||||
if (!pinned_use_background_threads()) {
|
if (!pinned_use_background_threads()) {
|
||||||
process_events();
|
process_events();
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// Round up the allocation to the nearest power of two to improve reuse.
|
|
||||||
// These power of two sizes are also used to index into the free list.
|
|
||||||
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
|
||||||
|
|
||||||
// First, try to allocate from the free list
|
|
||||||
auto* block = get_free_block(roundSize);
|
|
||||||
if (block) {
|
|
||||||
return {block->ptr_, reinterpret_cast<void*>(block)};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check in the recently freed blocks with pending events to see if we
|
|
||||||
// can reuse them. Call get_free_block again after processing events
|
|
||||||
if (pinned_use_background_threads()) {
|
|
||||||
// Launch the background thread and process events in a loop.
|
// Launch the background thread and process events in a loop.
|
||||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||||
getBackgroundThreadPool()->run([&]() {
|
getBackgroundThreadPool()->run([&]() {
|
||||||
@ -270,6 +216,16 @@ struct CachingHostAllocatorImpl {
|
|||||||
}();
|
}();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Round up the allocation to the nearest power of two to improve reuse.
|
||||||
|
// These power of two sizes are also used to index into the free list.
|
||||||
|
size_t roundSize = c10::llvm::PowerOf2Ceil(size);
|
||||||
|
|
||||||
|
// First, try to allocate from the free list
|
||||||
|
auto* block = get_free_block(roundSize);
|
||||||
|
if (block) {
|
||||||
|
return {block->ptr_, reinterpret_cast<void*>(block)};
|
||||||
|
}
|
||||||
|
|
||||||
// Slow path: if we can't allocate from the cached free list, we need
|
// Slow path: if we can't allocate from the cached free list, we need
|
||||||
// to create a new block.
|
// to create a new block.
|
||||||
void* ptr = nullptr;
|
void* ptr = nullptr;
|
||||||
|
|||||||
@ -173,4 +173,12 @@ unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)>
|
|||||||
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
|
return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::optional<ScalarType> TensorBase::grad_dtype() const {
|
||||||
|
return impl::GetVariableHooks()->grad_dtype(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TensorBase::set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const {
|
||||||
|
return impl::GetVariableHooks()->set_grad_dtype(*this, grad_dtype);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|||||||
@ -930,6 +930,10 @@ public:
|
|||||||
|
|
||||||
const TensorBase& requires_grad_(bool _requires_grad=true) const;
|
const TensorBase& requires_grad_(bool _requires_grad=true) const;
|
||||||
|
|
||||||
|
std::optional<ScalarType> grad_dtype() const;
|
||||||
|
|
||||||
|
void set_grad_dtype(const std::optional<ScalarType>& grad_dtype) const;
|
||||||
|
|
||||||
// View Variables
|
// View Variables
|
||||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@ -117,7 +117,7 @@ C10_HOST_DEVICE inline T cauchy(T val, T median, T sigma) {
|
|||||||
template <>
|
template <>
|
||||||
C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
|
C10_HOST_DEVICE inline double cauchy(double val, double median, double sigma) {
|
||||||
// https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
|
// https://en.wikipedia.org/wiki/Cauchy_distribution#Cumulative_distribution_function
|
||||||
return median + sigma * at::tan(c10::pi<double> * (val - 0.5));
|
return median + sigma * at::tan(c10::pi<double> * (val - static_cast<double>(0.5)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -68,6 +68,8 @@ struct TORCH_API VariableHooksInterface {
|
|||||||
const c10::OperatorHandle& op,
|
const c10::OperatorHandle& op,
|
||||||
c10::DispatchKeySet dispatch_keys,
|
c10::DispatchKeySet dispatch_keys,
|
||||||
torch::jit::Stack* stack) const = 0;
|
torch::jit::Stack* stack) const = 0;
|
||||||
|
virtual std::optional<c10::ScalarType> grad_dtype(const TensorBase&) const = 0;
|
||||||
|
virtual void set_grad_dtype(const TensorBase&, const std::optional<c10::ScalarType>&) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
inline BoxedKernel::BoxedKernel() : boxed_kernel_func_(nullptr) {}
|
inline BoxedKernel::BoxedKernel() : functor_(), boxed_kernel_func_(nullptr) {}
|
||||||
|
|
||||||
inline BoxedKernel::BoxedKernel(
|
inline BoxedKernel::BoxedKernel(
|
||||||
std::unique_ptr<OperatorKernel> functor,
|
std::unique_ptr<OperatorKernel> functor,
|
||||||
|
|||||||
@ -20,7 +20,9 @@ make_unique_base(Args&&... args) {
|
|||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
inline KernelFunction::KernelFunction()
|
inline KernelFunction::KernelFunction()
|
||||||
: unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {}
|
: boxed_kernel_func_(),
|
||||||
|
unboxed_kernel_func_(nullptr),
|
||||||
|
sym_unboxed_kernel_func_(nullptr) {}
|
||||||
|
|
||||||
inline KernelFunction::~KernelFunction() {
|
inline KernelFunction::~KernelFunction() {
|
||||||
if (tokens_) {
|
if (tokens_) {
|
||||||
|
|||||||
@ -76,7 +76,13 @@ void _print_dispatch_trace(const std::string& label, const std::string& op_name,
|
|||||||
|
|
||||||
OpRegistrationListener::~OpRegistrationListener()= default;
|
OpRegistrationListener::~OpRegistrationListener()= default;
|
||||||
|
|
||||||
Dispatcher::Dispatcher(): backendFallbackKernels_(), listeners_(std::make_unique<detail::RegistrationListenerList>()), guard_(std::make_shared<Guard>())
|
Dispatcher::Dispatcher()
|
||||||
|
: operators_()
|
||||||
|
, operatorLookupTable_()
|
||||||
|
, backendFallbackKernels_()
|
||||||
|
, listeners_(std::make_unique<detail::RegistrationListenerList>())
|
||||||
|
, cond_var_()
|
||||||
|
, guard_(std::make_shared<Guard>())
|
||||||
{}
|
{}
|
||||||
|
|
||||||
Dispatcher::~Dispatcher() {
|
Dispatcher::~Dispatcher() {
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class TORCH_API Dispatcher final {
|
|||||||
friend class TypedOperatorHandle;
|
friend class TypedOperatorHandle;
|
||||||
|
|
||||||
struct Guard final {
|
struct Guard final {
|
||||||
Guard() : alive(true) {}
|
Guard() : alive(true), mutex() {}
|
||||||
std::atomic<bool> alive;
|
std::atomic<bool> alive;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -62,7 +62,17 @@ static const auto& getDispatchTableIndexToKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
OperatorEntry::OperatorEntry(OperatorName&& operator_name)
|
||||||
: name_(std::move(operator_name)), dispatchTable_(), dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized()), is_observed_(ObservedOperators::isObserved(name_))
|
: name_(std::move(operator_name))
|
||||||
|
, schema_()
|
||||||
|
#ifndef C10_MOBILE
|
||||||
|
, tags_()
|
||||||
|
#endif
|
||||||
|
, dispatchTable_()
|
||||||
|
, dispatchKeyExtractor_(DispatchKeyExtractor::makeUninitialized())
|
||||||
|
, kernels_()
|
||||||
|
, cpp_signature_()
|
||||||
|
, sym_cpp_signature_()
|
||||||
|
, is_observed_(ObservedOperators::isObserved(name_))
|
||||||
{
|
{
|
||||||
// Pick up any backend fallbacks that were registered prior to this
|
// Pick up any backend fallbacks that were registered prior to this
|
||||||
// OperatorEntry being created.
|
// OperatorEntry being created.
|
||||||
|
|||||||
@ -114,7 +114,7 @@ constexpr bool allowlist_contains(std::string_view allowlist, std::string_view i
|
|||||||
}
|
}
|
||||||
next++;
|
next++;
|
||||||
} else {
|
} else {
|
||||||
if (allowlist.substr(cur) == item) {
|
if (allowlist.substr(cur).compare(item) == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -73,7 +73,7 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(
|
|||||||
|
|
||||||
std::optional<FunctionSchema> inferred_schema = std::nullopt;
|
std::optional<FunctionSchema> inferred_schema = std::nullopt;
|
||||||
for (const auto& kernel : options.kernels) {
|
for (const auto& kernel : options.kernels) {
|
||||||
if (nullptr != kernel.inferred_function_schema) {
|
if (nullptr != kernel.inferred_function_schema.get()) {
|
||||||
if (!inferred_schema.has_value()) {
|
if (!inferred_schema.has_value()) {
|
||||||
inferred_schema = *kernel.inferred_function_schema;
|
inferred_schema = *kernel.inferred_function_schema;
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -411,6 +411,7 @@ public:
|
|||||||
|
|
||||||
Options()
|
Options()
|
||||||
: schemaOrName_(std::nullopt)
|
: schemaOrName_(std::nullopt)
|
||||||
|
, kernels()
|
||||||
, aliasAnalysisKind_(std::nullopt)
|
, aliasAnalysisKind_(std::nullopt)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
@ -419,6 +420,7 @@ public:
|
|||||||
struct KernelRegistrationConfig final {
|
struct KernelRegistrationConfig final {
|
||||||
KernelRegistrationConfig()
|
KernelRegistrationConfig()
|
||||||
: dispatch_key(std::nullopt)
|
: dispatch_key(std::nullopt)
|
||||||
|
, func()
|
||||||
, cpp_signature(std::nullopt)
|
, cpp_signature(std::nullopt)
|
||||||
, inferred_function_schema(nullptr)
|
, inferred_function_schema(nullptr)
|
||||||
{}
|
{}
|
||||||
|
|||||||
@ -905,7 +905,7 @@ class Vectorized8 : public Vectorizedi {
|
|||||||
// Because loadu(const void* ptr, T count) requires zero initialization for
|
// Because loadu(const void* ptr, T count) requires zero initialization for
|
||||||
// upper 128 bits. However, by using _mm256_castsi128_si256, the upper 128
|
// upper 128 bits. However, by using _mm256_castsi128_si256, the upper 128
|
||||||
// bits of the result are undefined.
|
// bits of the result are undefined.
|
||||||
// TODO<leslie> We can use _mm256_zextsi128_si256 in the future,
|
// TODO<leslie> We can use _mm256_zextsi128_si256 in the furture,
|
||||||
// since gcc 9.3 doesn't support it now.
|
// since gcc 9.3 doesn't support it now.
|
||||||
__m128i input_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr));
|
__m128i input_128 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ptr));
|
||||||
return _mm256_castsi128_si256(input_128);
|
return _mm256_castsi128_si256(input_128);
|
||||||
@ -1844,7 +1844,7 @@ Vectorized<int16_t> inline shift_256_16(
|
|||||||
c0 = _mm256_srav_epi32(a0, b0);
|
c0 = _mm256_srav_epi32(a0, b0);
|
||||||
c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
|
c0 = _mm256_shuffle_epi8(c0, ctl_1_0);
|
||||||
|
|
||||||
// Perform shifting the same way for input array elements with
|
// Peform shifting the same way for input array elements with
|
||||||
// idx%2==1.
|
// idx%2==1.
|
||||||
__m256i a1 = _mm256_and_si256(a, keep_1);
|
__m256i a1 = _mm256_and_si256(a, keep_1);
|
||||||
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
|
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
|
||||||
@ -2180,7 +2180,7 @@ Vectorized<T> inline shift_256_8(
|
|||||||
c0 = _mm256_srlv_epi32(a0, b0);
|
c0 = _mm256_srlv_epi32(a0, b0);
|
||||||
c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
|
c0 = _mm256_shuffle_epi8(c0, ctl_3_0);
|
||||||
|
|
||||||
// Perform shifting the same way for input array elements with
|
// Peform shifting the same way for input array elements with
|
||||||
// idx%4==1.
|
// idx%4==1.
|
||||||
__m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
|
__m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3);
|
||||||
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
|
__m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0);
|
||||||
@ -2193,7 +2193,7 @@ Vectorized<T> inline shift_256_8(
|
|||||||
c1 = _mm256_srlv_epi32(a1, b1);
|
c1 = _mm256_srlv_epi32(a1, b1);
|
||||||
c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
|
c1 = _mm256_shuffle_epi8(c1, ctl_3_1);
|
||||||
|
|
||||||
// Perform shifting the same way for input array elements with
|
// Peform shifting the same way for input array elements with
|
||||||
// idx%4==2.
|
// idx%4==2.
|
||||||
__m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
|
__m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3);
|
||||||
__m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
|
__m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0);
|
||||||
@ -2206,7 +2206,7 @@ Vectorized<T> inline shift_256_8(
|
|||||||
c2 = _mm256_srlv_epi32(a2, b2);
|
c2 = _mm256_srlv_epi32(a2, b2);
|
||||||
c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
|
c2 = _mm256_shuffle_epi8(c2, ctl_3_2);
|
||||||
|
|
||||||
// Perform shifting the same way for input array elements with
|
// Peform shifting the same way for input array elements with
|
||||||
// idx%4==3.
|
// idx%4==3.
|
||||||
__m256i a3 = _mm256_and_si256(a, keep_3);
|
__m256i a3 = _mm256_and_si256(a, keep_3);
|
||||||
__m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
|
__m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0);
|
||||||
|
|||||||
@ -1088,7 +1088,7 @@ class Vectorized8 : public Vectorizedi {
|
|||||||
// Because loadu(const void* ptr, T count) requires zero initialization for
|
// Because loadu(const void* ptr, T count) requires zero initialization for
|
||||||
// upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384
|
// upper 384 bits. However, by using _mm512_castsi128_si512, the upper 384
|
||||||
// bits of the result are undefined.
|
// bits of the result are undefined.
|
||||||
// TODO<leslie> We can use _mm512_zextsi128_si512 in the future,
|
// TODO<leslie> We can use _mm512_zextsi128_si512 in the furture,
|
||||||
// since gcc 9.3 doesn't support it now.
|
// since gcc 9.3 doesn't support it now.
|
||||||
__m128i input_128 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
|
__m128i input_128 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr));
|
||||||
return _mm512_castsi128_si512(input_128);
|
return _mm512_castsi128_si512(input_128);
|
||||||
@ -2022,7 +2022,7 @@ Vectorized<T> inline shift_512_8(
|
|||||||
c0 = _mm512_srlv_epi16(a0, b0);
|
c0 = _mm512_srlv_epi16(a0, b0);
|
||||||
c0 = _mm512_shuffle_epi8(c0, ctl_1_0);
|
c0 = _mm512_shuffle_epi8(c0, ctl_1_0);
|
||||||
|
|
||||||
// Perform shifting the same way for input array elements with
|
// Peform shifting the same way for input array elements with
|
||||||
// idx%2==1.
|
// idx%2==1.
|
||||||
__m512i a1 = _mm512_and_si512(a, keep_1);
|
__m512i a1 = _mm512_and_si512(a, keep_1);
|
||||||
__m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
|
__m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0);
|
||||||
|
|||||||
@ -323,7 +323,7 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
|
|||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
|
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
|
||||||
// NOLINTNEXTLINE(bugprone-sizeof-expression)
|
// NOLINTNEXTLINE(bugprone-sizeof-expression)
|
||||||
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
|
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
|
||||||
}
|
}
|
||||||
@ -345,7 +345,7 @@ class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
|
|||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
|
inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
|
||||||
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -360,7 +360,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
|||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
|
inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
|
||||||
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -395,7 +395,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
computeType = CUBLAS_COMPUTE_64F;
|
computeType = CUBLAS_COMPUTE_64F;
|
||||||
scaleType = CUDA_R_64F;
|
scaleType = CUDA_R_64F;
|
||||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
|
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
|
||||||
@ -1270,7 +1270,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
|||||||
}
|
}
|
||||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx11", "gfx12"})) { //no CK GEMM version
|
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
|
||||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||||
} else{
|
} else{
|
||||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||||
@ -1559,7 +1559,7 @@ bool gemm_and_bias(
|
|||||||
computeType = CUBLAS_COMPUTE_64F;
|
computeType = CUBLAS_COMPUTE_64F;
|
||||||
scaleType = CUDA_R_64F;
|
scaleType = CUDA_R_64F;
|
||||||
} else if constexpr (std::is_same_v<Dtype, float>) {
|
} else if constexpr (std::is_same_v<Dtype, float>) {
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||||
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||||
|
|||||||
@ -109,7 +109,7 @@ void CUDAGeneratorState::increase(uint64_t increment) {
|
|||||||
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
|
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
|
||||||
// Ensures the increment does not cause overflow.
|
// Ensures the increment does not cause overflow.
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
offset_intragraph_ <= std::numeric_limits<uint64_t>::max() - increment,
|
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
|
||||||
"Increment causes overflow in the offset value.");
|
"Increment causes overflow in the offset value.");
|
||||||
offset_intragraph_ += increment;
|
offset_intragraph_ += increment;
|
||||||
} else {
|
} else {
|
||||||
@ -461,7 +461,7 @@ void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
|
|||||||
*/
|
*/
|
||||||
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
|
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
|
||||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||||
uint64_t offset = state_->offset_intragraph_;
|
uint32_t offset = state_->offset_intragraph_;
|
||||||
state_->increase(increment);
|
state_->increase(increment);
|
||||||
return PhiloxCudaState(
|
return PhiloxCudaState(
|
||||||
state_->seed_extragraph_.data_ptr<int64_t>(),
|
state_->seed_extragraph_.data_ptr<int64_t>(),
|
||||||
|
|||||||
@ -96,16 +96,16 @@ struct CUDAGraph;
|
|||||||
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
|
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
|
||||||
uint64_t seed_;
|
uint64_t seed_;
|
||||||
uint64_t philox_offset_per_thread_;
|
uint64_t philox_offset_per_thread_;
|
||||||
uint64_t offset_intragraph_;
|
uint32_t offset_intragraph_;
|
||||||
bool capturing_{};
|
bool capturing_{};
|
||||||
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
|
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
|
||||||
at::TensorBase seed_extragraph_;
|
at::TensorBase seed_extragraph_{};
|
||||||
at::TensorBase offset_extragraph_;
|
at::TensorBase offset_extragraph_{};
|
||||||
|
|
||||||
CUDAGeneratorState(
|
CUDAGeneratorState(
|
||||||
uint64_t seed = default_rng_seed_val,
|
uint64_t seed = default_rng_seed_val,
|
||||||
uint64_t philox_offset_per_thread = 0,
|
uint64_t philox_offset_per_thread = 0,
|
||||||
uint64_t offset_intragraph = 0)
|
uint32_t offset_intragraph = 0)
|
||||||
: seed_(seed),
|
: seed_(seed),
|
||||||
philox_offset_per_thread_(philox_offset_per_thread),
|
philox_offset_per_thread_(philox_offset_per_thread),
|
||||||
offset_intragraph_(offset_intragraph) {}
|
offset_intragraph_(offset_intragraph) {}
|
||||||
@ -167,7 +167,7 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
|
|||||||
CUDAGeneratorImpl* clone_impl() const override;
|
CUDAGeneratorImpl* clone_impl() const override;
|
||||||
|
|
||||||
c10::intrusive_ptr<CUDAGeneratorState> state_;
|
c10::intrusive_ptr<CUDAGeneratorState> state_;
|
||||||
std::atomic_flag no_reset_rnn_state_;
|
std::atomic_flag no_reset_rnn_state_{};
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace cuda::detail {
|
namespace cuda::detail {
|
||||||
|
|||||||
@ -56,7 +56,7 @@ struct TORCH_CUDA_CPP_API CUDAGraph {
|
|||||||
|
|
||||||
// the ID assigned by cuda during graph capture,
|
// the ID assigned by cuda during graph capture,
|
||||||
// used to identify when a stream is participating in capture
|
// used to identify when a stream is participating in capture
|
||||||
CaptureId_t capture_id_ = 0;
|
CaptureId_t capture_id_ = -1;
|
||||||
|
|
||||||
// uuid used to request a particular private mempool from CUDACachingAllocator.
|
// uuid used to request a particular private mempool from CUDACachingAllocator.
|
||||||
// By default, this will be set to {id_, 0}.
|
// By default, this will be set to {id_, 0}.
|
||||||
|
|||||||
@ -6,15 +6,43 @@
|
|||||||
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
|
#define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// cuSparse Generic API added in CUDA 10.1
|
||||||
|
// Windows support added in CUDA 11.0
|
||||||
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
|
||||||
|
#define AT_USE_CUSPARSE_GENERIC_API() 1
|
||||||
|
#else
|
||||||
|
#define AT_USE_CUSPARSE_GENERIC_API() 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
|
||||||
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
|
||||||
|
(CUSPARSE_VERSION < 12000)
|
||||||
|
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
|
||||||
|
#else
|
||||||
|
#define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
|
||||||
|
(CUSPARSE_VERSION >= 12000)
|
||||||
|
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
|
||||||
|
#else
|
||||||
|
#define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
// hipSparse const API added in v2.4.0
|
// hipSparse const API added in v2.4.0
|
||||||
#if HIPSPARSE_VERSION >= 200400
|
#if HIPSPARSE_VERSION >= 200400
|
||||||
|
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
|
||||||
|
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
|
||||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||||
#else
|
#else
|
||||||
|
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
|
||||||
|
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
|
||||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||||
#endif
|
#endif
|
||||||
#else // USE_ROCM
|
#else // USE_ROCM
|
||||||
|
#define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
|
||||||
|
#define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
|
||||||
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,8 @@ cusparseStatus_t destroyConstDnMat(const cusparseDnMatDescr* dnMatDescr) {
|
|||||||
return cusparseDestroyDnMat(const_cast<cusparseDnMatDescr*>(dnMatDescr));
|
return cusparseDestroyDnMat(const_cast<cusparseDnMatDescr*>(dnMatDescr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// If a specific GPU model does not provide native support for a given data
|
// If a specific GPU model does not provide native support for a given data
|
||||||
@ -208,4 +210,6 @@ CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int6
|
|||||||
descriptor_.reset(raw_descriptor);
|
descriptor_.reset(raw_descriptor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||||
|
|
||||||
} // namespace at::cuda::sparse
|
} // namespace at::cuda::sparse
|
||||||
|
|||||||
@ -35,6 +35,7 @@ class CuSparseDescriptor {
|
|||||||
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
std::unique_ptr<T, CuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
||||||
template <typename T, cusparseStatus_t (*destructor)(const T*)>
|
template <typename T, cusparseStatus_t (*destructor)(const T*)>
|
||||||
struct ConstCuSparseDescriptorDeleter {
|
struct ConstCuSparseDescriptorDeleter {
|
||||||
void operator()(T* x) {
|
void operator()(T* x) {
|
||||||
@ -57,6 +58,7 @@ class ConstCuSparseDescriptor {
|
|||||||
protected:
|
protected:
|
||||||
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
std::unique_ptr<T, ConstCuSparseDescriptorDeleter<T, destructor>> descriptor_;
|
||||||
};
|
};
|
||||||
|
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS || AT_USE_HIPSPARSE_CONST_DESCRIPTORS
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
using cusparseMatDescr = std::remove_pointer_t<hipsparseMatDescr_t>;
|
using cusparseMatDescr = std::remove_pointer_t<hipsparseMatDescr_t>;
|
||||||
@ -121,8 +123,39 @@ class TORCH_CUDA_CPP_API CuSparseBsrsm2Info
|
|||||||
|
|
||||||
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
|
#endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
|
||||||
|
|
||||||
|
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||||
|
|
||||||
cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
|
cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
|
||||||
|
|
||||||
|
#if AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS()
|
||||||
|
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
|
||||||
|
: public CuSparseDescriptor<cusparseDnMatDescr, &cusparseDestroyDnMat> {
|
||||||
|
public:
|
||||||
|
explicit CuSparseDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
|
||||||
|
};
|
||||||
|
|
||||||
|
class TORCH_CUDA_CPP_API CuSparseConstDnMatDescriptor
|
||||||
|
: public CuSparseDescriptor<const cusparseDnMatDescr, &destroyConstDnMat> {
|
||||||
|
public:
|
||||||
|
explicit CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset = -1);
|
||||||
|
cusparseDnMatDescr* unsafe_mutable_descriptor() const {
|
||||||
|
return const_cast<cusparseDnMatDescr*>(descriptor());
|
||||||
|
}
|
||||||
|
cusparseDnMatDescr* unsafe_mutable_descriptor() {
|
||||||
|
return const_cast<cusparseDnMatDescr*>(descriptor());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class TORCH_CUDA_CPP_API CuSparseDnVecDescriptor
|
||||||
|
: public CuSparseDescriptor<cusparseDnVecDescr, &cusparseDestroyDnVec> {
|
||||||
|
public:
|
||||||
|
explicit CuSparseDnVecDescriptor(const Tensor& input);
|
||||||
|
};
|
||||||
|
|
||||||
|
class TORCH_CUDA_CPP_API CuSparseSpMatDescriptor
|
||||||
|
: public CuSparseDescriptor<cusparseSpMatDescr, &cusparseDestroySpMat> {};
|
||||||
|
|
||||||
|
#elif AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
||||||
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
|
class TORCH_CUDA_CPP_API CuSparseDnMatDescriptor
|
||||||
: public ConstCuSparseDescriptor<
|
: public ConstCuSparseDescriptor<
|
||||||
cusparseDnMatDescr,
|
cusparseDnMatDescr,
|
||||||
@ -161,6 +194,7 @@ cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type);
|
|||||||
: public ConstCuSparseDescriptor<
|
: public ConstCuSparseDescriptor<
|
||||||
cusparseSpMatDescr,
|
cusparseSpMatDescr,
|
||||||
&cusparseDestroySpMat> {};
|
&cusparseDestroySpMat> {};
|
||||||
|
#endif // AT_USE_CUSPARSE_CONST_DESCRIPTORS() || AT_USE_HIPSPARSE_CONST_DESCRIPTORS()
|
||||||
|
|
||||||
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
|
class TORCH_CUDA_CPP_API CuSparseSpMatCsrDescriptor
|
||||||
: public CuSparseSpMatDescriptor {
|
: public CuSparseSpMatDescriptor {
|
||||||
@ -249,4 +283,6 @@ class TORCH_CUDA_CPP_API CuSparseSpGEMMDescriptor
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||||
|
|
||||||
} // namespace at::cuda::sparse
|
} // namespace at::cuda::sparse
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <future>
|
#include <future>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace at::cuda {
|
namespace at::cuda {
|
||||||
namespace {
|
namespace {
|
||||||
@ -71,20 +72,9 @@ using Block = HostBlock<CUDAStream>;
|
|||||||
struct CUDACachingHostAllocatorImpl
|
struct CUDACachingHostAllocatorImpl
|
||||||
: public CachingHostAllocatorImpl<CUDAStream, EventPool::Event> {
|
: public CachingHostAllocatorImpl<CUDAStream, EventPool::Event> {
|
||||||
private:
|
private:
|
||||||
ska::flat_hash_map<void*, bool> use_host_register;
|
std::unordered_map<void*, bool> use_host_register;
|
||||||
|
|
||||||
void allocate_host_memory(size_t size, void** ptr) override {
|
void allocate_host_memory(size_t size, void** ptr) override {
|
||||||
// try allocating from reserve segment first before calling into expensive APIs
|
|
||||||
if (get_reserve_segment().initialized()) {
|
|
||||||
*ptr = get_reserve_segment().allocate(size);
|
|
||||||
if (*ptr != nullptr) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
allocate_host_memory_slowpath(size, ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
void allocate_host_memory_slowpath(size_t size, void** ptr) {
|
|
||||||
// Pinned memory pointers allocated by any device can be directly used by
|
// Pinned memory pointers allocated by any device can be directly used by
|
||||||
// any other device, regardless of the current device at the time of
|
// any other device, regardless of the current device at the time of
|
||||||
// allocation, since we assume unified addressing. So we grab any existing
|
// allocation, since we assume unified addressing. So we grab any existing
|
||||||
@ -123,18 +113,6 @@ struct CUDACachingHostAllocatorImpl
|
|||||||
}
|
}
|
||||||
|
|
||||||
void free_block(Block* block) override {
|
void free_block(Block* block) override {
|
||||||
// We never free blocks from the reserve segment
|
|
||||||
if (get_reserve_segment().initialized()) {
|
|
||||||
// Check if the block is from the reserve segment
|
|
||||||
if (get_reserve_segment().owns(block->ptr_)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
free_block_slowpath(block);
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_block_slowpath(Block* block) {
|
|
||||||
auto start = std::chrono::steady_clock::now();
|
auto start = std::chrono::steady_clock::now();
|
||||||
// Users may change the allocator config at will. torch unit tests do this.
|
// Users may change the allocator config at will. torch unit tests do this.
|
||||||
// However, allocations using cudaHostRegister should use corresonding
|
// However, allocations using cudaHostRegister should use corresonding
|
||||||
@ -194,20 +172,6 @@ struct CUDACachingHostAllocatorImpl
|
|||||||
return event_pool->get(idx);
|
return event_pool->get(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
PinnedReserveSegment& get_reserve_segment() {
|
|
||||||
static auto reserve_segment = [&]() {
|
|
||||||
if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_reserve_segment_size_mb() > 0) {
|
|
||||||
void *ptr;
|
|
||||||
size_t sz = c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_reserve_segment_size_mb() * 1024 * 1024;
|
|
||||||
allocate_host_memory_slowpath(sz, &ptr);
|
|
||||||
return PinnedReserveSegment(ptr, sz);
|
|
||||||
} else {
|
|
||||||
return PinnedReserveSegment();
|
|
||||||
}
|
|
||||||
} ();
|
|
||||||
return reserve_segment;
|
|
||||||
}
|
|
||||||
|
|
||||||
TaskThreadPool* getThreadPool() {
|
TaskThreadPool* getThreadPool() {
|
||||||
static TaskThreadPool* pool = new TaskThreadPool(
|
static TaskThreadPool* pool = new TaskThreadPool(
|
||||||
static_cast<int>(c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
static_cast<int>(c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||||
@ -222,15 +186,15 @@ struct CUDACachingHostAllocatorImpl
|
|||||||
size_t numThreads,
|
size_t numThreads,
|
||||||
size_t pageSize) {
|
size_t pageSize) {
|
||||||
uintptr_t start = (uintptr_t)ptr + (size * i / numThreads);
|
uintptr_t start = (uintptr_t)ptr + (size * i / numThreads);
|
||||||
uintptr_t end = start + (size / numThreads);
|
uintptr_t end = (uintptr_t)start + (size / numThreads);
|
||||||
if (i == (numThreads - 1)) {
|
if (i == (numThreads - 1)) {
|
||||||
end = (uintptr_t)ptr + size;
|
end = (uintptr_t)ptr + size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-fault/map the pages by setting the first byte of the page
|
// pre-fault/map the pages by setting the first byte of the page
|
||||||
uintptr_t alignedStart =
|
uintptr_t alignedStart =
|
||||||
((start + pageSize - 1) & ~(pageSize - 1));
|
(((uintptr_t)start + pageSize - 1) & ~(pageSize - 1));
|
||||||
for (uintptr_t p = alignedStart; p < (end); p += pageSize) {
|
for (uintptr_t p = alignedStart; p < ((uintptr_t)end); p += pageSize) {
|
||||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||||
memset((void*)p, 0, 1);
|
memset((void*)p, 0, 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -310,7 +310,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||||
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
|
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
|
||||||
if (!NoTF32Guard::should_disable_tf32() &&
|
if (!NoTF32Guard::should_disable_tf32() &&
|
||||||
at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
} else {
|
} else {
|
||||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||||
|
|||||||
@ -122,7 +122,7 @@ struct DeviceThreadHandlePool : public std::enable_shared_from_this<DeviceThread
|
|||||||
|
|
||||||
// Called by the destructor. Releases this thread's handles back into the pool.
|
// Called by the destructor. Releases this thread's handles back into the pool.
|
||||||
void release() {
|
void release() {
|
||||||
if(!my_handles.empty()) {
|
if(my_handles.size() > 0) {
|
||||||
auto parent = weak_parent.lock();
|
auto parent = weak_parent.lock();
|
||||||
if (!parent) {
|
if (!parent) {
|
||||||
// If this thread exits after atexit handlers have completed, the
|
// If this thread exits after atexit handlers have completed, the
|
||||||
|
|||||||
@ -19,7 +19,7 @@ struct PhiloxCudaState {
|
|||||||
// Called if graph capture is underway
|
// Called if graph capture is underway
|
||||||
PhiloxCudaState(int64_t* seed,
|
PhiloxCudaState(int64_t* seed,
|
||||||
int64_t* offset_extragraph,
|
int64_t* offset_extragraph,
|
||||||
uint64_t offset_intragraph) {
|
uint32_t offset_intragraph) {
|
||||||
seed_.ptr = seed;
|
seed_.ptr = seed;
|
||||||
offset_.ptr = offset_extragraph;
|
offset_.ptr = offset_extragraph;
|
||||||
offset_intragraph_ = offset_intragraph;
|
offset_intragraph_ = offset_intragraph;
|
||||||
@ -36,7 +36,7 @@ struct PhiloxCudaState {
|
|||||||
|
|
||||||
Payload seed_{};
|
Payload seed_{};
|
||||||
Payload offset_{};
|
Payload offset_{};
|
||||||
uint64_t offset_intragraph_ = 0;
|
uint32_t offset_intragraph_ = 0;
|
||||||
bool captured_ = false;
|
bool captured_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -162,7 +162,7 @@ inline std::string ComputeTypeFor() {
|
|||||||
// ROCBLAS and hipBLASLt.
|
// ROCBLAS and hipBLASLt.
|
||||||
template <>
|
template <>
|
||||||
inline std::string ComputeTypeFor<float>() {
|
inline std::string ComputeTypeFor<float>() {
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) != at::Float32Precision::TF32) {
|
if (at::globalContext().float32Precision("cuda", "matmul") != "tf32") {
|
||||||
return "f32_r";
|
return "f32_r";
|
||||||
} else {
|
} else {
|
||||||
return "xf32_r";
|
return "xf32_r";
|
||||||
|
|||||||
@ -506,7 +506,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) {
|
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") {
|
||||||
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
computeType = HIPBLAS_COMPUTE_32F_FAST_TF32;
|
||||||
}
|
}
|
||||||
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F);
|
||||||
|
|||||||
@ -141,7 +141,7 @@ class RocblasGemmOp : public Callable<GemmParams<T>> {
|
|||||||
|
|
||||||
TuningStatus Call(const GemmParams<T>* params) override {
|
TuningStatus Call(const GemmParams<T>* params) override {
|
||||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||||
return FAIL; // no support for TF32 in rocBLAS
|
return FAIL; // no support for TF32 in rocBLAS
|
||||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||||
@ -209,7 +209,7 @@ class RocblasGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
|||||||
|
|
||||||
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
TuningStatus Call(const GemmStridedBatchedParams<T>* params) override {
|
||||||
auto input_output_type = RocBlasDataTypeFor<T>();
|
auto input_output_type = RocBlasDataTypeFor<T>();
|
||||||
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
|
if (at::globalContext().float32Precision("cuda", "matmul") == "tf32" && input_output_type == rocblas_datatype_f32_r)
|
||||||
return FAIL; // no support for TF32 in rocBLAS
|
return FAIL; // no support for TF32 in rocBLAS
|
||||||
auto compute_type = RocBlasComputeTypeFor<T>();
|
auto compute_type = RocBlasComputeTypeFor<T>();
|
||||||
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
|
||||||
|
|||||||
@ -404,6 +404,8 @@ TuningContext::TuningContext() :
|
|||||||
max_warmup_iterations_{0},
|
max_warmup_iterations_{0},
|
||||||
icache_flush_{true},
|
icache_flush_{true},
|
||||||
rotating_buffer_size_{-1},
|
rotating_buffer_size_{-1},
|
||||||
|
filename_{},
|
||||||
|
untuned_file_{},
|
||||||
results_count_from_input_file_{0},
|
results_count_from_input_file_{0},
|
||||||
is_shutting_down_{false}
|
is_shutting_down_{false}
|
||||||
{
|
{
|
||||||
|
|||||||
@ -141,7 +141,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
|
|||||||
size[i] = (int) t.size(i);
|
size[i] = (int) t.size(i);
|
||||||
}
|
}
|
||||||
for (const auto i : c10::irange(dim, pad)) {
|
for (const auto i : c10::irange(dim, pad)) {
|
||||||
size[i] = 1;
|
size[i] = (int) 1;
|
||||||
}
|
}
|
||||||
dim = std::max(dim, pad);
|
dim = std::max(dim, pad);
|
||||||
cudnnTensorFormat_t filter_format{};
|
cudnnTensorFormat_t filter_format{};
|
||||||
|
|||||||
@ -176,7 +176,7 @@ struct LinalgCheckMatrixUnaryRuleHelper;
|
|||||||
|
|
||||||
template <char const *op_name, typename F, F Func, typename A, typename... T>
|
template <char const *op_name, typename F, F Func, typename A, typename... T>
|
||||||
struct LinalgCheckMatrixUnaryRuleHelper<op_name, F, Func, typelist<A, T...>> {
|
struct LinalgCheckMatrixUnaryRuleHelper<op_name, F, Func, typelist<A, T...>> {
|
||||||
static Tensor check_and_reshape_input(const Tensor& tensor, std::optional<int64_t> batch_dim) {
|
static inline Tensor check_and_reshape_input(const Tensor& tensor, std::optional<int64_t> batch_dim) {
|
||||||
TORCH_CHECK(rankWithoutBatchDim(tensor, batch_dim) >= 2, op_name, ": The input tensor A must have at least 2 dimensions.");
|
TORCH_CHECK(rankWithoutBatchDim(tensor, batch_dim) >= 2, op_name, ": The input tensor A must have at least 2 dimensions.");
|
||||||
return moveBatchDimToFront(tensor, batch_dim);
|
return moveBatchDimToFront(tensor, batch_dim);
|
||||||
}
|
}
|
||||||
@ -222,7 +222,7 @@ struct LinalgCheckMatrixBinaryRuleHelper;
|
|||||||
|
|
||||||
template <char const *op_name, typename F, F Func, typename A, typename B, typename... T>
|
template <char const *op_name, typename F, F Func, typename A, typename B, typename... T>
|
||||||
struct LinalgCheckMatrixBinaryRuleHelper<op_name, F, Func, typelist<A, B, T...>> {
|
struct LinalgCheckMatrixBinaryRuleHelper<op_name, F, Func, typelist<A, B, T...>> {
|
||||||
static std::tuple<Tensor, Tensor> check_inputs_and_reshape_inputs(
|
static inline std::tuple<Tensor, Tensor> check_inputs_and_reshape_inputs(
|
||||||
const Tensor& first, std::optional<int64_t> first_bdim,
|
const Tensor& first, std::optional<int64_t> first_bdim,
|
||||||
const Tensor& second, std::optional<int64_t> second_bdim) {
|
const Tensor& second, std::optional<int64_t> second_bdim) {
|
||||||
TORCH_CHECK(rankWithoutBatchDim(first, first_bdim) >= 2,
|
TORCH_CHECK(rankWithoutBatchDim(first, first_bdim) >= 2,
|
||||||
|
|||||||
@ -58,7 +58,7 @@ scalar_t dot_impl(int64_t n, const scalar_t *x, int64_t incx, const scalar_t *y,
|
|||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
scalar_t vdot_impl(int64_t n, const scalar_t *x, int64_t incx, const scalar_t *y, int64_t incy);
|
scalar_t vdot_impl(int64_t n, const scalar_t *x, int64_t incx, const scalar_t *y, int64_t incy);
|
||||||
|
|
||||||
static constexpr bool lda_cond(int64_t m, int64_t n, int64_t lda) {
|
static constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) {
|
||||||
return n == 1 || lda >= std::max<int64_t>(1L, m);
|
return n == 1 || lda >= std::max<int64_t>(1L, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -991,7 +991,7 @@ std::size_t UnsafeUkernelKeyHasher<PackKey>::operator()(const PackKey& key) cons
|
|||||||
template <typename key_t, typename value_t>
|
template <typename key_t, typename value_t>
|
||||||
struct KernelCache {
|
struct KernelCache {
|
||||||
using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
|
using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
|
||||||
static std::shared_ptr<value_t>&& fetch_or_create(
|
static inline std::shared_ptr<value_t>&& fetch_or_create(
|
||||||
const key_t& key,
|
const key_t& key,
|
||||||
const std::function<std::shared_ptr<value_t>()>& callback) {
|
const std::function<std::shared_ptr<value_t>()>& callback) {
|
||||||
auto&& search = get_store().find(key);
|
auto&& search = get_store().find(key);
|
||||||
@ -1003,7 +1003,7 @@ struct KernelCache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static kstore_t& get_store() {
|
static inline kstore_t& get_store() {
|
||||||
static thread_local kstore_t cache_kernels;
|
static thread_local kstore_t cache_kernels;
|
||||||
return cache_kernels;
|
return cache_kernels;
|
||||||
}
|
}
|
||||||
@ -1067,7 +1067,7 @@ struct GemmHelper {
|
|||||||
struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
||||||
// Fetch/create GemmHelper object and execute brgemm with batch size = 1
|
// Fetch/create GemmHelper object and execute brgemm with batch size = 1
|
||||||
template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
|
template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
|
||||||
static void call(
|
static inline void call(
|
||||||
int64_t M,
|
int64_t M,
|
||||||
int64_t N,
|
int64_t N,
|
||||||
int64_t K,
|
int64_t K,
|
||||||
@ -1118,12 +1118,12 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
|||||||
.execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
|
.execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::shared_ptr<GemmHelper>& get_current() {
|
static inline std::shared_ptr<GemmHelper>& get_current() {
|
||||||
static thread_local std::shared_ptr<GemmHelper> current;
|
static thread_local std::shared_ptr<GemmHelper> current;
|
||||||
return current;
|
return current;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool device_check(ScalarType dtype) {
|
static inline bool device_check(ScalarType dtype) {
|
||||||
if (!at::globalContext().userEnabledMkldnn()) {
|
if (!at::globalContext().userEnabledMkldnn()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1153,7 +1153,7 @@ using pack_t = dnnl::ukernel::brgemm_pack_B;
|
|||||||
using pack_t = dnnl::ukernel::transform;
|
using pack_t = dnnl::ukernel::transform;
|
||||||
#endif
|
#endif
|
||||||
struct Pack : public KernelCache <PackKey, pack_t> {
|
struct Pack : public KernelCache <PackKey, pack_t> {
|
||||||
static void call(
|
static inline void call(
|
||||||
int64_t K,
|
int64_t K,
|
||||||
int64_t N,
|
int64_t N,
|
||||||
int64_t ld_in,
|
int64_t ld_in,
|
||||||
@ -1182,7 +1182,7 @@ struct Pack : public KernelCache <PackKey, pack_t> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool could_pack(ScalarType dtype) {
|
static inline bool could_pack(ScalarType dtype) {
|
||||||
if (!at::globalContext().userEnabledMkldnn()) {
|
if (!at::globalContext().userEnabledMkldnn()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -702,7 +702,7 @@ static void check_shape_forward(const at::Tensor& input,
|
|||||||
// If kernel size is incorrect
|
// If kernel size is incorrect
|
||||||
std::ostringstream input_ss;
|
std::ostringstream input_ss;
|
||||||
std::ostringstream kernel_ss;
|
std::ostringstream kernel_ss;
|
||||||
std::string separator;
|
std::string separator = "";
|
||||||
|
|
||||||
for (int i = 0, len = input_shape.size(); i < len; ++i) {
|
for (int i = 0, len = input_shape.size(); i < len; ++i) {
|
||||||
input_ss << separator << input_shape[i];
|
input_ss << separator << input_shape[i];
|
||||||
@ -1019,7 +1019,7 @@ static Tensor convolution_same(
|
|||||||
|
|
||||||
if (symmetric_padding) {
|
if (symmetric_padding) {
|
||||||
// All backends handle symmetric padding natively
|
// All backends handle symmetric padding natively
|
||||||
SymDimVector output_padding(dim);
|
SymDimVector output_padding(static_cast<size_t>(dim));
|
||||||
return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
|
return at::convolution_symint(input, weight, bias, stride, padding_l, dilation,
|
||||||
false, output_padding, groups);
|
false, output_padding, groups);
|
||||||
}
|
}
|
||||||
@ -1039,7 +1039,7 @@ static Tensor convolution_same(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
|
auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0);
|
||||||
SymDimVector output_padding(dim);
|
SymDimVector output_padding(static_cast<size_t>(dim));
|
||||||
return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
|
return at::convolution_symint(padded_input, weight, bias, stride, padding_l,
|
||||||
dilation, false, output_padding, groups);
|
dilation, false, output_padding, groups);
|
||||||
}
|
}
|
||||||
@ -1174,7 +1174,7 @@ at::Tensor convolution(
|
|||||||
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||||
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
return at::_convolution(input, weight, bias, stride, padding, dilation,
|
||||||
transposed, output_padding, groups,
|
transposed, output_padding, groups,
|
||||||
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN(at::Float32Op::CONV));
|
ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN("conv"));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor convolution_overrideable(
|
at::Tensor convolution_overrideable(
|
||||||
@ -1319,7 +1319,7 @@ ConvBackend select_conv_backend(
|
|||||||
params.benchmark = ctx.benchmarkCuDNN();
|
params.benchmark = ctx.benchmarkCuDNN();
|
||||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||||
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||||
|
|
||||||
auto input = input_r;
|
auto input = input_r;
|
||||||
auto weight = weight_r;
|
auto weight = weight_r;
|
||||||
@ -1699,7 +1699,7 @@ at::Tensor _convolution(
|
|||||||
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt);
|
||||||
const Tensor& bias_r = *bias_r_maybe_owned;
|
const Tensor& bias_r = *bias_r_maybe_owned;
|
||||||
|
|
||||||
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN(at::Float32Op::CONV));
|
return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN("conv"));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable(
|
||||||
@ -1997,7 +1997,7 @@ std::tuple<Tensor, Tensor, Tensor> convolution_backward(
|
|||||||
params.benchmark = ctx.benchmarkCuDNN();
|
params.benchmark = ctx.benchmarkCuDNN();
|
||||||
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms();
|
||||||
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
params.cudnn_enabled = ctx.userEnabledCuDNN();
|
||||||
params.allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
params.allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||||
|
|
||||||
// Validate inputs.
|
// Validate inputs.
|
||||||
check_shape_backward(input, weight.sizes(), params);
|
check_shape_backward(input, weight.sizes(), params);
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/native/Copy.h>
|
#include <ATen/native/Copy.h>
|
||||||
|
#include <ATen/native/Copy.h>
|
||||||
|
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
|||||||
@ -70,7 +70,7 @@ Tensor constant_pad_nd(const Tensor& self, IntArrayRef pad, const Scalar& value)
|
|||||||
new_shape.emplace_back(input_sizes[i]);
|
new_shape.emplace_back(input_sizes[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto i : c10::irange(l_pad)) {
|
for (const auto i : c10::irange((size_t)l_pad)) {
|
||||||
auto pad_idx = pad.size() - ((i + 1) * 2);
|
auto pad_idx = pad.size() - ((i + 1) * 2);
|
||||||
auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
|
auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
|
||||||
TORCH_CHECK(new_dim >= 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ",
|
TORCH_CHECK(new_dim >= 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ",
|
||||||
|
|||||||
@ -47,7 +47,7 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar
|
|||||||
int64_t sgn = (xstep > 0) - (xstep < 0);
|
int64_t sgn = (xstep > 0) - (xstep < 0);
|
||||||
size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
|
size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
|
||||||
} else {
|
} else {
|
||||||
size_d = std::ceil((end.to<double>() - start.to<double>())
|
size_d = std::ceil(static_cast<double>(end.to<double>() - start.to<double>())
|
||||||
/ step.to<double>());
|
/ step.to<double>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -107,6 +107,11 @@ void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes) {
|
|||||||
storage->set_nbytes(size_bytes);
|
storage->set_nbytes(size_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call the sparse implementation in SparseTensor.cpp directly.
|
||||||
|
// A dynamic dispatch here is NOT necessary, so I didn't put
|
||||||
|
// this function in native_functions.yaml
|
||||||
|
const Tensor& resize_as_sparse_(const Tensor& self, const Tensor& src);
|
||||||
|
|
||||||
// TODO(VitalyFedyunin): Move it to HTML docs.
|
// TODO(VitalyFedyunin): Move it to HTML docs.
|
||||||
//
|
//
|
||||||
// Strides of the output tensor of `resize_as_` operator is defined by input
|
// Strides of the output tensor of `resize_as_` operator is defined by input
|
||||||
|
|||||||
@ -145,6 +145,12 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
|
||||||
|
AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
|
||||||
|
|
||||||
|
} // namespace at::native
|
||||||
|
|
||||||
namespace at::meta {
|
namespace at::meta {
|
||||||
|
|
||||||
TORCH_META_FUNC(gather)
|
TORCH_META_FUNC(gather)
|
||||||
|
|||||||
@ -73,6 +73,7 @@
|
|||||||
#include <ATen/ops/where_native.h>
|
#include <ATen/ops/where_native.h>
|
||||||
#include <ATen/ops/zeros_like.h>
|
#include <ATen/ops/zeros_like.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -124,7 +124,7 @@ struct IsUnique {};
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
struct IsUnique<scalar_t, false> {
|
struct IsUnique<scalar_t, false> {
|
||||||
bool operator() (scalar_t* data_ptr, int64_t i) {
|
inline bool operator() (scalar_t* data_ptr, int64_t i) {
|
||||||
if (i == 0) { return true; }
|
if (i == 0) { return true; }
|
||||||
return c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]);
|
return c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]);
|
||||||
}
|
}
|
||||||
@ -132,7 +132,7 @@ struct IsUnique<scalar_t, false> {
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
struct IsUnique<scalar_t, true> {
|
struct IsUnique<scalar_t, true> {
|
||||||
bool operator() (scalar_t* data_ptr, int64_t i) {
|
inline bool operator() (scalar_t* data_ptr, int64_t i) {
|
||||||
if (i == 0) { return true; }
|
if (i == 0) { return true; }
|
||||||
return (c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]))
|
return (c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]))
|
||||||
&& !(_isnan(data_ptr[i]) && _isnan(data_ptr[i - 1]));
|
&& !(_isnan(data_ptr[i]) && _isnan(data_ptr[i - 1]));
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include <ATen/OpMathType.h>
|
#include <ATen/OpMathType.h>
|
||||||
#include <ATen/TensorUtils.h>
|
#include <ATen/TensorUtils.h>
|
||||||
|
#include <ATen/OpMathType.h>
|
||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/cpu/vec/functional.h>
|
#include <ATen/cpu/vec/functional.h>
|
||||||
#include <ATen/cpu/vec/vec.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
|
|||||||
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
namespace ao::sparse {
|
namespace ao::sparse {
|
||||||
|
|
||||||
|
int register_linear_params();
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
namespace ao::sparse {
|
namespace ao::sparse {
|
||||||
|
|
||||||
|
int register_linear_params();
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
namespace {
|
namespace {
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace ao::sparse {
|
namespace ao::sparse {
|
||||||
|
int register_linear_params();
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
|
|
||||||
|
|||||||
@ -22,7 +22,7 @@ static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
|||||||
old_value.floatV = *dst;
|
old_value.floatV = *dst;
|
||||||
new_value.floatV = old_value.floatV + fvalue;
|
new_value.floatV = old_value.floatV + fvalue;
|
||||||
|
|
||||||
unsigned* old_intV = &old_value.intV;
|
unsigned* old_intV = (unsigned*)(&old_value.intV);
|
||||||
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
|
while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
__asm__ __volatile__("yield;" : : : "memory");
|
__asm__ __volatile__("yield;" : : : "memory");
|
||||||
|
|||||||
@ -118,7 +118,7 @@ gemm_notrans_(
|
|||||||
scale_(m, n, beta, c, ldc);
|
scale_(m, n, beta, c, ldc);
|
||||||
|
|
||||||
// c += alpha * (a @ b)
|
// c += alpha * (a @ b)
|
||||||
const uint64_t unsigned_m = m;
|
const uint64_t unsigned_m = static_cast<int64_t>(m);
|
||||||
const uint64_t i_m = unsigned_m / 4;
|
const uint64_t i_m = unsigned_m / 4;
|
||||||
for (const uint64_t l : c10::irange(k)) {
|
for (const uint64_t l : c10::irange(k)) {
|
||||||
for (const uint64_t j : c10::irange(n)) {
|
for (const uint64_t j : c10::irange(n)) {
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <ATen/OpMathType.h>
|
#include <ATen/OpMathType.h>
|
||||||
#include <ATen/native/cpu/utils.h>
|
#include <ATen/native/cpu/utils.h>
|
||||||
|
#include <ATen/OpMathType.h>
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
inline namespace CPU_CAPABILITY {
|
inline namespace CPU_CAPABILITY {
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#include <ATen/cpu/vec/functional.h>
|
#include <ATen/cpu/vec/functional.h>
|
||||||
#include <ATen/cpu/vec/vec.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <ATen/OpMathType.h>
|
||||||
|
|
||||||
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
|
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
|
||||||
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
|
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
|
||||||
|
|||||||
@ -240,7 +240,7 @@ static void unfolded2d_copy(
|
|||||||
int64_t output_height,
|
int64_t output_height,
|
||||||
int64_t output_width) {
|
int64_t output_width) {
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, n_input_plane * kH * kW, 0, [&](int64_t start, int64_t end) {
|
0, (int64_t)n_input_plane * kH * kW, 0, [&](int64_t start, int64_t end) {
|
||||||
for (const auto k : c10::irange(start, end)) {
|
for (const auto k : c10::irange(start, end)) {
|
||||||
int64_t nip = k / (kH * kW);
|
int64_t nip = k / (kH * kW);
|
||||||
int64_t rest = k % (kH * kW);
|
int64_t rest = k % (kH * kW);
|
||||||
@ -316,7 +316,7 @@ static void unfolded2d_copy(
|
|||||||
for (int64_t x = 0; x < output_width; x++)
|
for (int64_t x = 0; x < output_width; x++)
|
||||||
memcpy(
|
memcpy(
|
||||||
dst + (size_t)y * output_width + x,
|
dst + (size_t)y * output_width + x,
|
||||||
src + (size_t)iy * input_width + ix + x * dW,
|
src + (size_t)iy * input_width + ix + (int64_t)x * dW,
|
||||||
sizeof(scalar_t) * (1));
|
sizeof(scalar_t) * (1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -906,7 +906,7 @@ static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
|||||||
// Round to nearest integer
|
// Round to nearest integer
|
||||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||||
|
|
||||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride;
|
||||||
|
|
||||||
// LHS offset at the beginning of the row
|
// LHS offset at the beginning of the row
|
||||||
*((float*)(dst_ptr)) = recip_scale0;
|
*((float*)(dst_ptr)) = recip_scale0;
|
||||||
@ -1048,7 +1048,7 @@ static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
|||||||
zero_point0 = (std::min)(zero_point0, qmax);
|
zero_point0 = (std::min)(zero_point0, qmax);
|
||||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||||
|
|
||||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride;
|
||||||
|
|
||||||
*((float*)(dst_ptr)) = recip_scale0;
|
*((float*)(dst_ptr)) = recip_scale0;
|
||||||
dst_ptr += sizeof(float);
|
dst_ptr += sizeof(float);
|
||||||
|
|||||||
@ -285,8 +285,8 @@ static bool isSupportedHipLtROCmArch(int index) {
|
|||||||
#if ROCM_VERSION >= 60300
|
#if ROCM_VERSION >= 60300
|
||||||
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
"gfx1100", "gfx1101", "gfx1200", "gfx1201", "gfx908",
|
||||||
#endif
|
#endif
|
||||||
#if ROCM_VERSION >= 70000
|
#if ROCM_VERSION >= 60500
|
||||||
"gfx950", "gfx1150", "gfx1151"
|
"gfx950"
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
return at::detail::getCUDAHooks().isGPUArch(archs, index);
|
||||||
@ -1919,7 +1919,7 @@ Tensor& _mm_dtype_out_cuda(const Tensor& self, const Tensor& mat2, const at::Sca
|
|||||||
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
TORCH_CHECK(out_dtype == out.scalar_type(), "out_dtype must be the same as the dtype of the provided out tensor");
|
||||||
|
|
||||||
|
|
||||||
addmm_out_cuda_impl(out, out, self, mat2, 0, 1);
|
addmm_out_cuda_impl(const_cast<Tensor&>(out), out, self, mat2, 0, 1);
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -102,7 +102,13 @@ __host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::comple
|
|||||||
}
|
}
|
||||||
|
|
||||||
void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
|
void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
|
||||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
|
// Compile time for CUDA-11.4 is 3x slower than with CUDA-11.6+, specifically for complex numbers
|
||||||
|
#if defined(FBCODE_CAFFE2) || defined(OVRSOURCE)
|
||||||
|
#define _LCME_DISPATCH AT_DISPATCH_FLOATING_TYPES_AND2
|
||||||
|
#else
|
||||||
|
#define _LCME_DISPATCH AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2
|
||||||
|
#endif
|
||||||
|
_LCME_DISPATCH(ScalarType::Half, ScalarType::BFloat16,
|
||||||
self.scalar_type(), "logcumsumexp_cuda",
|
self.scalar_type(), "logcumsumexp_cuda",
|
||||||
[&]() {
|
[&]() {
|
||||||
using opmath_t = at::opmath_type<scalar_t>;
|
using opmath_t = at::opmath_type<scalar_t>;
|
||||||
|
|||||||
@ -230,7 +230,7 @@ constexpr int BLOCK_THREADS = 256;
|
|||||||
constexpr int RADIX_BITS = 8;
|
constexpr int RADIX_BITS = 8;
|
||||||
constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 2 ^ RADIX_BITS
|
constexpr int RADIX_DIGITS = 1 << RADIX_BITS; // 2 ^ RADIX_BITS
|
||||||
constexpr int RADIX_MASK = (RADIX_DIGITS - 1);
|
constexpr int RADIX_MASK = (RADIX_DIGITS - 1);
|
||||||
static_assert(RADIX_DIGITS <= BLOCK_THREADS, "RADIX_DIGITS must be <= BLOCK_THREADS");
|
static_assert(RADIX_DIGITS <= BLOCK_THREADS, "radixFindKthValues kernel requires RADIX_DIGITS <= BLOCK_THREADS");
|
||||||
constexpr int MIN_ITEMS_PER_THREAD = 4;
|
constexpr int MIN_ITEMS_PER_THREAD = 4;
|
||||||
constexpr int MAX_ITEMS_PER_THREAD = 64;
|
constexpr int MAX_ITEMS_PER_THREAD = 64;
|
||||||
|
|
||||||
@ -242,10 +242,11 @@ __global__ void fill(T* x, T value, IndexType size) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute local histogram for each block
|
// find the kth smallest value,
|
||||||
|
// for largest topk, k_to_find = slice_size - k + 1
|
||||||
template <typename T, typename IndexType, typename Bitwise, int Dim>
|
template <typename T, typename IndexType, typename Bitwise, int Dim>
|
||||||
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
|
C10_LAUNCH_BOUNDS_1(BLOCK_THREADS)
|
||||||
__global__ void computeBlockDigitCounts(
|
__global__ void radixFindKthValues(
|
||||||
at::cuda::detail::TensorInfo<const T, IndexType> input,
|
at::cuda::detail::TensorInfo<const T, IndexType> input,
|
||||||
uint32_t slice_size,
|
uint32_t slice_size,
|
||||||
uint32_t* ks_to_find, // size: num_slices, unused arg but for mysterious reasons perf is better when it's present
|
uint32_t* ks_to_find, // size: num_slices, unused arg but for mysterious reasons perf is better when it's present
|
||||||
@ -320,51 +321,12 @@ __global__ void computeBlockDigitCounts(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute global histogram and cumsum for each row
|
|
||||||
__global__ void computeDigitCumSum(
|
|
||||||
short* counts,
|
|
||||||
uint32_t* digit_cum_sum,
|
|
||||||
uint32_t blocks_per_slice) {
|
|
||||||
int tidx = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
int digit_idx = threadIdx.x;
|
|
||||||
uint32_t slice_idx = blockIdx.x;
|
|
||||||
|
|
||||||
typedef cub::BlockScan<uint32_t, RADIX_DIGITS> BlockScan;
|
|
||||||
__shared__ typename BlockScan::TempStorage scan_storage;
|
|
||||||
// accumulates counters from multiple blocks
|
|
||||||
uint32_t digit_count = 0;
|
|
||||||
if (threadIdx.x < RADIX_DIGITS) {
|
|
||||||
constexpr int HISTO_ACCUM_TILE = 4;
|
|
||||||
uint32_t rounds = blocks_per_slice / HISTO_ACCUM_TILE;
|
|
||||||
for (int iter = 0; iter < rounds; iter++) {
|
|
||||||
int base = HISTO_ACCUM_TILE * iter;
|
|
||||||
#pragma unroll
|
|
||||||
for (int j = 0; j < HISTO_ACCUM_TILE; j++) {
|
|
||||||
int blk = base + j;
|
|
||||||
digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + digit_idx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int blk = HISTO_ACCUM_TILE * rounds; blk < blocks_per_slice; blk++) {
|
|
||||||
digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + digit_idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
// compute the block-wide inclusive prefix sum
|
|
||||||
uint32_t digit_count_cumsum;
|
|
||||||
BlockScan(scan_storage).InclusiveSum(digit_count, digit_count_cumsum);
|
|
||||||
__syncthreads();
|
|
||||||
if (threadIdx.x < RADIX_DIGITS) {
|
|
||||||
digit_cum_sum[tidx] = digit_count_cumsum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assumption: k can not be larger than UINT32_MAX
|
// Assumption: k can not be larger than UINT32_MAX
|
||||||
template <typename Bitwise, typename T>
|
template <typename Bitwise, typename T>
|
||||||
C10_LAUNCH_BOUNDS_1(RADIX_DIGITS) // one thread per digit
|
C10_LAUNCH_BOUNDS_1(RADIX_DIGITS) // one thread per digit
|
||||||
__global__ void computeBlockwiseWithinKCounts(
|
__global__ void computeBlockwiseWithinKCounts(
|
||||||
Bitwise* desires_in, // size: num_slices
|
Bitwise* desires_in, // size: num_slices
|
||||||
short* counts, // size: num_slices * blocks_per_slice * radix_digits
|
short* counts, // size: num_slices * blocks_per_slice * radix_digits
|
||||||
uint32_t* digit_cum_sum,
|
|
||||||
uint32_t* ks_to_find_in, // size: num_slices
|
uint32_t* ks_to_find_in, // size: num_slices
|
||||||
uint32_t blocks_per_slice,
|
uint32_t blocks_per_slice,
|
||||||
int current_bit,
|
int current_bit,
|
||||||
@ -376,7 +338,7 @@ __global__ void computeBlockwiseWithinKCounts(
|
|||||||
Bitwise* desires_out,
|
Bitwise* desires_out,
|
||||||
uint32_t num_blocks
|
uint32_t num_blocks
|
||||||
) {
|
) {
|
||||||
// This kernel should be launched with the same number of blocks as the `computeBlockDigitCounts` kernel.
|
// This kernel should be launched with the same number of blocks as the `radixFindKthValues` kernel.
|
||||||
int tidx = threadIdx.x;
|
int tidx = threadIdx.x;
|
||||||
uint32_t block_idx = getLinearBlockId<uint32_t>();
|
uint32_t block_idx = getLinearBlockId<uint32_t>();
|
||||||
uint32_t slice_idx = block_idx / blocks_per_slice;
|
uint32_t slice_idx = block_idx / blocks_per_slice;
|
||||||
@ -389,15 +351,36 @@ __global__ void computeBlockwiseWithinKCounts(
|
|||||||
if (block_idx >= num_blocks) {
|
if (block_idx >= num_blocks) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
typedef cub::BlockScan<uint32_t, BLOCK_THREADS> BlockScan;
|
||||||
|
union __align__(16) TempStorage {
|
||||||
|
uint32_t digit_count_cumsum[RADIX_DIGITS]; // only used if this it the last block for this slice
|
||||||
|
typename BlockScan::TempStorage scan_storage;
|
||||||
|
};
|
||||||
|
__shared__ TempStorage temp_storage;
|
||||||
|
|
||||||
|
// accumulates counters from multiple blocks
|
||||||
|
uint32_t digit_count = 0;
|
||||||
|
if (tidx < RADIX_DIGITS) {
|
||||||
|
for (int blk = 0; blk < blocks_per_slice; ++blk) {
|
||||||
|
digit_count += counts[(slice_idx * blocks_per_slice + blk) * RADIX_DIGITS + tidx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute the block-wide inclusive prefix sum
|
||||||
|
uint32_t digit_count_cumsum;
|
||||||
|
BlockScan(temp_storage.scan_storage).InclusiveSum(digit_count, digit_count_cumsum);
|
||||||
|
__syncthreads();
|
||||||
|
// every thread also need the perfix_sum of it's left value for comparison, so save a copy in shared mem
|
||||||
|
if (tidx < RADIX_DIGITS) {
|
||||||
|
temp_storage.digit_count_cumsum[tidx] = digit_count_cumsum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
__shared__ Bitwise desired;
|
__shared__ Bitwise desired;
|
||||||
uint32_t k_to_find = ks_to_find_in[slice_idx];
|
uint32_t k_to_find = ks_to_find_in[slice_idx];
|
||||||
|
|
||||||
if (tidx < RADIX_DIGITS) {
|
if (tidx < RADIX_DIGITS) {
|
||||||
uint32_t position = slice_idx * RADIX_DIGITS + tidx;
|
uint32_t digit_count_cumsum_left = (tidx == 0) ? 0 : temp_storage.digit_count_cumsum[tidx - 1];
|
||||||
uint32_t digit_count_cumsum = digit_cum_sum[position];
|
|
||||||
uint32_t digit_count_cumsum_left = (tidx == 0) ? 0 : digit_cum_sum[position - 1];
|
|
||||||
|
|
||||||
// if not the last pass: update desired and ks_to_find
|
// if not the last pass: update desired and ks_to_find
|
||||||
// if last pass: write out the kth value
|
// if last pass: write out the kth value
|
||||||
@ -483,7 +466,7 @@ template <typename Bitwise>
|
|||||||
__global__ void computeBlockwiseKthCounts(
|
__global__ void computeBlockwiseKthCounts(
|
||||||
Bitwise* desires, // size: num_slices
|
Bitwise* desires, // size: num_slices
|
||||||
short* counts, // size: num_slices * blocks_per_slice * radix_digits
|
short* counts, // size: num_slices * blocks_per_slice * radix_digits
|
||||||
uint32_t num_blocks, // the number of blocks used by `computeBlockDigitCounts` kernel
|
uint32_t num_blocks, // the number of blocks used by `radixFindKthValues` kernel
|
||||||
uint32_t blocks_per_slice,
|
uint32_t blocks_per_slice,
|
||||||
// outputs:
|
// outputs:
|
||||||
uint32_t* kthCounts // size: num_slices * blocks_per_slice == num_blocks
|
uint32_t* kthCounts // size: num_slices * blocks_per_slice == num_blocks
|
||||||
@ -666,7 +649,9 @@ void launch(
|
|||||||
T* kthValues = reinterpret_cast<T*>(kthValues_buffer.get());
|
T* kthValues = reinterpret_cast<T*>(kthValues_buffer.get());
|
||||||
|
|
||||||
TORCH_CHECK(blocks_per_slice <= std::numeric_limits<uint32_t>::max(), "blocks_per_slice larger than uint32 maximum is not supported");
|
TORCH_CHECK(blocks_per_slice <= std::numeric_limits<uint32_t>::max(), "blocks_per_slice larger than uint32 maximum is not supported");
|
||||||
|
auto semaphores_buffer = allocator.allocate(numInputSlices * sizeof(uint32_t));
|
||||||
|
uint32_t* semaphores = reinterpret_cast<uint32_t*>(semaphores_buffer.get());
|
||||||
|
AT_CUDA_CHECK(cudaMemsetAsync(semaphores, 0, numInputSlices * sizeof(uint32_t), stream));
|
||||||
|
|
||||||
auto ks_to_find_buffer = allocator.allocate(2 * numInputSlices * sizeof(uint32_t));
|
auto ks_to_find_buffer = allocator.allocate(2 * numInputSlices * sizeof(uint32_t));
|
||||||
uint32_t* ks_to_find = reinterpret_cast<uint32_t*>(ks_to_find_buffer.get());
|
uint32_t* ks_to_find = reinterpret_cast<uint32_t*>(ks_to_find_buffer.get());
|
||||||
@ -683,10 +668,6 @@ void launch(
|
|||||||
static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
|
static_assert(MAX_ITEMS_PER_THREAD * BLOCK_THREADS < std::numeric_limits<short>::max(),
|
||||||
"blockwise counter too large");
|
"blockwise counter too large");
|
||||||
|
|
||||||
auto digit_cum_sum_buffer = allocator.allocate(numInputSlices * RADIX_DIGITS * sizeof(uint32_t));
|
|
||||||
uint32_t* digit_cum_sum = reinterpret_cast<uint32_t*>(digit_cum_sum_buffer.get());
|
|
||||||
AT_CUDA_CHECK(cudaMemsetAsync(digit_cum_sum, 0, numInputSlices * RADIX_DIGITS * sizeof(uint32_t), stream));
|
|
||||||
|
|
||||||
#if CUB_SUPPORTS_SCAN_BY_KEY()
|
#if CUB_SUPPORTS_SCAN_BY_KEY()
|
||||||
auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
|
auto withinKCounts_buffer = allocator.allocate(num_blocks * sizeof(uint32_t));
|
||||||
uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
|
uint32_t* withinKCounts = reinterpret_cast<uint32_t*>(withinKCounts_buffer.get());
|
||||||
@ -710,7 +691,7 @@ void launch(
|
|||||||
|
|
||||||
// iterate radix bits for multiple passes
|
// iterate radix bits for multiple passes
|
||||||
for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; current_bit -= RADIX_BITS) {
|
for (int current_bit = sizeof(T) * 8 - RADIX_BITS; current_bit >= 0; current_bit -= RADIX_BITS) {
|
||||||
computeBlockDigitCounts<T, IndexType, Bitwise, Dim><<<grid, block, 0, stream>>>(
|
radixFindKthValues<T, IndexType, Bitwise, Dim><<<grid, block, 0, stream>>>(
|
||||||
input,
|
input,
|
||||||
inputSliceSize,
|
inputSliceSize,
|
||||||
ks_to_find_in, // unused arg
|
ks_to_find_in, // unused arg
|
||||||
@ -723,14 +704,10 @@ void launch(
|
|||||||
desired_in,
|
desired_in,
|
||||||
counts);
|
counts);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
|
||||||
computeDigitCumSum<<<numInputSlices, RADIX_DIGITS, 0, stream>>>(counts, digit_cum_sum, blocks_per_slice);
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
|
|
||||||
// we unconditionally call this kernel to update desired/ks_to_find/kthValues
|
// we unconditionally call this kernel to update desired/ks_to_find/kthValues
|
||||||
// if cub supports scan_by_key we additionally do k counts
|
// if cub supports scan_by_key we additionally do k counts
|
||||||
computeBlockwiseWithinKCounts<Bitwise, T><<<grid, RADIX_DIGITS, 0, stream>>>(
|
computeBlockwiseWithinKCounts<Bitwise, T><<<grid, RADIX_DIGITS, 0, stream>>>(
|
||||||
desired_in, counts, digit_cum_sum, ks_to_find_in, blocks_per_slice, current_bit, largest, withinKCounts, kthValues, ks_to_find_out, desired_out, num_blocks);
|
desired_in, counts, ks_to_find_in, blocks_per_slice, current_bit, largest, withinKCounts, kthValues, ks_to_find_out, desired_out, num_blocks);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
// swap desired/ks_to_find in and out for next iter
|
// swap desired/ks_to_find in and out for next iter
|
||||||
auto tmp_desired = desired_in;
|
auto tmp_desired = desired_in;
|
||||||
|
|||||||
@ -127,7 +127,8 @@ void apply_ldl_solve_cusolver(
|
|||||||
const Tensor& pivots,
|
const Tensor& pivots,
|
||||||
const Tensor& B,
|
const Tensor& B,
|
||||||
bool upper) {
|
bool upper) {
|
||||||
#if !(defined(CUDART_VERSION) && defined(CUSOLVER_VERSION))
|
#if !(defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && \
|
||||||
|
CUSOLVER_VERSION >= 11102)
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Calling torch.linalg.ldl_solve on a CUDA tensor requires compiling ",
|
"Calling torch.linalg.ldl_solve on a CUDA tensor requires compiling ",
|
||||||
|
|||||||
@ -169,10 +169,7 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
|||||||
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
|
ss << "If that doesn't trigger the error, please include your original repro script when reporting this issue.\n\n";
|
||||||
ss << "import torch\n";
|
ss << "import torch\n";
|
||||||
ss << "torch.backends.cuda.matmul.allow_tf32 = "
|
ss << "torch.backends.cuda.matmul.allow_tf32 = "
|
||||||
<< pybool(
|
<< pybool(at::globalContext().float32Precision("cuda", "matmul") == "tf32")
|
||||||
at::globalContext().float32Precision(
|
|
||||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
|
||||||
at::Float32Precision::TF32)
|
|
||||||
<< "\n";
|
<< "\n";
|
||||||
ss << "torch.backends.cudnn.benchmark = "
|
ss << "torch.backends.cudnn.benchmark = "
|
||||||
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
||||||
@ -729,7 +726,7 @@ Tensor cudnn_convolution_relu(
|
|||||||
|
|
||||||
auto& ctx = at::globalContext();
|
auto& ctx = at::globalContext();
|
||||||
bool benchmark = ctx.benchmarkCuDNN();
|
bool benchmark = ctx.benchmarkCuDNN();
|
||||||
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||||
auto _bias = bias_t.has_value()
|
auto _bias = bias_t.has_value()
|
||||||
? bias_t.value()
|
? bias_t.value()
|
||||||
: at::zeros(
|
: at::zeros(
|
||||||
@ -787,7 +784,7 @@ Tensor cudnn_convolution_add_relu(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& ctx = at::globalContext();
|
auto& ctx = at::globalContext();
|
||||||
bool allow_tf32 = ctx.allowTF32CuDNN(at::Float32Op::CONV);
|
bool allow_tf32 = ctx.allowTF32CuDNN("conv");
|
||||||
bool benchmark = ctx.benchmarkCuDNN();
|
bool benchmark = ctx.benchmarkCuDNN();
|
||||||
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
auto _alpha = alpha.has_value() ? alpha.value().to<float>() : 1.0;
|
||||||
auto _bias = bias_t.has_value()
|
auto _bias = bias_t.has_value()
|
||||||
|
|||||||
@ -76,6 +76,7 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
|
|||||||
|
|
||||||
#else // AT_CUDNN_ENABLED
|
#else // AT_CUDNN_ENABLED
|
||||||
|
|
||||||
|
#include <ATen/cudnn/Descriptors.h>
|
||||||
#include <ATen/cudnn/Types.h>
|
#include <ATen/cudnn/Types.h>
|
||||||
#include <ATen/cudnn/Utils.h>
|
#include <ATen/cudnn/Utils.h>
|
||||||
|
|
||||||
@ -283,9 +284,9 @@ std::tuple<Tensor, Tensor> _cudnn_ctc_loss_tensor(
|
|||||||
checkBackend(c, {*targets}, Backend::CUDA);
|
checkBackend(c, {*targets}, Backend::CUDA);
|
||||||
const auto batch_size = log_probs->size(1);
|
const auto batch_size = log_probs->size(1);
|
||||||
int64_t input_lengths_size =
|
int64_t input_lengths_size =
|
||||||
!input_lengths_.sizes().empty() ? input_lengths_.size(0) : 1;
|
input_lengths_.sizes().size() ? input_lengths_.size(0) : 1;
|
||||||
int64_t target_lengths_size =
|
int64_t target_lengths_size =
|
||||||
!target_lengths_.sizes().empty() ? target_lengths_.size(0) : 1;
|
target_lengths_.sizes().size() ? target_lengths_.size(0) : 1;
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
input_lengths_size == batch_size,
|
input_lengths_size == batch_size,
|
||||||
"input_lengths needs to have size to match batch_size");
|
"input_lengths needs to have size to match batch_size");
|
||||||
|
|||||||
@ -142,6 +142,8 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
|||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
||||||
|
#include <cudnn_frontend.h>
|
||||||
|
|
||||||
namespace fe = cudnn_frontend;
|
namespace fe = cudnn_frontend;
|
||||||
|
|
||||||
constexpr uint8_t MAX_MHA_DIM = 4;
|
constexpr uint8_t MAX_MHA_DIM = 4;
|
||||||
@ -1377,7 +1379,7 @@ void run_cudnn_SDP_fprop(
|
|||||||
cudnnHandle_t handle = getCudnnHandle();
|
cudnnHandle_t handle = getCudnnHandle();
|
||||||
|
|
||||||
// NB: The key initialization will round up sequence length, stride data etc.
|
// NB: The key initialization will round up sequence length, stride data etc.
|
||||||
// if use_ragged_in_dense is enabled (to allow multiple sequence lengths to
|
// if use_ragged_in_dense is enabled (to allow multiple sequence lenghths to
|
||||||
// reuse the same cached value/graph)
|
// reuse the same cached value/graph)
|
||||||
auto key = MHACacheKeyWrapper(
|
auto key = MHACacheKeyWrapper(
|
||||||
b,
|
b,
|
||||||
|
|||||||
@ -245,7 +245,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
|||||||
datatype,
|
datatype,
|
||||||
input_datatype,
|
input_datatype,
|
||||||
algo,
|
algo,
|
||||||
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
|
at::globalContext().allowTF32CuDNN("rnn"));
|
||||||
#else
|
#else
|
||||||
rnn_desc.set(
|
rnn_desc.set(
|
||||||
handle,
|
handle,
|
||||||
@ -261,7 +261,7 @@ descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
|
|||||||
datatype,
|
datatype,
|
||||||
input_datatype,
|
input_datatype,
|
||||||
algo,
|
algo,
|
||||||
at::globalContext().allowTF32CuDNN(at::Float32Op::RNN));
|
at::globalContext().allowTF32CuDNN("rnn"));
|
||||||
#endif
|
#endif
|
||||||
return rnn_desc;
|
return rnn_desc;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -772,21 +772,13 @@ void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||||
static const std::vector<std::string> wmma_archs = {
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
|
std::string_view arch(dprops->gcnArchName);
|
||||||
#if ROCM_VERSION >= 70000
|
if (arch == "gfx1100") {
|
||||||
"gfx1150", "gfx1151"
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
|
|
||||||
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||||
}
|
} else{
|
||||||
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
|
|
||||||
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
TORCH_CHECK(false, "gemm_internal_ck<at::BFloat16> unsupported gfx arch");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|||||||
@ -599,21 +599,11 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||||
static const std::vector<std::string> wmma_archs = {
|
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) {
|
||||||
"gfx1100", "gfx1101", "gfx1102", "gfx1200", "gfx1201",
|
|
||||||
#if ROCM_VERSION >= 70000
|
|
||||||
"gfx1150", "gfx1151"
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
if (at::detail::getCUDAHooks().isGPUArch(wmma_archs)) {
|
|
||||||
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
|
dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half));
|
||||||
}
|
} else{
|
||||||
else if (at::detail::getCUDAHooks().isGPUArch({"gfx9"})) {
|
|
||||||
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half));
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
TORCH_CHECK(false, "gemm_internal_ck<at::Half> unsupported gfx arch");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|||||||
@ -38,6 +38,7 @@ REGISTER_NO_CPU_DISPATCH(mkldnn_convolution_transpose_backward_stub)
|
|||||||
|
|
||||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||||
#include <ATen/native/mkldnn/Utils.h>
|
#include <ATen/native/mkldnn/Utils.h>
|
||||||
|
#include <ATen/native/ConvUtils.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
@ -104,7 +105,7 @@ static void check_shape_forward(const Tensor& input,
|
|||||||
// If kernel size is incorrect
|
// If kernel size is incorrect
|
||||||
std::ostringstream input_ss;
|
std::ostringstream input_ss;
|
||||||
std::ostringstream kernel_ss;
|
std::ostringstream kernel_ss;
|
||||||
std::string separator;
|
std::string separator = "";
|
||||||
|
|
||||||
for (int i = 0, len = input_shape.size(); i < len; ++i) {
|
for (int i = 0, len = input_shape.size(); i < len; ++i) {
|
||||||
input_ss << separator << input_shape[i];
|
input_ss << separator << input_shape[i];
|
||||||
@ -155,12 +156,12 @@ static void check_shape_forward(const Tensor& input,
|
|||||||
//
|
//
|
||||||
|
|
||||||
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
|
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
|
||||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::BF16 &&
|
return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" &&
|
||||||
mkldnn_bf16_device_check();
|
mkldnn_bf16_device_check();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
|
static bool mkldnn_conv_enabled_fpmath_mode_tf32(){
|
||||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::CONV) == at::Float32Precision::TF32 &&
|
return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" &&
|
||||||
cpuinfo_has_x86_amx_fp16();
|
cpuinfo_has_x86_amx_fp16();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -69,12 +69,12 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
|||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
static bool use_mkldnn_bf32_linear() {
|
static bool use_mkldnn_bf32_linear() {
|
||||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16 &&
|
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||||
mkldnn_bf16_device_check();
|
mkldnn_bf16_device_check();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool use_mkldnn_tf32_linear() {
|
static bool use_mkldnn_tf32_linear() {
|
||||||
return at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32 &&
|
return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" &&
|
||||||
cpuinfo_has_x86_amx_fp16();
|
cpuinfo_has_x86_amx_fp16();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -111,11 +111,11 @@ static bool use_mkldnn_fp16_matmul() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool use_mkldnn_bf32_matmul() {
|
static bool use_mkldnn_bf32_matmul() {
|
||||||
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::BF16;
|
return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16";
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool use_mkldnn_tf32_matmul() {
|
static bool use_mkldnn_tf32_matmul() {
|
||||||
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision(at::Float32Backend::MKLDNN, at::Float32Op::MATMUL) == at::Float32Precision::TF32;
|
return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32";
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns an ideep::tensor
|
// returns an ideep::tensor
|
||||||
|
|||||||
@ -316,7 +316,7 @@ Tensor NestedTensor_to_padded_tensor_generic(
|
|||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
(int64_t)output_size_.size() == ret_val.dim(),
|
(int64_t)output_size_.size() == ret_val.dim(),
|
||||||
"Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
|
"Length of output_size does not match NestedTensor dims. Broadcasting is not supported.");
|
||||||
for (int64_t i = 0; i < ret_val.dim(); i++) {
|
for (int64_t i = 0; i < (int64_t)ret_val.dim(); i++) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
output_size_[i] >= ret_val.size(i),
|
output_size_[i] >= ret_val.size(i),
|
||||||
"Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
|
"Value in output_size is less than NestedTensor padded size. Truncation is not supported.");
|
||||||
|
|||||||
@ -146,12 +146,12 @@ inline TensorQuantizationParams ChooseQuantizationParams(
|
|||||||
// The arithmetic error on the zero point computed from either pair
|
// The arithmetic error on the zero point computed from either pair
|
||||||
// will be roughly machine_epsilon * (sum of absolute values of terms)
|
// will be roughly machine_epsilon * (sum of absolute values of terms)
|
||||||
// so we want to use the variant that adds the smaller terms.
|
// so we want to use the variant that adds the smaller terms.
|
||||||
double zero_point_from_min = qmin - min / scale;
|
double zero_point_from_min = qmin - min / static_cast<double>(scale);
|
||||||
double zero_point_from_max = qmax - max / scale;
|
double zero_point_from_max = qmax - max / static_cast<double>(scale);
|
||||||
double zero_point_from_min_error =
|
double zero_point_from_min_error =
|
||||||
std::abs(qmin) - std::abs(min / scale);
|
std::abs(qmin) - std::abs(min / static_cast<double>(scale));
|
||||||
double zero_point_from_max_error =
|
double zero_point_from_max_error =
|
||||||
std::abs(qmax) - std::abs(max / scale);
|
std::abs(qmax) - std::abs(max / static_cast<double>(scale));
|
||||||
double initial_zero_point =
|
double initial_zero_point =
|
||||||
zero_point_from_min_error < zero_point_from_max_error
|
zero_point_from_min_error < zero_point_from_max_error
|
||||||
? zero_point_from_min
|
? zero_point_from_min
|
||||||
|
|||||||
@ -560,7 +560,7 @@ float hsum_sq(const int32_t* A, int len) {
|
|||||||
alignas(64) float temp[8];
|
alignas(64) float temp[8];
|
||||||
_mm256_store_ps(temp, sum_ps);
|
_mm256_store_ps(temp, sum_ps);
|
||||||
for (const auto k : c10::irange(8)) {
|
for (const auto k : c10::irange(8)) {
|
||||||
row_sum += temp[k];
|
row_sum += static_cast<float>(temp[k]);
|
||||||
}
|
}
|
||||||
#elif defined(CPU_CAPABILITY_AVX512)
|
#elif defined(CPU_CAPABILITY_AVX512)
|
||||||
__m512 sum_ps = _mm512_setzero_ps();
|
__m512 sum_ps = _mm512_setzero_ps();
|
||||||
@ -574,7 +574,7 @@ float hsum_sq(const int32_t* A, int len) {
|
|||||||
alignas(64) float temp[16];
|
alignas(64) float temp[16];
|
||||||
_mm512_store_ps(temp, sum_ps);
|
_mm512_store_ps(temp, sum_ps);
|
||||||
for (const auto k : c10::irange(16)) {
|
for (const auto k : c10::irange(16)) {
|
||||||
row_sum += temp[k];
|
row_sum += static_cast<float>(temp[k]);
|
||||||
}
|
}
|
||||||
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512
|
||||||
|
|
||||||
@ -1282,7 +1282,7 @@ template <bool ReLUFused = false>
|
|||||||
void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
|
void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
|
||||||
int64_t zero_point = out.q_zero_point();
|
int64_t zero_point = out.q_zero_point();
|
||||||
float scale = static_cast<float>(out.q_scale());
|
float scale = static_cast<float>(out.q_scale());
|
||||||
float inv_scale = 1.0f / scale;
|
float inv_scale = static_cast<float>(1.0f / scale);
|
||||||
int64_t self_zero_point = self.q_zero_point();
|
int64_t self_zero_point = self.q_zero_point();
|
||||||
float self_scale = static_cast<float>(self.q_scale());
|
float self_scale = static_cast<float>(self.q_scale());
|
||||||
|
|
||||||
@ -2915,7 +2915,7 @@ void fake_quantize_learnable_channel_grad_kernel_cpu(
|
|||||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||||
*dx_output = (*dy_input) * (xqi >= quant_min && xqi <= quant_max);
|
*dx_output = (*dy_input) * (xqi >= quant_min && xqi <= quant_max);
|
||||||
// Calculate gradients for scale and zero point.
|
// Calculate gradients for scale and zero point.
|
||||||
float xfqi = ((std::max(std::min(xqi, quant_max), quant_min) - (*zero_point_input)) * (*scale_input));
|
float xfqi = static_cast<float>((std::max(std::min(xqi, quant_max), quant_min) - (*zero_point_input)) * (*scale_input));
|
||||||
if (xqi < quant_min || xqi > quant_max) {
|
if (xqi < quant_min || xqi > quant_max) {
|
||||||
*dzero_point_output = (*dy_input) * (-1) * (*scale_input) * grad_factor;
|
*dzero_point_output = (*dy_input) * (-1) * (*scale_input) * grad_factor;
|
||||||
*dscale_output = ((xqi < quant_min) ? ((*dy_input) * dscale_small) : ((*dy_input) * dscale_big)) * grad_factor;
|
*dscale_output = ((xqi < quant_min) ? ((*dy_input) * dscale_small) : ((*dy_input) * dscale_big)) * grad_factor;
|
||||||
@ -4415,7 +4415,7 @@ void _qmul_tensor_cpu_impl(
|
|||||||
uint8_t y_data = *(y_ptr + idx);
|
uint8_t y_data = *(y_ptr + idx);
|
||||||
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
|
int32_t x_val = static_cast<int32_t>(x_data) - x_zero_point;
|
||||||
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
|
int32_t y_val = static_cast<int32_t>(y_data) - y_zero_point;
|
||||||
int32_t out_val = x_val * y_val;
|
int32_t out_val = static_cast<int32_t>(x_val * y_val);
|
||||||
float out_val_f = (float)out_val * multiplier;
|
float out_val_f = (float)out_val * multiplier;
|
||||||
if constexpr (std::is_same<T, float>::value) {
|
if constexpr (std::is_same<T, float>::value) {
|
||||||
*(out_ptr + idx) = out_val_f;
|
*(out_ptr + idx) = out_val_f;
|
||||||
|
|||||||
@ -1198,7 +1198,7 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
|
|||||||
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
|
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
|
||||||
ideep::tensor src(src_desc, act_contig.data_ptr());
|
ideep::tensor src(src_desc, act_contig.data_ptr());
|
||||||
// weights & bias
|
// weights & bias
|
||||||
ideep::tensor& weights = *(weight_);
|
ideep::tensor& weights = *(weight_.get());
|
||||||
bool with_bias = bias_.has_value();
|
bool with_bias = bias_.has_value();
|
||||||
const auto& kernel_size = weights.get_dims();
|
const auto& kernel_size = weights.get_dims();
|
||||||
// dst
|
// dst
|
||||||
|
|||||||
@ -812,7 +812,7 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
|||||||
|
|
||||||
auto is_input_qint8 = input.scalar_type() == c10::ScalarType::QInt8;
|
auto is_input_qint8 = input.scalar_type() == c10::ScalarType::QInt8;
|
||||||
auto input_contig = input.expect_contiguous();
|
auto input_contig = input.expect_contiguous();
|
||||||
auto& w = *weight_;
|
auto& w = *(weight_.get());
|
||||||
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
|
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
|
||||||
auto input_dims = {M, K};
|
auto input_dims = {M, K};
|
||||||
auto input_data_type = is_input_qint8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::u8;
|
auto input_data_type = is_input_qint8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::u8;
|
||||||
|
|||||||
@ -545,7 +545,7 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
|
|||||||
/*reduce_range=*/reduce_range);
|
/*reduce_range=*/reduce_range);
|
||||||
const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
|
const std::vector<int32_t>& src_zero_point = std::vector<int32_t>(1, q_params.zero_point);
|
||||||
// weights, dst
|
// weights, dst
|
||||||
auto w = *weight_;
|
auto w = *(weight_.get());
|
||||||
auto dst_dims = {x.get_dim(0), w.get_dim(1)};
|
auto dst_dims = {x.get_dim(0), w.get_dim(1)};
|
||||||
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
|
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/q_params.scale);
|
||||||
const ideep::scale_t& weights_scales = w.get_scale();
|
const ideep::scale_t& weights_scales = w.get_scale();
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <c10/core/QScheme.h>
|
#include <c10/core/QScheme.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include <ATen/quantized/Quantizer.h>
|
#include <ATen/quantized/Quantizer.h>
|
||||||
#include <c10/core/QScheme.h>
|
#include <c10/core/QScheme.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
int register_linear_params();
|
int register_linear_params();
|
||||||
|
|
||||||
|
|||||||
@ -65,7 +65,7 @@ Tensor& addmv_out_sparse_compressed(
|
|||||||
return result.zero_();
|
return result.zero_();
|
||||||
} else {
|
} else {
|
||||||
return at::mul_out(
|
return at::mul_out(
|
||||||
result,
|
const_cast<Tensor&>(result),
|
||||||
self,
|
self,
|
||||||
at::native::scalar_tensor(
|
at::native::scalar_tensor(
|
||||||
beta,
|
beta,
|
||||||
|
|||||||
@ -1330,18 +1330,18 @@ Tensor reduce_sparse_csr_cpu_template(const Tensor& sparse, IntArrayRef dims_to_
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
struct ReductionAddOp {
|
struct ReductionAddOp {
|
||||||
scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
|
inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
scalar_t identity() const { return 0; }
|
inline scalar_t identity() const { return 0; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
struct ReductionMulOp {
|
struct ReductionMulOp {
|
||||||
scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
|
inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
|
||||||
return a * b;
|
return a * b;
|
||||||
}
|
}
|
||||||
scalar_t identity() const { return 1; }
|
inline scalar_t identity() const { return 1; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user