mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
Compare commits
171 Commits
quint-bits
...
llama4-sta
| Author | SHA1 | Date | |
|---|---|---|---|
| 7895ff12a7 | |||
| 25ef3d315d | |||
| 7e00f2ec9d | |||
| 490cb3f1a4 | |||
| b95cf5c91d | |||
| 5e2ef2a465 | |||
| 9f753f8c0d | |||
| db437690d1 | |||
| 669009bcd1 | |||
| e4e2701429 | |||
| 64cc649275 | |||
| b1fb552974 | |||
| bb62e1f769 | |||
| 327e2ca580 | |||
| 1ebcba4e1b | |||
| 5f7eae697d | |||
| c1722db0f7 | |||
| 8a233d6000 | |||
| bf3ebd7ad4 | |||
| c07bb277a0 | |||
| f89c28cc6b | |||
| 8fedcfa59a | |||
| 6662a76f59 | |||
| 05aade1b6d | |||
| f946b25865 | |||
| d2e02585b8 | |||
| 3dd7ebf418 | |||
| 8273ee0646 | |||
| c57382a493 | |||
| e7cc42df58 | |||
| 72c69e731f | |||
| 78b9dea754 | |||
| 838924436e | |||
| 2ffb510942 | |||
| 20b5f694f8 | |||
| 447e300d55 | |||
| 5b2ad9279c | |||
| 78d7f0cdec | |||
| d5c719ec3c | |||
| c44efc3755 | |||
| 6b9473469f | |||
| 7a4167a164 | |||
| 8e67a6ae89 | |||
| c68ad1bd6a | |||
| 3e5e094615 | |||
| c65efc8ea1 | |||
| a9049413e2 | |||
| d7a5ec9355 | |||
| 2c46922ce4 | |||
| 668d414ae7 | |||
| 4541509237 | |||
| 6c7f88c2c9 | |||
| c400c8e2e0 | |||
| 25c3a7e317 | |||
| de7376537f | |||
| fd2c64e286 | |||
| 2b1ae29960 | |||
| 1293405c8d | |||
| 3a65ff84b6 | |||
| acf13a9b75 | |||
| 3a55676200 | |||
| af39144a93 | |||
| 25343b343e | |||
| 07fad04181 | |||
| 7ac70ac4cd | |||
| e221a1c853 | |||
| 4defea1e2c | |||
| 53d68b95de | |||
| f74842d57f | |||
| 644fee2610 | |||
| 7821fbc560 | |||
| 73ee323380 | |||
| 176c6446f8 | |||
| debc0591b8 | |||
| 0df78f0c11 | |||
| d0e8a0ec4c | |||
| 22492848b6 | |||
| 5c14315b05 | |||
| 1b99c1859c | |||
| 435edbcb5d | |||
| 6c6e11c206 | |||
| a775c8e73e | |||
| 24d07b3a67 | |||
| 90fd06be71 | |||
| 002f18807e | |||
| 259e79e3ff | |||
| ee343ce60c | |||
| ea5369113a | |||
| b268f22ab2 | |||
| 52a52d1b78 | |||
| eaadd1282c | |||
| 1465757959 | |||
| 17b9c618dd | |||
| d3ce45012e | |||
| 1fc010a9d8 | |||
| dfacf11f66 | |||
| c8cf811995 | |||
| 914b1a3873 | |||
| 7eb5fdb358 | |||
| f1fb57d854 | |||
| 6d0f4566e2 | |||
| e785c087c5 | |||
| d214901133 | |||
| 96ac64d00c | |||
| 46d34d6766 | |||
| 880249adbc | |||
| 846ada4973 | |||
| badd0618e4 | |||
| a753a72b14 | |||
| b57d1ef110 | |||
| dd7c996d5c | |||
| 70d2e9ba45 | |||
| 62f98dbb44 | |||
| e288c258f7 | |||
| df58db8831 | |||
| 15bb81ea4f | |||
| 8d37073bac | |||
| dc286aef61 | |||
| b4619f0272 | |||
| 477c2273e1 | |||
| 2176d481c1 | |||
| b97274e8ac | |||
| f9be65cea4 | |||
| 4e3e3dc0a7 | |||
| fcf59df2b6 | |||
| 1bcb2f41e0 | |||
| 8460131087 | |||
| c0c24b61ff | |||
| 4fac43b21f | |||
| b794e77b7b | |||
| d987a6f7f0 | |||
| 5d93127c87 | |||
| a3a51282db | |||
| e557b3d5e5 | |||
| f3a9e99036 | |||
| f7d6e9f500 | |||
| e43e09e6c1 | |||
| 2004f8aa10 | |||
| 31b3b38e3a | |||
| 2f0db0444e | |||
| 6162e650b0 | |||
| 5d89634ca8 | |||
| 52e180c379 | |||
| c55e72bea1 | |||
| 750348b579 | |||
| 52b9af163c | |||
| f4bfac11c7 | |||
| 8d00833fdb | |||
| de529ef002 | |||
| 61aa2ae20f | |||
| 9d32aa9789 | |||
| 5cf77a0ea2 | |||
| efcf87654e | |||
| 2523e58781 | |||
| 222fa451a2 | |||
| 6de24135e5 | |||
| 27ae72036d | |||
| e924df23a6 | |||
| 67e68e0785 | |||
| 775788f93b | |||
| 19ce1beb05 | |||
| a91ddea61f | |||
| ffccb90ff4 | |||
| f916f34739 | |||
| c32994ce4b | |||
| 433e43cbec | |||
| e469414b59 | |||
| 657e5e9aa6 | |||
| f02b783aae | |||
| 8ad96a563c | |||
| 59e261bbd8 |
@ -103,5 +103,5 @@ fi
|
||||
# It depends on torch and triton. We don't want to install
|
||||
# triton and torch from production on Docker CI images
|
||||
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
|
||||
pip_install helion --no-deps
|
||||
pip_install helion==0.0.10 --no-deps
|
||||
fi
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
|
||||
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
@ -50,8 +50,8 @@ IPython==8.12.0
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 0.13.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
@ -59,4 +59,3 @@ sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-nb
|
||||
|
||||
@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then
|
||||
export ATEN_THREADING=NATIVE
|
||||
fi
|
||||
|
||||
# Enable LLVM dependency for TensorExpr testing
|
||||
export USE_LLVM=/opt/llvm
|
||||
export LLVM_DIR=/opt/llvm/lib/cmake/llvm
|
||||
|
||||
if ! which conda; then
|
||||
# In ROCm CIs, we are doing cross compilation on build machines with
|
||||
@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
|
||||
export USE_ASAN=1
|
||||
export REL_WITH_DEB_INFO=1
|
||||
export UBSAN_FLAGS="-fno-sanitize-recover=all"
|
||||
unset USE_LLVM
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then
|
||||
|
||||
@ -462,7 +462,7 @@ test_inductor_aoti() {
|
||||
# rebuild with the build cache with `BUILD_AOT_INDUCTOR_TEST` enabled
|
||||
/usr/bin/env CMAKE_FRESH=1 BUILD_AOT_INDUCTOR_TEST=1 "${BUILD_COMMAND[@]}"
|
||||
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference -dist=loadfile
|
||||
/usr/bin/env "${TEST_ENVS[@]}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference cpp/test_vec_half_AVX2 -dist=loadfile
|
||||
}
|
||||
|
||||
test_inductor_cpp_wrapper_shard() {
|
||||
@ -1039,10 +1039,20 @@ test_libtorch_api() {
|
||||
mkdir -p $TEST_REPORTS_DIR
|
||||
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml
|
||||
"$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml
|
||||
else
|
||||
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
|
||||
|
||||
# On s390x, pytorch is built without llvm.
|
||||
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
|
||||
# test fails with errors like:
|
||||
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
|
||||
# unknown file: Failure
|
||||
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
|
||||
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
|
||||
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
|
||||
fi
|
||||
fi
|
||||
|
||||
# quantization is not fully supported on s390x yet
|
||||
|
||||
7
.github/actionlint.yaml
vendored
7
.github/actionlint.yaml
vendored
@ -53,13 +53,12 @@ self-hosted-runner:
|
||||
- linux.rocm.gpu.mi250
|
||||
- linux.rocm.gpu.2
|
||||
- linux.rocm.gpu.4
|
||||
# MI300 runners
|
||||
- linux.rocm.gpu.mi300.2
|
||||
- linux.rocm.gpu.mi300.4
|
||||
# gfx942 runners
|
||||
- linux.rocm.gpu.gfx942.2
|
||||
- linux.rocm.gpu.gfx942.4
|
||||
- rocm-docker
|
||||
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
|
||||
- macos-m1-stable
|
||||
- macos-m1-13
|
||||
- macos-m1-14
|
||||
# GitHub-hosted MacOS runners
|
||||
- macos-latest-xlarge
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
f6dfe1231dcdd221a68416e49ab85c2575cbb824
|
||||
bf305f538005f2e900f8850ed57146024a8bc559
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
8f605ee30912541126c0fe46d0c8c413101b600a
|
||||
ca9e2be3ed6320b51f52f536595cd24e254f8bb2
|
||||
|
||||
@ -2,7 +2,7 @@ boto3==1.35.42
|
||||
cmake==3.27.*
|
||||
expecttest==0.3.0
|
||||
fbscribelogger==0.1.7
|
||||
filelock==3.13.1
|
||||
filelock==3.18.0
|
||||
hypothesis==6.56.4
|
||||
librosa>=0.6.2
|
||||
mpmath==1.3.0
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -1891,7 +1891,9 @@ def validate_revert(
|
||||
else pr.get_comment_by_id(comment_id)
|
||||
)
|
||||
if comment.editor_login is not None:
|
||||
raise PostCommentError("Don't want to revert based on edited command")
|
||||
raise PostCommentError(
|
||||
"Halting the revert as the revert comment has been edited."
|
||||
)
|
||||
author_association = comment.author_association
|
||||
author_login = comment.author_login
|
||||
allowed_reverters = ["COLLABORATOR", "MEMBER", "OWNER"]
|
||||
|
||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -269,8 +269,8 @@ jobs:
|
||||
# copy test results back to the mounted workspace, needed sudo, resulting permissions were correct
|
||||
docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "cd ../pytorch && sudo cp -R test/test-reports ../workspace/test"
|
||||
|
||||
- name: Change permissions (only needed for MI300 and MI355 kubernetes runners for now)
|
||||
if: ${{ always() && steps.test.conclusion && (contains(matrix.runner, 'mi300') || contains(matrix.runner, 'mi355')) }}
|
||||
- name: Change permissions (only needed for kubernetes runners for now)
|
||||
if: ${{ always() && steps.test.conclusion && (contains(matrix.runner, 'gfx942') || contains(matrix.runner, 'mi355')) }}
|
||||
run: |
|
||||
docker exec -t "${{ env.CONTAINER_NAME }}" sh -c "sudo chown -R 1001:1001 test"
|
||||
|
||||
|
||||
@ -88,23 +88,23 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
4
.github/workflows/inductor-rocm-mi300.yml
vendored
4
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -47,8 +47,8 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
1
.github/workflows/mac-mps.yml
vendored
1
.github/workflows/mac-mps.yml
vendored
@ -28,7 +28,6 @@ jobs:
|
||||
# than our AWS macos-m1-14 runners
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
|
||||
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
|
||||
{ config: "test_mps", shard: 1, num_shards: 1, runner: "macos-m2-15" },
|
||||
]}
|
||||
|
||||
6
.github/workflows/periodic-rocm-mi300.yml
vendored
6
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -59,9 +59,9 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi300.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
12
.github/workflows/rocm-mi300.yml
vendored
12
.github/workflows/rocm-mi300.yml
vendored
@ -48,12 +48,12 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
2
.github/workflows/rocm-mi355.yml
vendored
2
.github/workflows/rocm-mi355.yml
vendored
@ -3,7 +3,7 @@ name: rocm-mi355
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 30 9 * * * # about 2:30am PDT
|
||||
- cron: 30 11,1 * * * # about 4:30am PDT and 6:30pm PDT
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
|
||||
1
.github/workflows/trunk.yml
vendored
1
.github/workflows/trunk.yml
vendored
@ -94,7 +94,6 @@ jobs:
|
||||
{ config: "default", shard: 1, num_shards: 3, runner: "macos-m1-stable" },
|
||||
{ config: "default", shard: 2, num_shards: 3, runner: "macos-m1-stable" },
|
||||
{ config: "default", shard: 3, num_shards: 3, runner: "macos-m1-stable" },
|
||||
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-13" },
|
||||
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m1-14" },
|
||||
{ config: "mps", shard: 1, num_shards: 1, runner: "macos-m2-15" },
|
||||
]}
|
||||
|
||||
@ -164,7 +164,7 @@ init_command = [
|
||||
'types-setuptools==79.0.0.20250422',
|
||||
'types-jinja2==2.11.9',
|
||||
'types-colorama==0.4.6',
|
||||
'filelock==3.13.1',
|
||||
'filelock==3.18.0',
|
||||
'junitparser==2.1.1',
|
||||
'rich==14.1.0',
|
||||
'pyyaml==6.0.2',
|
||||
|
||||
@ -679,6 +679,7 @@ cc_library(
|
||||
[
|
||||
"torch/*.h",
|
||||
"torch/csrc/**/*.h",
|
||||
"torch/nativert/**/*.h",
|
||||
"torch/csrc/distributed/c10d/**/*.hpp",
|
||||
"torch/lib/libshm/*.h",
|
||||
],
|
||||
|
||||
@ -564,7 +564,7 @@ if(MSVC)
|
||||
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
|
||||
if(MSVC_Z7_OVERRIDE)
|
||||
# CMake set debug flags to use /Z7
|
||||
set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT Embedded)
|
||||
set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT "$<$<CONFIG:Debug,RelWithDebInfo>:Embedded>")
|
||||
endif()
|
||||
foreach(
|
||||
flag_var
|
||||
@ -872,6 +872,14 @@ cmake_dependent_option(
|
||||
"USE_CUDA OR USE_ROCM;NOT MSVC"
|
||||
OFF)
|
||||
|
||||
cmake_dependent_option(
|
||||
USE_FBGEMM_GENAI
|
||||
"Whether to build FBGEMM GenAI quantized GEMM kernels.\
|
||||
Will be disabled if not supported by the platform"
|
||||
OFF
|
||||
"USE_CUDA OR USE_ROCM"
|
||||
OFF)
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
@ -905,6 +913,10 @@ if(USE_FBGEMM)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
|
||||
endif()
|
||||
|
||||
if(USE_FBGEMM_GENAI)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM_GENAI")
|
||||
endif()
|
||||
|
||||
if(USE_PYTORCH_QNNPACK)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_QNNPACK")
|
||||
endif()
|
||||
|
||||
18
CODEOWNERS
18
CODEOWNERS
@ -14,7 +14,6 @@
|
||||
/torch/csrc/autograd/ @albanD @soulitzer
|
||||
/torch/autograd/ @albanD @soulitzer
|
||||
/tools/autograd/ @albanD @soulitzer
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||
/torch/optim/ @albanD @janeyx99
|
||||
/test/test_public_bindings.py @albanD
|
||||
@ -51,12 +50,12 @@ nn/qat/ @jerryzh168
|
||||
/torch/csrc/distributed/c10d/Ops.* @kwen2501
|
||||
|
||||
# ONNX Export
|
||||
/torch/_dynamo/backends/onnxrt.py @wschin
|
||||
/torch/csrc/jit/passes/onnx.h @titaiwangms @shubhambhokare1
|
||||
/torch/csrc/jit/passes/onnx.cpp @titaiwangms @shubhambhokare1
|
||||
/torch/csrc/jit/passes/onnx/ @titaiwangms @shubhambhokare1
|
||||
/torch/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin
|
||||
/test/onnx/ @titaiwangms @shubhambhokare1 @justinchuby @wschin
|
||||
/torch/_dynamo/backends/onnxrt.py @titaiwangms @xadupre @justinchuby
|
||||
/torch/csrc/jit/passes/onnx.h @titaiwangms @xadupre
|
||||
/torch/csrc/jit/passes/onnx.cpp @titaiwangms @xadupre
|
||||
/torch/csrc/jit/passes/onnx/ @titaiwangms @xadupre
|
||||
/torch/onnx/ @titaiwangms @xadupre @justinchuby
|
||||
/test/onnx/ @titaiwangms @xadupre @justinchuby
|
||||
|
||||
# CI
|
||||
/.ci @pytorch/pytorch-dev-infra
|
||||
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
|
||||
/torch/utils/_cxx_pytree.py @XuehaiPan
|
||||
/torch/utils/pytree/ @XuehaiPan
|
||||
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
|
||||
|
||||
# Relating to libtorch ABI
|
||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
||||
/torch/headeronly/ @janeyx99
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
|
||||
@ -247,6 +247,50 @@ if(USE_MEM_EFF_ATTENTION)
|
||||
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${mem_eff_attention_cuda_kernels_cu})
|
||||
endif()
|
||||
|
||||
IF(USE_FBGEMM_GENAI AND USE_ROCM AND NOT "gfx942" IN_LIST PYTORCH_ROCM_ARCH)
|
||||
message(WARNING "Unsupported ROCM arch for FBGEMM GenAI, will set USE_FBGEMM_GENAI to OFF")
|
||||
set(USE_FBGEMM_GENAI off)
|
||||
endif()
|
||||
|
||||
# FBGEMM GenAI
|
||||
IF(USE_FBGEMM_GENAI)
|
||||
set(FBGEMM_THIRD_PARTY ${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/)
|
||||
set(FBGEMM_GENAI_DIR ${PROJECT_SOURCE_DIR}/third_party/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize)
|
||||
|
||||
if(USE_ROCM)
|
||||
# Only include the kernels we want to build to avoid increasing binary size.
|
||||
file(GLOB_RECURSE fbgemm_genai_native_rocm_hip
|
||||
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped*.hip"
|
||||
"${FBGEMM_GENAI_DIR}/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip")
|
||||
set_source_files_properties(${fbgemm_genai_native_rocm_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
|
||||
# Add additional HIPCC compiler flags for performance
|
||||
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
|
||||
-mllvm
|
||||
-amdgpu-coerce-illegal-types=1
|
||||
-mllvm
|
||||
-enable-post-misched=0
|
||||
-mllvm
|
||||
-greedy-reverse-local-assignment=1
|
||||
-fhip-new-launch-api)
|
||||
|
||||
hip_add_library(
|
||||
fbgemm_genai STATIC
|
||||
${fbgemm_genai_native_rocm_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${FBGEMM_GENAI_EXTRA_HIPCC_FLAGS})
|
||||
set_target_properties(fbgemm_genai PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(fbgemm_genai PRIVATE FBGEMM_GENAI_NO_EXTENDED_SHAPES)
|
||||
|
||||
target_include_directories(fbgemm_genai PUBLIC
|
||||
# FBGEMM version of Composable Kernel is used due to some customizations
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/include
|
||||
${FBGEMM_THIRD_PARTY}/composable_kernel/library/include
|
||||
${FBGEMM_GENAI_DIR}/include/
|
||||
${FBGEMM_GENAI_DIR}/common/include/
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# XNNPACK
|
||||
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
||||
|
||||
|
||||
@ -1,55 +1 @@
|
||||
#pragma once
|
||||
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
|
||||
/* GCC or clang-compatible compiler, targeting x86/x86-64 */
|
||||
#include <x86intrin.h>
|
||||
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* Clang-compatible compiler, targeting arm neon */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* CLANG-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#elif defined(_MSC_VER)
|
||||
/* Microsoft C/C++-compatible compiler */
|
||||
#include <intrin.h>
|
||||
#if _MSC_VER <= 1900
|
||||
#define _mm256_extract_epi64(X, Y) \
|
||||
(_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2))
|
||||
#define _mm256_extract_epi32(X, Y) \
|
||||
(_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4))
|
||||
#define _mm256_extract_epi16(X, Y) \
|
||||
(_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8))
|
||||
#define _mm256_extract_epi8(X, Y) \
|
||||
(_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16))
|
||||
#endif
|
||||
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
|
||||
/* GCC-compatible compiler, targeting ARM with NEON */
|
||||
#include <arm_neon.h>
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
/* GCC-compatible compiler, targeting ARM with SVE */
|
||||
#include <arm_sve.h>
|
||||
#endif
|
||||
#if defined(MISSING_ARM_VLD1)
|
||||
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
|
||||
#elif defined(MISSING_ARM_VST1)
|
||||
#include <ATen/cpu/vec/vec256/missing_vst1_neon.h>
|
||||
#endif
|
||||
#elif defined(__GNUC__) && defined(__IWMMXT__)
|
||||
/* GCC-compatible compiler, targeting ARM with WMMX */
|
||||
#include <mmintrin.h>
|
||||
#elif defined(__s390x__)
|
||||
// targets Z/architecture
|
||||
// we will include vecintrin later
|
||||
#elif (defined(__GNUC__) || defined(__xlC__)) && \
|
||||
(defined(__VEC__) || defined(__ALTIVEC__))
|
||||
/* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */
|
||||
#include <altivec.h>
|
||||
/* We need to undef those tokens defined by <altivec.h> to avoid conflicts
|
||||
with the C++ types. => Can still use __bool/__vector */
|
||||
#undef bool
|
||||
#undef vector
|
||||
#undef pixel
|
||||
#elif defined(__GNUC__) && defined(__SPE__)
|
||||
/* GCC-compatible compiler, targeting PowerPC with SPE */
|
||||
#include <spe.h>
|
||||
#endif
|
||||
#include <torch/headeronly/cpu/vec/intrinsics.h>
|
||||
|
||||
@ -1,396 +1 @@
|
||||
/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */
|
||||
|
||||
__extension__ extern __inline uint8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u8_x2(const uint8_t* __a) {
|
||||
uint8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s8_x2(const int8_t* __a) {
|
||||
int8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u16_x2(const uint16_t* __a) {
|
||||
uint16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s16_x2(const int16_t* __a) {
|
||||
int16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u32_x2(const uint32_t* __a) {
|
||||
uint32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s32_x2(const int32_t* __a) {
|
||||
int32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_u64_x2(const uint64_t* __a) {
|
||||
uint64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_s64_x2(const int64_t* __a) {
|
||||
int64x1x2_t ret;
|
||||
__builtin_aarch64_simd_oi __o;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f16_x2(const float16_t* __a) {
|
||||
float16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f32_x2(const float32_t* __a) {
|
||||
float32x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_f64_x2(const float64_t* __a) {
|
||||
float64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p8_x2(const poly8_t* __a) {
|
||||
poly8x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p16_x2(const poly16_t* __a) {
|
||||
poly16x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x1x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1_p64_x2(const poly64_t* __a) {
|
||||
poly64x1x2_t ret;
|
||||
asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u8_x2(const uint8_t* __a) {
|
||||
uint8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s8_x2(const int8_t* __a) {
|
||||
int8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u16_x2(const uint16_t* __a) {
|
||||
uint16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s16_x2(const int16_t* __a) {
|
||||
int16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u32_x2(const uint32_t* __a) {
|
||||
uint32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s32_x2(const int32_t* __a) {
|
||||
int32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline uint64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_u64_x2(const uint64_t* __a) {
|
||||
uint64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline int64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_s64_x2(const int64_t* __a) {
|
||||
int64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f16_x2(const float16_t* __a) {
|
||||
float16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float32x4x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f32_x2(const float32_t* __a) {
|
||||
float32x4x2_t ret;
|
||||
asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline float64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_f64_x2(const float64_t* __a) {
|
||||
float64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly8x16x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p8_x2(const poly8_t* __a) {
|
||||
poly8x16x2_t ret;
|
||||
asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly16x8x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p16_x2(const poly16_t* __a) {
|
||||
poly16x8x2_t ret;
|
||||
asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
__extension__ extern __inline poly64x2x2_t
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vld1q_p64_x2(const poly64_t* __a) {
|
||||
poly64x2x2_t ret;
|
||||
asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a));
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* vst1x2 */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s64_x2(int64_t* __a, int64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f64_x2(float64_t* __a, float64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s8_x2(int8_t* __a, int8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s16_x2(int16_t* __a, int16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_s32_x2(int32_t* __a, int32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f16_x2(float16_t* __a, float16x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_f32_x2(float32_t* __a, float32x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) {
|
||||
asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s8_x2(int8_t* __a, int8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s16_x2(int16_t* __a, int16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s32_x2(int32_t* __a, int32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_s64_x2(int64_t* __a, int64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) {
|
||||
asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f16_x2(float16_t* __a, float16x8x2_t val) {
|
||||
asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f64_x2(float64_t* __a, float64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) {
|
||||
asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h>
|
||||
|
||||
@ -1,7 +1 @@
|
||||
/* Workaround for missing vst1q_f32_x2 in gcc-8. */
|
||||
|
||||
__extension__ extern __inline void
|
||||
__attribute__((__always_inline__, __gnu_inline__, __artificial__))
|
||||
vst1q_f32_x2(float32_t* __a, float32x4x2_t val) {
|
||||
asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val));
|
||||
}
|
||||
#include <torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h>
|
||||
|
||||
@ -3,50 +3,12 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <torch/headeronly/cpu/vec/vec_half.h>
|
||||
|
||||
namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
static inline uint16_t float2half_scalar(float val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m256 v = _mm256_set1_ps(val);
|
||||
__m128i o =
|
||||
_mm256_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(_mm_cvtsi128_si32(o));
|
||||
#else
|
||||
return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m512 v = _mm512_set1_ps(val);
|
||||
__m256i o =
|
||||
_mm512_cvtps_ph(v, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
return static_cast<std::uint16_t>(
|
||||
_mm_cvtsi128_si32(_mm256_castsi256_si128(o)));
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline float half2float_scalar(uint16_t val) {
|
||||
#if defined(CPU_CAPABILITY_AVX2)
|
||||
#if defined(_MSC_VER)
|
||||
__m128i v = _mm_cvtsi32_si128(val);
|
||||
__m256 o = _mm256_cvtph_ps(v);
|
||||
return _mm256_cvtss_f32(o);
|
||||
#else
|
||||
return _cvtsh_ss(val);
|
||||
#endif
|
||||
#elif defined(CPU_CAPABILITY_AVX512)
|
||||
__m256i v =
|
||||
_mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||
__m512 o = _mm512_cvtph_ps(v);
|
||||
return _mm512_cvtss_f32(o);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// Transpose a [2, 32] matrix to [32, 2]
|
||||
// Note: the output leading dimension should be 2,
|
||||
// that is, the output must be contiguous
|
||||
|
||||
@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl
|
||||
}
|
||||
|
||||
bool pinned_use_background_threads() override {
|
||||
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,10 @@
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
#include <fbgemm_gpu/torch_ops.h>
|
||||
#endif
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
@ -1216,7 +1220,7 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: if true, enables fast float8 accumulation
|
||||
// - `use_fast_accum`: if true, enables fast float8 accumulation. Backends may ignore this option if not applicable.
|
||||
// - `out`: a reference to the output tensor
|
||||
|
||||
Tensor&
|
||||
@ -1525,6 +1529,7 @@ namespace {
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// For TMA transfers, strides of output tensor have to be either
|
||||
// 1, or aligned to 16 bytes.
|
||||
const auto last_dim = out_size.size() - 1;
|
||||
@ -1536,9 +1541,10 @@ namespace {
|
||||
} else {
|
||||
out_stride = {out_size[1] * size_padded, size_padded, 1};
|
||||
}
|
||||
auto out = at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
|
||||
|
||||
return out;
|
||||
return at::empty_strided(out_size, out_stride, mat_a.options().dtype(out_dtype_));
|
||||
#else
|
||||
return at::empty(out_size, mat_a.options().dtype(out_dtype_));
|
||||
#endif
|
||||
}
|
||||
|
||||
bool check_valid_strides_and_return_transposed(const Tensor& mat) {
|
||||
@ -1619,12 +1625,9 @@ const std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& scale_result,
|
||||
std::optional<c10::ScalarType> out_dtype,
|
||||
bool use_fast_accum) {
|
||||
#ifndef USE_ROCM
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true);
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0");
|
||||
bool allowed_device = _scaled_mm_allowed_device();
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0, or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
@ -1664,6 +1667,10 @@ bool use_fast_accum) {
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
|
||||
|
||||
at::cuda::detail::f8f8bf16_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -1674,12 +1681,23 @@ bool use_fast_accum) {
|
||||
use_fast_accum,
|
||||
out);
|
||||
return out;
|
||||
|
||||
|
||||
|
||||
|
||||
#else
|
||||
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_a.scalar_type());
|
||||
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fnuz, "Expected mat_a to be Float8_e4m3fnuz matrix got ", mat_b.scalar_type());
|
||||
|
||||
fbgemm_gpu::f8f8bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
// FBGEMM expects B matrix shape to be (.., N, K)
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
offs,
|
||||
out);
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK(false, "grouped gemm is not supported without USE_FBGEMM_GENAI on ROCM")
|
||||
#endif
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
@ -38,17 +38,19 @@ static inline std::string _cudaGetErrorEnum(cufftResult error)
|
||||
return "CUFFT_INVALID_SIZE";
|
||||
case CUFFT_UNALIGNED_DATA:
|
||||
return "CUFFT_UNALIGNED_DATA";
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||
case CUFFT_INVALID_DEVICE:
|
||||
return "CUFFT_INVALID_DEVICE";
|
||||
case CUFFT_PARSE_ERROR:
|
||||
return "CUFFT_PARSE_ERROR";
|
||||
case CUFFT_NO_WORKSPACE:
|
||||
return "CUFFT_NO_WORKSPACE";
|
||||
case CUFFT_NOT_IMPLEMENTED:
|
||||
return "CUFFT_NOT_IMPLEMENTED";
|
||||
#if !defined(USE_ROCM)
|
||||
#if CUDA_VERSION <= 12090
|
||||
case CUFFT_INCOMPLETE_PARAMETER_LIST:
|
||||
return "CUFFT_INCOMPLETE_PARAMETER_LIST";
|
||||
case CUFFT_PARSE_ERROR:
|
||||
return "CUFFT_PARSE_ERROR";
|
||||
#endif
|
||||
#if !defined(USE_ROCM) && CUDA_VERSION <= 12090
|
||||
case CUFFT_LICENSE_ERROR:
|
||||
return "CUFFT_LICENSE_ERROR";
|
||||
#endif
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
|
||||
|
||||
// Determine if the architecture supports rowwise scaled mm
|
||||
// Currently failing on windows with:
|
||||
@ -44,6 +45,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers")
|
||||
|
||||
#include <ATen/native/cuda/cutlass_common.cuh>
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
// Two warninngs in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
|
||||
|
||||
// Determine if the architecture supports rowwise scaled mm
|
||||
// Currently failing on windows with:
|
||||
@ -44,6 +45,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ namespace at::cuda::jit {
|
||||
// Copied from aten/src/ATen/cuda/llvm_basic.cpp, then modified as above.
|
||||
// If not compiling for ROCm, return the original get_traits_string().
|
||||
std::string get_traits_string_but_hiprtc_safe() {
|
||||
#if defined(USE_ROCM) && ROCM_VERSION < 70000
|
||||
#if defined(USE_ROCM) && HIP_VERSION_MAJOR < 7
|
||||
return R"ESCAPE(
|
||||
namespace std {
|
||||
|
||||
|
||||
@ -342,8 +342,8 @@ Tensor rms_norm_symint(
|
||||
|
||||
if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Mismatch dtype between input and module: input dtype = ", input.dtype(),
|
||||
", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation"
|
||||
"Mismatch dtype between input and weight: input dtype = ", input.dtype(),
|
||||
", weight dtype = ", weight_opt.value().dtype(), ", Cannot dispatch to fused implementation."
|
||||
);
|
||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
|
||||
@ -22,6 +22,22 @@ struct PoolingParams {
|
||||
bool return_indices;
|
||||
};
|
||||
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct AvgPoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N - 2> kernel_size;
|
||||
::c10::metal::array<idx_type_t, N - 2> stride;
|
||||
::c10::metal::array<idx_type_t, N - 2> padding;
|
||||
bool count_include_pad;
|
||||
bool has_divisor_override;
|
||||
int32_t divisor_override;
|
||||
};
|
||||
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct PoolingBackwardParams {
|
||||
int32_t dims;
|
||||
|
||||
@ -292,12 +292,154 @@ kernel void max_pool_backward(
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
#define REGISTER_MAX_POOL_OP(DTYPE) \
|
||||
template <typename T>
|
||||
struct AvgPoolIterBounds {
|
||||
T start;
|
||||
T end;
|
||||
T count;
|
||||
};
|
||||
|
||||
template <int32_t dim>
|
||||
AvgPoolIterBounds<int32_t> get_avg_pool_input_iter_bounds(
|
||||
constant int32_t* input_sizes,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
bool count_include_pad) {
|
||||
auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim];
|
||||
auto end = start + kernel_size[dim];
|
||||
auto end_corrected = min(start + kernel_size[dim], input_sizes[dim]);
|
||||
auto start_corrected = (start < 0) ? 0 : start;
|
||||
auto count = count_include_pad
|
||||
? (min(end, input_sizes[dim] + padding[dim]) - start)
|
||||
: (end_corrected - start_corrected);
|
||||
return {start_corrected, end_corrected, count};
|
||||
}
|
||||
|
||||
// Iterates through all the input elements that this kernel needs to
|
||||
// apply max to. Specialized for 3 pooling dimensions.
|
||||
template <typename T>
|
||||
void avg_pool_3d_input_iter(
|
||||
constant T* input,
|
||||
device T* output,
|
||||
constant int32_t* input_sizes,
|
||||
constant int32_t* input_strides,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
bool count_include_pad,
|
||||
bool has_divisor_override,
|
||||
int32_t divisor_override) {
|
||||
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
|
||||
input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
|
||||
input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
|
||||
input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
|
||||
T value_sum = 0;
|
||||
auto divisor = has_divisor_override
|
||||
? divisor_override
|
||||
: (bounds0.count) * (bounds1.count) * (bounds2.count);
|
||||
auto size12 = input_sizes[1] * input_sizes[2];
|
||||
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
|
||||
auto offset0 = input_strides[0] * i0;
|
||||
|
||||
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
|
||||
auto offset1 = input_strides[1] * i1;
|
||||
|
||||
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
|
||||
auto offset2 = input_strides[2] * i2;
|
||||
auto input_value = input[offset0 + offset1 + offset2];
|
||||
value_sum += input_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
*output = value_sum / static_cast<T>(divisor);
|
||||
}
|
||||
|
||||
// Kernel computes one element of the output per kernel call.
|
||||
template <typename T>
|
||||
kernel void avg_pool(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant AvgPoolingParams<5>& params [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto input_sizes = params.input_sizes.data();
|
||||
auto input_strides = params.input_strides.data();
|
||||
auto output_sizes = params.output_sizes.data();
|
||||
auto output_strides = params.output_strides.data();
|
||||
auto kernel_size = params.kernel_size.data();
|
||||
auto stride = params.stride.data();
|
||||
auto padding = params.padding.data();
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// This buffer keeps track of the pooling dimension indices of this thread's
|
||||
// element of the output. We need to fill it with the proper values below.
|
||||
int32_t pooling_dim_indices[3];
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
/*indices_strides=*/nullptr,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/false,
|
||||
tid);
|
||||
|
||||
output += offsets.output;
|
||||
input += offsets.input_leading;
|
||||
input_sizes += leading_dims;
|
||||
input_strides += leading_dims;
|
||||
|
||||
avg_pool_3d_input_iter<T>(
|
||||
input,
|
||||
output,
|
||||
input_sizes,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
@ -309,19 +451,19 @@ kernel void max_pool_backward(
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_MAX_POOL_OP(float);
|
||||
REGISTER_MAX_POOL_OP(half);
|
||||
REGISTER_MAX_POOL_OP(int);
|
||||
REGISTER_MAX_POOL_OP(long);
|
||||
REGISTER_MAX_POOL_OP(short);
|
||||
REGISTER_MAX_POOL_OP(char);
|
||||
REGISTER_MAX_POOL_OP(uchar);
|
||||
REGISTER_MAX_POOL_OP(bool);
|
||||
REGISTER_POOL_OP(float);
|
||||
REGISTER_POOL_OP(half);
|
||||
REGISTER_POOL_OP(int);
|
||||
REGISTER_POOL_OP(long);
|
||||
REGISTER_POOL_OP(short);
|
||||
REGISTER_POOL_OP(char);
|
||||
REGISTER_POOL_OP(uchar);
|
||||
REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_MAX_POOL_OP(bfloat);
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -418,8 +418,9 @@ Tensor& exponential_mps_(Tensor& self, double lambda, std::optional<Generator> g
|
||||
MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil];
|
||||
return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil];
|
||||
};
|
||||
auto eps = std::numeric_limits<float>::epsilon();
|
||||
return mps::random_mps_impl<double>(self,
|
||||
0.0,
|
||||
eps,
|
||||
1.0,
|
||||
std::nullopt,
|
||||
std::nullopt,
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <ATen/ops/avg_pool2d_backward.h>
|
||||
#include <ATen/ops/avg_pool2d_backward_native.h>
|
||||
#include <ATen/ops/avg_pool2d_native.h>
|
||||
#include <ATen/ops/avg_pool3d_native.h>
|
||||
#include <ATen/ops/max_pool2d_backward_native.h>
|
||||
#include <ATen/ops/max_pool2d_native.h>
|
||||
#include <ATen/ops/max_pool2d_with_indices_backward_native.h>
|
||||
@ -265,13 +266,13 @@ using PoolSizes = std::tuple<int32_t,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>>;
|
||||
std::optional<std::vector<int32_t>>>;
|
||||
|
||||
static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
std::optional<IntArrayRef> dilation_opt,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
@ -305,18 +306,22 @@ static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
pooling_dims,
|
||||
" ints");
|
||||
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == pooling_dims,
|
||||
op_name,
|
||||
": dilation must be either a single int, or a tuple of ",
|
||||
pooling_dims,
|
||||
" ints");
|
||||
if (dilation_opt.has_value()) {
|
||||
auto dilation = dilation_opt.value();
|
||||
TORCH_CHECK(dilation.size() == 1 || dilation.size() == pooling_dims,
|
||||
op_name,
|
||||
": dilation must be either a single int, or a tuple of ",
|
||||
pooling_dims,
|
||||
" ints");
|
||||
}
|
||||
|
||||
int32_t leading_dims = input.dim() - pooling_dims;
|
||||
|
||||
const auto kernel_size_expanded = copy_and_maybe_expand(kernel_size, pooling_dims);
|
||||
const auto stride_expanded = copy_and_maybe_expand(stride.empty() ? kernel_size : stride, pooling_dims);
|
||||
const auto padding_expanded = copy_and_maybe_expand(padding, pooling_dims);
|
||||
const auto dilation_expanded = copy_and_maybe_expand(dilation, pooling_dims);
|
||||
const auto dilation_expanded = dilation_opt.has_value() ? copy_and_maybe_expand(dilation_opt.value(), pooling_dims)
|
||||
: std::vector<int32_t>(pooling_dims, 1);
|
||||
|
||||
for (const auto dim : c10::irange(pooling_dims)) {
|
||||
TORCH_CHECK(padding_expanded[dim] >= 0, op_name, ": pad must be non-negative");
|
||||
@ -362,7 +367,12 @@ static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
output_size[leading_dims + dim] = output_pooling_size[dim];
|
||||
}
|
||||
|
||||
return PoolSizes(dims, output_size, kernel_size_expanded, stride_expanded, padding_expanded, dilation_expanded);
|
||||
return PoolSizes(dims,
|
||||
output_size,
|
||||
kernel_size_expanded,
|
||||
stride_expanded,
|
||||
padding_expanded,
|
||||
dilation_opt.has_value() ? std::make_optional(dilation_expanded) : std::nullopt);
|
||||
}
|
||||
|
||||
static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
@ -375,8 +385,10 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, output_size, kernel_size, stride, padding, dilation] =
|
||||
auto [dims, output_size, kernel_size, stride, padding, dilation_opt] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
|
||||
TORCH_INTERNAL_ASSERT(dilation_opt.has_value());
|
||||
auto dilation = dilation_opt.value();
|
||||
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
|
||||
const bool return_indices = indices.defined();
|
||||
|
||||
@ -442,7 +454,7 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
bool ceil_mode,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, output_size, kernel_size, stride, padding, dilation] =
|
||||
auto [dims, output_size, kernel_size, stride, padding, dilation_opt] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
@ -601,6 +613,62 @@ static void avg_pool2d_template(const Tensor& input,
|
||||
op_name);
|
||||
}
|
||||
|
||||
static void avg_pool_out_mps_template(const Tensor& output,
|
||||
const Tensor& input,
|
||||
IntArrayRef _kernel_size,
|
||||
IntArrayRef _stride,
|
||||
IntArrayRef _padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, output_size, kernel_size, stride, padding, _] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
output.resize_(output_size, memory_format);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = output.numel();
|
||||
|
||||
AvgPoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
params.count_include_pad = count_include_pad;
|
||||
params.has_divisor_override = divisor_override.has_value();
|
||||
if (divisor_override.has_value()) {
|
||||
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
|
||||
}
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
|
||||
}
|
||||
|
||||
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("avg_pool_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, input, output, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor mps_max_pool2d(const Tensor& input,
|
||||
@ -876,4 +944,25 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
|
||||
"avg_pool2d_backward");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool3d_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const Tensor& output) {
|
||||
mps::avg_pool_out_mps_template(output,
|
||||
input,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
/*pooling_dims=*/3,
|
||||
"avg_pool3d");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -7124,18 +7124,21 @@
|
||||
dispatch:
|
||||
CPU: _scaled_mm_cpu
|
||||
CUDA: _scaled_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_out_cpu
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
|
||||
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
@ -12334,6 +12337,7 @@
|
||||
dispatch:
|
||||
CPU: avg_pool3d_out_cpu
|
||||
CUDA: avg_pool3d_out_cuda
|
||||
MPS: avg_pool3d_out_mps
|
||||
MkldnnCPU: mkldnn_avg_pool3d_out
|
||||
|
||||
- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
|
||||
|
||||
@ -955,7 +955,10 @@ static at::Tensor fp8_qlinear_onednn_ref(
|
||||
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
|
||||
w_scales_new_shape[0] = -1;
|
||||
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
|
||||
auto y_f32 = at::linear(dqx, dqw, bias);
|
||||
auto y_f32 = at::linear(dqx, dqw);
|
||||
if (bias.has_value()) {
|
||||
y_f32 += bias.value().to(at::kFloat);
|
||||
}
|
||||
if (binary_post_op == "none") {
|
||||
if (unary_post_op == "relu") {
|
||||
at::relu_(y_f32);
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <test/cpp/tensorexpr/test_base.h>
|
||||
#include <thread>
|
||||
|
||||
|
||||
@ -10,7 +9,7 @@
|
||||
// numbers of threads set and also whether the scheduler
|
||||
// will throw an exception when multiple threads call
|
||||
// their first parallel construct.
|
||||
static void test(int given_num_threads) {
|
||||
void test(int given_num_threads) {
|
||||
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
|
||||
ASSERT_TRUE(given_num_threads >= 0);
|
||||
ASSERT_EQ(at::get_num_threads(), given_num_threads);
|
||||
@ -20,7 +19,7 @@ static void test(int given_num_threads) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ThreadInitTest, ThreadInit) {
|
||||
int main() {
|
||||
at::init_num_threads();
|
||||
|
||||
at::set_num_threads(4);
|
||||
@ -33,11 +32,13 @@ TEST(ThreadInitTest, ThreadInit) {
|
||||
|
||||
#if !AT_PARALLEL_NATIVE
|
||||
at::set_num_threads(5);
|
||||
ASSERT_EQ(at::get_num_threads(), 5);
|
||||
ASSERT_TRUE(at::get_num_threads() == 5);
|
||||
#endif
|
||||
|
||||
// test inter-op settings
|
||||
at::set_num_interop_threads(5);
|
||||
ASSERT_EQ(at::get_num_interop_threads(), 5);
|
||||
ASSERT_ANY_THROW(at::set_num_interop_threads(6));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@ flaky_models = {
|
||||
"gluon_inception_v3",
|
||||
"detectron2_maskrcnn_r_101_c4",
|
||||
"XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148
|
||||
"detectron2_fcos_r_50_fpn",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -346,7 +346,7 @@ vgg16,pass,0
|
||||
|
||||
|
||||
|
||||
vision_maskrcnn,fail_accuracy,30
|
||||
vision_maskrcnn,fail_accuracy,29
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,32 +1,32 @@
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.10
|
||||
add_loop_eager,compile_time_instruction_count,3070000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,30280000000,0.10
|
||||
add_loop_inductor,compile_time_instruction_count,30280000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.1
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -34,56 +34,56 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.10
|
||||
update_hint_regression,compile_time_instruction_count,1719000000,0.1
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,966100000,0.10
|
||||
sum_floordiv_regression,compile_time_instruction_count,966100000,0.1
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3237000000,0.10
|
||||
symint_sum,compile_time_instruction_count,3237000000,0.1
|
||||
|
||||
|
||||
|
||||
symint_sum_loop,compile_time_instruction_count,4299000000,0.10
|
||||
symint_sum_loop,compile_time_instruction_count,4299000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10
|
||||
aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.1
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10
|
||||
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
|
||||
|
||||
|
||||
|
||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10
|
||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10
|
||||
basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.1
|
||||
|
||||
|
@ -944,6 +944,7 @@ def define_buck_targets(
|
||||
[
|
||||
("torch/csrc/api/include", "torch/**/*.h"),
|
||||
("", "torch/csrc/**/*.h"),
|
||||
("", "torch/nativert/**/*.h"),
|
||||
("", "torch/headeronly/**/*.h"),
|
||||
("", "torch/script.h"),
|
||||
("", "torch/library.h"),
|
||||
|
||||
@ -593,11 +593,13 @@ libtorch_core_jit_sources = sorted(jit_sources_full)
|
||||
|
||||
|
||||
libtorch_nativert_sources = [
|
||||
"torch/nativert/ModelRunner.cpp",
|
||||
"torch/nativert/graph/Graph.cpp",
|
||||
"torch/nativert/graph/GraphPasses.cpp",
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
"torch/nativert/graph/Serialization.cpp",
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
"torch/nativert/graph/GraphUtils.cpp",
|
||||
"torch/nativert/executor/DelegateExecutor.cpp",
|
||||
"torch/nativert/executor/Placement.cpp",
|
||||
"torch/nativert/executor/ExecutionPlanner.cpp",
|
||||
@ -864,6 +866,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/QScheme.cpp",
|
||||
"torch/csrc/Module.cpp",
|
||||
"torch/csrc/PyInterpreter.cpp",
|
||||
"torch/csrc/PyInterpreterHooks.cpp",
|
||||
"torch/csrc/python_dimname.cpp",
|
||||
"torch/csrc/Size.cpp",
|
||||
"torch/csrc/Storage.cpp",
|
||||
@ -986,6 +989,7 @@ libtorch_python_core_sources = [
|
||||
"torch/csrc/utils/verbose.cpp",
|
||||
"torch/csrc/cpu/Module.cpp",
|
||||
"torch/csrc/instruction_counter/Module.cpp",
|
||||
"torch/nativert/python/Bindings.cpp",
|
||||
] + lazy_tensor_core_python_sources
|
||||
|
||||
libtorch_python_distributed_core_sources = [
|
||||
|
||||
241
c10/core/AllocatorConfig.cpp
Normal file
241
c10/core/AllocatorConfig.cpp
Normal file
@ -0,0 +1,241 @@
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
namespace c10::CachingAllocator {
|
||||
|
||||
namespace {
|
||||
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
|
||||
constexpr size_t kMB = 1024 * 1024ul;
|
||||
constexpr size_t kRoundUpPowerOfTwoStart = 1 * kMB; // 1MB
|
||||
constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB
|
||||
} // anonymous namespace
|
||||
|
||||
AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() {
|
||||
static AcceleratorAllocatorConfig instance;
|
||||
#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \
|
||||
auto env##_name = c10::utils::get_env(#env); \
|
||||
if (env##_name.has_value()) { \
|
||||
if (deprecated) { \
|
||||
TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \
|
||||
} \
|
||||
instance.parseArgs(env##_name.value()); \
|
||||
return true; \
|
||||
}
|
||||
static bool env_flag [[maybe_unused]] = []() {
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false)
|
||||
// Keep this for backwards compatibility
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true)
|
||||
C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true)
|
||||
return false;
|
||||
}();
|
||||
#undef C10_ALLOCATOR_CONFIG_PARSE_ENV
|
||||
return instance;
|
||||
}
|
||||
|
||||
AcceleratorAllocatorConfig::AcceleratorAllocatorConfig() {
|
||||
roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
|
||||
size_t log_size = (63 - llvm::countLeadingZeros(size));
|
||||
|
||||
// Our intervals start at 1MB and end at 64GB
|
||||
const size_t interval_start =
|
||||
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
|
||||
const size_t interval_end =
|
||||
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
|
||||
TORCH_CHECK_VALUE(
|
||||
interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
|
||||
"kRoundUpPowerOfTwoIntervals mismatch");
|
||||
|
||||
size_t index =
|
||||
(log_size > interval_start) ? (log_size - interval_start) : 0ul;
|
||||
index = std::min(index, kRoundUpPowerOfTwoIntervals - 1);
|
||||
return instance().roundup_power2_divisions_[index];
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB;
|
||||
constexpr size_t max_allowed_split_size_mb =
|
||||
std::numeric_limits<size_t>::max() / kMB;
|
||||
|
||||
size_t val_env = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
val_env >= min_allowed_split_size_mb,
|
||||
"CachingAllocator option max_split_size_mb too small, must be >= ",
|
||||
min_allowed_split_size_mb);
|
||||
val_env = std::min(val_env, max_allowed_split_size_mb);
|
||||
max_split_size_ = val_env * kMB;
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB;
|
||||
constexpr size_t max_allowed_split_size_mb =
|
||||
std::numeric_limits<size_t>::max() / kMB;
|
||||
|
||||
size_t val_env = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
val_env >= min_allowed_split_size_mb,
|
||||
"CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
|
||||
min_allowed_split_size_mb);
|
||||
val_env = std::min(val_env, max_allowed_split_size_mb);
|
||||
max_non_split_rounding_size_ = val_env * kMB;
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
double val_env = tokenizer.toDouble(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
val_env > 0 && val_env < 1.0,
|
||||
"garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
|
||||
garbage_collection_threshold_ = val_env;
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
bool first_value = true;
|
||||
|
||||
if (tokenizer[++i] == "[") {
|
||||
size_t last_index = 0;
|
||||
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
||||
while (++i < tokenizer.size() && tokenizer[i] != "]") {
|
||||
size_t value_index = i;
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t value = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
value == 0 || llvm::isPowerOf2_64(value),
|
||||
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
|
||||
|
||||
if (tokenizer[value_index] == ">") {
|
||||
std::fill(
|
||||
std::next(
|
||||
roundup_power2_divisions_.begin(),
|
||||
static_cast<std::vector<size_t>::difference_type>(
|
||||
last_index + 1)),
|
||||
roundup_power2_divisions_.end(),
|
||||
value);
|
||||
} else {
|
||||
size_t boundary = tokenizer.toSizeT(value_index);
|
||||
TORCH_CHECK_VALUE(
|
||||
llvm::isPowerOf2_64(boundary),
|
||||
"For roundups, the intervals have to be power of 2 ");
|
||||
|
||||
size_t index = 63 - llvm::countLeadingZeros(boundary);
|
||||
index =
|
||||
std::clamp(index, size_t{0}, roundup_power2_divisions_.size() - 1);
|
||||
|
||||
if (first_value) {
|
||||
std::fill(
|
||||
roundup_power2_divisions_.begin(),
|
||||
std::next(
|
||||
roundup_power2_divisions_.begin(),
|
||||
static_cast<std::vector<size_t>::difference_type>(index)),
|
||||
value);
|
||||
first_value = false;
|
||||
}
|
||||
roundup_power2_divisions_[index] = value;
|
||||
last_index = index;
|
||||
}
|
||||
|
||||
if (tokenizer[i + 1] != "]") {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i < tokenizer.size(),
|
||||
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
|
||||
} else { // Keep this for backwards compatibility
|
||||
size_t value = tokenizer.toSizeT(i);
|
||||
TORCH_CHECK_VALUE(
|
||||
llvm::isPowerOf2_64(value),
|
||||
"For roundups, the divisions has to be power of 2 ");
|
||||
std::fill(
|
||||
roundup_power2_divisions_.begin(),
|
||||
roundup_power2_divisions_.end(),
|
||||
value);
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parseExpandableSegments(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
use_expandable_segments_ = tokenizer.toBool(++i);
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t AcceleratorAllocatorConfig::parsePinnedUseBackgroundThreads(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
pinned_use_background_threads_ = tokenizer.toBool(++i);
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
|
||||
// The following option will be reset to its default value if not explicitly
|
||||
// set each time.
|
||||
max_split_size_ = std::numeric_limits<size_t>::max();
|
||||
roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
garbage_collection_threshold_ = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(last_allocator_settings_mutex_);
|
||||
last_allocator_settings_ = env;
|
||||
}
|
||||
|
||||
ConfigTokenizer tokenizer(env);
|
||||
for (size_t i = 0; i < tokenizer.size(); i++) {
|
||||
const auto& key = tokenizer[i];
|
||||
if (key == "max_split_size_mb") {
|
||||
i = parseMaxSplitSize(tokenizer, i);
|
||||
} else if (key == "max_non_split_rounding_mb") {
|
||||
i = parseMaxNonSplitRoundingSize(tokenizer, i);
|
||||
} else if (key == "garbage_collection_threshold") {
|
||||
i = parseGarbageCollectionThreshold(tokenizer, i);
|
||||
} else if (key == "roundup_power2_divisions") {
|
||||
i = parseRoundUpPower2Divisions(tokenizer, i);
|
||||
} else if (key == "expandable_segments") {
|
||||
i = parseExpandableSegments(tokenizer, i);
|
||||
} else if (key == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(tokenizer, i);
|
||||
} else {
|
||||
// If a device-specific configuration parser hook is registered, it will
|
||||
// check if the key is unrecognized.
|
||||
if (device_config_parser_hook_) {
|
||||
TORCH_CHECK(
|
||||
keys_.find(key) != keys_.end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in Accelerator allocator config.");
|
||||
}
|
||||
i = tokenizer.skipKey(i);
|
||||
}
|
||||
|
||||
if (i + 1 < tokenizer.size()) {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10::CachingAllocator
|
||||
372
c10/core/AllocatorConfig.h
Normal file
372
c10/core/AllocatorConfig.h
Normal file
@ -0,0 +1,372 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::CachingAllocator {
|
||||
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
const size_t kLargeBuffer = 20971520;
|
||||
|
||||
// A utility class for tokenizing allocator configuration strings into discrete
|
||||
// parts. For example, the config string:
|
||||
// "key1:val1,key2:[val2,val3]"
|
||||
// is tokenized into:
|
||||
// "key1", ":", "val1", ",", "key2", ":", "[", "val2", ",", "val3", "]",
|
||||
//
|
||||
// Tokens include keys, values, and special characters (':', ',', '[', ']').
|
||||
// Whitespace is ignored.
|
||||
class ConfigTokenizer {
|
||||
public:
|
||||
explicit ConfigTokenizer(const std::string& env) {
|
||||
std::string buffer;
|
||||
for (char ch : env) {
|
||||
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
|
||||
if (!buffer.empty()) {
|
||||
config_.emplace_back(std::move(buffer));
|
||||
buffer.clear();
|
||||
}
|
||||
config_.emplace_back(1, ch);
|
||||
} else if (!std::isspace(static_cast<unsigned char>(ch))) {
|
||||
buffer += ch;
|
||||
}
|
||||
}
|
||||
if (!buffer.empty()) {
|
||||
config_.emplace_back(std::move(buffer));
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& operator[](size_t i) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i < config_.size(), "Index out of bounds in ConfigTokenizer");
|
||||
return config_[i];
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return config_.size();
|
||||
}
|
||||
|
||||
bool checkToken(size_t i, const std::string& token) const {
|
||||
checkIndex(i);
|
||||
return config_[i] == token;
|
||||
}
|
||||
|
||||
size_t toSizeT(size_t i) const {
|
||||
checkIndex(i);
|
||||
return std::stoull(config_[i]);
|
||||
}
|
||||
|
||||
double toDouble(size_t i) const {
|
||||
checkIndex(i);
|
||||
return std::stod(config_[i]);
|
||||
}
|
||||
|
||||
bool toBool(size_t i) const {
|
||||
checkIndex(i);
|
||||
const auto& token = config_[i];
|
||||
if (token == "True") {
|
||||
return true;
|
||||
} else if (token == "False") {
|
||||
return false;
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(
|
||||
false,
|
||||
"Expected 'True' or 'False' at index ",
|
||||
i,
|
||||
" in ConfigTokenizer but got '",
|
||||
token,
|
||||
"'");
|
||||
}
|
||||
}
|
||||
|
||||
// Skips the current token group and returns the index of the value token.
|
||||
// Assumes the current index `i` points to a key name in a key-value pair.
|
||||
size_t skipKey(size_t i) const {
|
||||
// Expect a colon after the key
|
||||
checkToken(++i, ":");
|
||||
|
||||
++i; // Move to the value
|
||||
checkIndex(i);
|
||||
if (config_[i] != "[") {
|
||||
// Value is a single token (not a list) -> return its index
|
||||
return i;
|
||||
}
|
||||
|
||||
// Skip tokens inside the list until matching ']'
|
||||
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
||||
while (++i < config_.size() && config_[i] != "]") {
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i < config_.size(),
|
||||
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
|
||||
|
||||
return i; // Return the index of the closing ']'
|
||||
}
|
||||
|
||||
private:
|
||||
void checkIndex(size_t i) const {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
i < config_.size(), "Index out of bounds in ConfigTokenizer");
|
||||
}
|
||||
|
||||
std::vector<std::string> config_;
|
||||
};
|
||||
|
||||
/**
|
||||
* Note [AcceleratorAllocatorConfig design]
|
||||
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
* This class configures memory allocation for both device and host memory. A
|
||||
* single `AcceleratorAllocatorConfig` instance is shared across all accelerator
|
||||
* backends, such as CUDA and XPU, under the assumption that relevant
|
||||
* environment variables apply uniformly to all accelerators. Device-specific
|
||||
* configuration extensions are supported via hooks (see
|
||||
* `registerDeviceConfigParserHook`).
|
||||
*
|
||||
* Recommended design:
|
||||
* - Place common configurations in `AcceleratorAllocatorConfig`.
|
||||
* - Extend backend-specific configurations in corresponding device-specific
|
||||
* classes, such as `CUDAAllocatorConfig`, etc.
|
||||
*
|
||||
* Scope:
|
||||
* - Configuration options must be environment-variable driven.
|
||||
*
|
||||
* Naming Convention:
|
||||
* - Public API names in `AcceleratorAllocatorConfig` should be device-generic.
|
||||
* - Members prefixed with `pinned_` are specific to the host/pinned allocator.
|
||||
* - Environment variable names should be generic across backends.
|
||||
* - Comma-separated key-value pairs in the format: `key:value`. Use square
|
||||
* brackets `[]` for list values Example: `key1:123, key2:[val1,val2]`
|
||||
*
|
||||
* Environment Variables:
|
||||
* - The primary environment variable for configuration is `PYTORCH_ALLOC_CONF`.
|
||||
* - For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` is also supported
|
||||
* with lower priority.
|
||||
*/
|
||||
|
||||
class C10_API AcceleratorAllocatorConfig {
|
||||
public:
|
||||
static AcceleratorAllocatorConfig& instance();
|
||||
|
||||
C10_DISABLE_COPY_AND_ASSIGN(AcceleratorAllocatorConfig);
|
||||
AcceleratorAllocatorConfig(AcceleratorAllocatorConfig&&) = delete;
|
||||
AcceleratorAllocatorConfig& operator=(AcceleratorAllocatorConfig&&) = delete;
|
||||
~AcceleratorAllocatorConfig() = default;
|
||||
|
||||
/* Device allocator settings */
|
||||
|
||||
// Returns the maximum block size (in MB) that is allowed to be split. The
|
||||
// default is unlimited (all blocks can be split).
|
||||
static size_t max_split_size() {
|
||||
return instance().max_split_size_;
|
||||
}
|
||||
|
||||
// Returns the maximum block size (in MB) that is allowed to be rounded up
|
||||
// without requiring splitting when searching for a free block. The default is
|
||||
// 20 MiB.
|
||||
static size_t max_non_split_rounding_size() {
|
||||
return instance().max_non_split_rounding_size_;
|
||||
}
|
||||
|
||||
// Return the number of divisions used when rounding up allocation sizes (in
|
||||
// MB) to the nearest power-of-2 boundary.
|
||||
static size_t roundup_power2_divisions(size_t size);
|
||||
|
||||
// Returns the vector of division factors used for rounding up allocation
|
||||
// sizes. These divisions apply to size intervals between 1MB and 64GB.
|
||||
static const std::vector<size_t>& roundup_power2_divisions() {
|
||||
return instance().roundup_power2_divisions_;
|
||||
}
|
||||
|
||||
// Returns the threshold that triggers garbage collection when the ratio of
|
||||
// used memory to maximum allowed memory exceeds this value. The default is 0,
|
||||
// meaning no garbage collection is triggered. The value should be in the
|
||||
// range (0.0, 1.0).
|
||||
static double garbage_collection_threshold() {
|
||||
return instance().garbage_collection_threshold_;
|
||||
}
|
||||
|
||||
// Returns whether the expandable segment feature is enabled. This allows the
|
||||
// allocator to start with one segment that grows as needed, rather than
|
||||
// creating a new segment for each allocation. Default is false (expandable
|
||||
// segments disabled).
|
||||
static bool use_expandable_segments() {
|
||||
return instance().use_expandable_segments_;
|
||||
}
|
||||
|
||||
/* Host allocator settings */
|
||||
|
||||
// Returns whether the pinned host allocator uses background threads for
|
||||
// processing events. This is useful for improving performance in scenarios
|
||||
// where many small allocations are made. Default is false (background threads
|
||||
// disabled).
|
||||
static bool pinned_use_background_threads() {
|
||||
return instance().pinned_use_background_threads_;
|
||||
}
|
||||
|
||||
/* Settings for both device and host allocator */
|
||||
|
||||
// Returns the current allocator settings as a string. This string is useful
|
||||
// to expand device-specific allocator configurations
|
||||
static std::string last_allocator_settings() {
|
||||
std::lock_guard<std::mutex> lock(instance().last_allocator_settings_mutex_);
|
||||
return instance().last_allocator_settings_;
|
||||
}
|
||||
|
||||
// Returns the set of valid keys for the allocator configuration.
|
||||
// This set is used to validate the presence and correctness of keys in
|
||||
// device-specific configuration parsers.
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
}
|
||||
|
||||
// Registers a device-specific configuration parser hook and its key. This
|
||||
// allows backends to parse additional device-specific configuration options
|
||||
// from the environment variable. The hook should be a function that takes a
|
||||
// string (the environment variable value) and parses it to set
|
||||
// device-specific configuration options. The hook will be called when the
|
||||
// environment variable is parsed. If a hook is already registered, it will be
|
||||
// replaced with the new one.
|
||||
static void registerDeviceConfigParserHook(
|
||||
std::function<void(const std::string&)>&& hook,
|
||||
const std::unordered_set<std::string>& keys) {
|
||||
device_config_parser_hook_ = std::move(hook);
|
||||
for (auto& key : keys) {
|
||||
TORCH_CHECK(
|
||||
keys_.insert(key).second,
|
||||
"Duplicated key '",
|
||||
key,
|
||||
"' found in device-specific configuration parser hook registration");
|
||||
}
|
||||
}
|
||||
|
||||
// Calls the registered device-specific configuration parser hook with the
|
||||
// provided environment string. This allows backends to parse additional
|
||||
// device-specific configuration options from the environment variable.
|
||||
// If no hook is registered, this function does nothing.
|
||||
static void callDeviceConfigParserHook(const std::string& env) {
|
||||
if (device_config_parser_hook_) {
|
||||
device_config_parser_hook_(env);
|
||||
}
|
||||
}
|
||||
|
||||
// Parses the environment variable `env` to update the allocator settings.
|
||||
// If the environment variable is not set, it does nothing.
|
||||
// The configuration string should be a comma-separated list of key-value
|
||||
// pairs, where each key is a configuration option and the value is the
|
||||
// corresponding setting. For example:
|
||||
// "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true"
|
||||
void parseArgs(const std::string& env);
|
||||
|
||||
private:
|
||||
AcceleratorAllocatorConfig();
|
||||
|
||||
/* Internal functions for device allocator */
|
||||
|
||||
// Parse `max_split_size_mb` from environment variable.
|
||||
size_t parseMaxSplitSize(const ConfigTokenizer& tokenizer, size_t i);
|
||||
// Parse `max_non_split_rounding_mb` from environment variable.
|
||||
size_t parseMaxNonSplitRoundingSize(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
// Parse `garbage_collection_threshold` from environment variable.
|
||||
size_t parseGarbageCollectionThreshold(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
// Parse `roundup_power2_divisions` from environment variable.
|
||||
size_t parseRoundUpPower2Divisions(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
// Parse `expandable_segments` from environment variable.
|
||||
size_t parseExpandableSegments(const ConfigTokenizer& tokenizer, size_t i);
|
||||
|
||||
/* Internal functions for host allocator */
|
||||
|
||||
// Parse `pinned_use_background_threads` from environment variable.
|
||||
size_t parsePinnedUseBackgroundThreads(
|
||||
const ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
|
||||
/* The following members are specifically used for the device allocator. */
|
||||
|
||||
// The maximum block size that is allowed to be split.
|
||||
std::atomic<size_t> max_split_size_{std::numeric_limits<size_t>::max()};
|
||||
// The maximum allowable extra size of a memory block without requiring
|
||||
// splitting when searching for a free block.
|
||||
std::atomic<size_t> max_non_split_rounding_size_{kLargeBuffer};
|
||||
// Used to store how memory allocations of different sizes should be rounded
|
||||
// up to the nearest power of 2 divisions.
|
||||
std::vector<size_t> roundup_power2_divisions_;
|
||||
// The threshold that triggers garbage collection when the ratio of used
|
||||
// memory to maximum allowed memory exceeds this value.
|
||||
std::atomic<double> garbage_collection_threshold_{0};
|
||||
// A flag to enable expandable segments feature.
|
||||
std::atomic<bool> use_expandable_segments_{false};
|
||||
|
||||
/* The following members are specifically used for the host allocator. */
|
||||
|
||||
// A flag to enable background thread for processing events.
|
||||
std::atomic<bool> pinned_use_background_threads_{false};
|
||||
|
||||
/* The following members are used for both device and host allocator. */
|
||||
|
||||
// Record the last allocator config environment setting.
|
||||
std::mutex last_allocator_settings_mutex_;
|
||||
std::string last_allocator_settings_;
|
||||
|
||||
// Optional hook for parsing additional device-specific allocator settings.
|
||||
// This allows backends (e.g., CUDA, XPU) to register a custom parser for
|
||||
// their own environment configuration extensions.
|
||||
inline static std::function<void(const std::string&)>
|
||||
device_config_parser_hook_{nullptr};
|
||||
|
||||
// A set of valid configuration keys, including both common and
|
||||
// device-specific options. This set is used to validate the presence and
|
||||
// correctness of keys during parsing.
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"max_split_size_mb",
|
||||
"max_non_split_rounding_mb",
|
||||
"garbage_collection_threshold",
|
||||
"roundup_power2_divisions",
|
||||
"expandable_segments",
|
||||
"pinned_use_background_threads"};
|
||||
};
|
||||
|
||||
C10_API inline void setAllocatorSettings(const std::string& env) {
|
||||
AcceleratorAllocatorConfig::instance().parseArgs(env);
|
||||
AcceleratorAllocatorConfig::callDeviceConfigParserHook(env);
|
||||
}
|
||||
|
||||
C10_API inline std::string getAllocatorSettings() {
|
||||
return AcceleratorAllocatorConfig::instance().last_allocator_settings();
|
||||
}
|
||||
|
||||
struct DeviceConfigParserHookRegistry {
|
||||
explicit DeviceConfigParserHookRegistry(
|
||||
std::function<void(const std::string&)>&& hook,
|
||||
const std::unordered_set<std::string>& keys) {
|
||||
// Use static method to avoid static initialization order fiasco issues
|
||||
AcceleratorAllocatorConfig::registerDeviceConfigParserHook(
|
||||
std::move(hook), keys);
|
||||
}
|
||||
};
|
||||
|
||||
// Assume each config parser has `parseArgs` and `getKeys` methods
|
||||
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \
|
||||
namespace { \
|
||||
static at::CachingAllocator::DeviceConfigParserHookRegistry \
|
||||
g_device_config_parse_hook_registry_instance( \
|
||||
[](const std::string& env) { \
|
||||
parser_cls::instance().parseArgs(env); \
|
||||
}, \
|
||||
parser_cls::getKeys()); \
|
||||
}
|
||||
|
||||
} // namespace c10::CachingAllocator
|
||||
@ -240,24 +240,4 @@ struct C10_API PyInterpreter {
|
||||
void disarm() noexcept;
|
||||
};
|
||||
|
||||
// PyInterpreterStatus describes what the state of its interpreter tag
|
||||
// is, relative to the thread currently holding the GIL.
|
||||
enum class PyInterpreterStatus {
|
||||
// We just allocated the Tensor, it hasn't escaped to other threads,
|
||||
// we know that it definitely hasn't been tagged to be associated
|
||||
// with an interpreter.
|
||||
DEFINITELY_UNINITIALIZED,
|
||||
// We queried the interpreter field and it looked uninitialized. But
|
||||
// another thread may have raced with us to tag it with some other
|
||||
// interpreter id. So we will have to do a CEX to make sure we can
|
||||
// actually nab it.
|
||||
MAYBE_UNINITIALIZED,
|
||||
// We queried the interpreter field and it was tagged to belong to us.
|
||||
// This means we have sole write access (as we hold the GIL for this
|
||||
// interpreter)
|
||||
TAGGED_BY_US,
|
||||
// Someone else tagged this. We can't use this TensorImpl from Python.
|
||||
TAGGED_BY_OTHER,
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
|
||||
32
c10/core/impl/PyInterpreterHooks.cpp
Normal file
32
c10/core/impl/PyInterpreterHooks.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <c10/core/impl/PyInterpreterHooks.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
// Define the registry
|
||||
C10_DEFINE_REGISTRY(
|
||||
PyInterpreterHooksRegistry,
|
||||
PyInterpreterHooksInterface,
|
||||
PyInterpreterHooksArgs)
|
||||
|
||||
const PyInterpreterHooksInterface& getPyInterpreterHooks() {
|
||||
auto create_impl = [] {
|
||||
#if !defined C10_MOBILE
|
||||
auto hooks = PyInterpreterHooksRegistry()->Create(
|
||||
"PyInterpreterHooks", PyInterpreterHooksArgs{});
|
||||
if (hooks) {
|
||||
return hooks;
|
||||
}
|
||||
#endif
|
||||
// Return stub implementation that will throw errors when methods are called
|
||||
return std::make_unique<PyInterpreterHooksInterface>();
|
||||
};
|
||||
static auto hooks = create_impl();
|
||||
return *hooks;
|
||||
}
|
||||
|
||||
// Main function to get global PyInterpreter
|
||||
PyInterpreter* getGlobalPyInterpreter() {
|
||||
return getPyInterpreterHooks().getPyInterpreter();
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
39
c10/core/impl/PyInterpreterHooks.h
Normal file
39
c10/core/impl/PyInterpreterHooks.h
Normal file
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Registry.h>
|
||||
#include <memory>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
// Minimal interface for PyInterpreter hooks
|
||||
struct C10_API PyInterpreterHooksInterface {
|
||||
virtual ~PyInterpreterHooksInterface() = default;
|
||||
|
||||
// Get the PyInterpreter instance
|
||||
// Stub implementation throws error when Python is not available
|
||||
virtual PyInterpreter* getPyInterpreter() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"PyTorch was compiled without Python support. "
|
||||
"Cannot access Python interpreter from C++.");
|
||||
}
|
||||
};
|
||||
|
||||
struct C10_API PyInterpreterHooksArgs{};
|
||||
|
||||
C10_DECLARE_REGISTRY(
|
||||
PyInterpreterHooksRegistry,
|
||||
PyInterpreterHooksInterface,
|
||||
PyInterpreterHooksArgs);
|
||||
|
||||
#define REGISTER_PYTHON_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(PyInterpreterHooksRegistry, clsname, clsname)
|
||||
|
||||
// Get the global PyInterpreter hooks instance
|
||||
C10_API const PyInterpreterHooksInterface& getPyInterpreterHooks();
|
||||
|
||||
C10_API PyInterpreter* getGlobalPyInterpreter();
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -34,29 +34,12 @@ PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||
}
|
||||
|
||||
void PyObjectSlot::unchecked_clear_pyobj(PyInterpreter* interpreter) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(interpreter == pyobj_interpreter_.load());
|
||||
pyobj_ = nullptr;
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"cannot access PyObject for Tensor on interpreter ",
|
||||
(*pyobj_interpreter_.load())->name());
|
||||
}
|
||||
|
||||
bool PyObjectSlot::check_interpreter(PyInterpreter* interpreter) {
|
||||
return interpreter == pyobj_interpreter();
|
||||
}
|
||||
|
||||
bool PyObjectSlot::has_pyobj_nonhermetic() {
|
||||
return check_pyobj(pyobj_interpreter(), /*ignore_hermetic_tls=*/true)
|
||||
.has_value();
|
||||
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
#include <c10/core/impl/PyInterpreter.h>
|
||||
#include <c10/core/impl/PyInterpreterHooks.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <optional>
|
||||
|
||||
@ -24,52 +25,9 @@ struct C10_API PyObjectSlot {
|
||||
//
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(
|
||||
PyInterpreter* self_interpreter,
|
||||
PyObject* pyobj,
|
||||
PyInterpreterStatus status) {
|
||||
impl::PyInterpreter* expected = nullptr;
|
||||
switch (status) {
|
||||
case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
|
||||
// caller guarantees there is no multithreaded access; if there is
|
||||
// no data race OK to do a relaxed store
|
||||
pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed);
|
||||
break;
|
||||
case impl::PyInterpreterStatus::TAGGED_BY_US:
|
||||
// no tagging is necessary, the tag is already correct
|
||||
break;
|
||||
case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
|
||||
// attempt to claim this TensorImpl with the specified interpreter
|
||||
// tag
|
||||
if (pyobj_interpreter_.compare_exchange_strong(
|
||||
expected, self_interpreter, std::memory_order_acq_rel)) {
|
||||
break;
|
||||
}
|
||||
// test if, actually, it was already tagged by us! this situation can't
|
||||
// be caused by a race, but it could be caused by a situation
|
||||
// where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
|
||||
// (because they didn't pre-check the tag) when actually it was
|
||||
// owned by the interpreter
|
||||
if (expected == self_interpreter) {
|
||||
break;
|
||||
}
|
||||
// fallthrough, we lost the race. We are guaranteed not to lose the
|
||||
// race with ourself, as calls to init_pyobj with the same interpreter
|
||||
// ID must be sequentialized by the GIL
|
||||
[[fallthrough]];
|
||||
case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"cannot allocate PyObject for Tensor on interpreter ",
|
||||
self_interpreter,
|
||||
" that has already been used by another torch deploy interpreter ",
|
||||
pyobj_interpreter_.load());
|
||||
}
|
||||
|
||||
// we are the ONLY thread that can have gotten to this point. It is not
|
||||
// possible to conflict with another zero interpreter as access is protected
|
||||
// by GIL
|
||||
// NB: owns_pyobj tag is initially false
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
|
||||
@ -94,49 +52,25 @@ struct C10_API PyObjectSlot {
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
std::optional<PyObject*> check_pyobj(
|
||||
PyInterpreter* self_interpreter,
|
||||
bool ignore_hermetic_tls = false) const {
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
// NB: This never returns DEFINITELY_UNINITIALIZED because there is
|
||||
// always the possibility that another thread races to initialize
|
||||
// after we query here. The only time when we can conclude a tensor
|
||||
// is definitely uninitialized is when we have just allocated it and
|
||||
// it cannot have escaped to other threads yet
|
||||
return std::nullopt;
|
||||
} else if (interpreter == self_interpreter) {
|
||||
// NB: pyobj_ could still be null!
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
}
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"cannot access PyObject for Tensor on interpreter ",
|
||||
(*self_interpreter)->name(),
|
||||
" that has already been used by another torch deploy interpreter ",
|
||||
(*pyobj_interpreter_.load())->name());
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the PyObject field for an interpreter, in situations where we
|
||||
// statically know the tensor is tagged with our interpreter.
|
||||
void unchecked_clear_pyobj(PyInterpreter* interpreter);
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
|
||||
// Check if the PyObjectSlot's interpreter is the same as the specified
|
||||
// interpreter
|
||||
bool check_interpreter(PyInterpreter* interpreter);
|
||||
|
||||
// Check if the PyObjectSlot is holding a PyObject, owned or non-owned
|
||||
bool has_pyobj_nonhermetic();
|
||||
|
||||
bool owns_pyobj();
|
||||
|
||||
void set_owns_pyobj(bool b);
|
||||
|
||||
@ -1,389 +1,119 @@
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
|
||||
|
||||
CUDAAllocatorConfig::CUDAAllocatorConfig()
|
||||
: m_max_split_size(std::numeric_limits<size_t>::max()),
|
||||
m_max_non_split_rounding_size(kLargeBuffer),
|
||||
m_garbage_collection_threshold(0),
|
||||
m_pinned_num_register_threads(1),
|
||||
m_expandable_segments(false),
|
||||
#if CUDA_VERSION >= 12030
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::UNSPECIFIED),
|
||||
#else
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::POSIX_FD),
|
||||
#endif
|
||||
m_release_lock_on_cudamalloc(false),
|
||||
m_pinned_use_cuda_host_register(false),
|
||||
m_pinned_use_background_threads(false) {
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
|
||||
size_t log_size = (63 - llvm::countLeadingZeros(size));
|
||||
|
||||
// Our intervals start at 1MB and end at 64GB
|
||||
const size_t interval_start =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
|
||||
const size_t interval_end =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
|
||||
TORCH_CHECK(
|
||||
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
|
||||
"kRoundUpPowerOfTwoIntervals mismatch");
|
||||
|
||||
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
|
||||
|
||||
index = std::max(0, index);
|
||||
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
|
||||
return instance().m_roundup_power2_divisions[index];
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::lexArgs(
|
||||
const std::string& env,
|
||||
std::vector<std::string>& config) {
|
||||
std::vector<char> buf;
|
||||
|
||||
for (char ch : env) {
|
||||
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
}
|
||||
config.emplace_back(1, ch);
|
||||
} else if (ch != ' ') {
|
||||
buf.emplace_back(ch);
|
||||
}
|
||||
}
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c) {
|
||||
TORCH_CHECK(
|
||||
i < config.size() && config[i] == std::string(1, c),
|
||||
"Error parsing CachingAllocator settings, expected ",
|
||||
c,
|
||||
"");
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxSplitSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_split_size_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_split_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_non_split_rounding_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
double val1 = stod(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
|
||||
TORCH_CHECK(
|
||||
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
|
||||
m_garbage_collection_threshold = val1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting garbage_collection_threshold value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
bool first_value = true;
|
||||
|
||||
if (++i < config.size()) {
|
||||
if (std::string_view(config[i]) == "[") {
|
||||
size_t last_index = 0;
|
||||
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
||||
while (++i < config.size() && std::string_view(config[i]) != "]") {
|
||||
const std::string& val1 = config[i];
|
||||
size_t val2 = 0;
|
||||
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
val2 = stoi(config[i]);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error parsing roundup_power2_divisions value", "");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
val2 == 0 || llvm::isPowerOf2_64(val2),
|
||||
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
|
||||
"");
|
||||
|
||||
if (std::string_view(val1) == ">") {
|
||||
std::fill(
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
last_index)),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val2);
|
||||
} else {
|
||||
size_t val1_long = stoul(val1);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1_long),
|
||||
"For roundups, the intervals have to be power of 2 ",
|
||||
"");
|
||||
|
||||
size_t index = 63 - llvm::countLeadingZeros(val1_long);
|
||||
index = std::max((size_t)0, index);
|
||||
index = std::min(index, m_roundup_power2_divisions.size() - 1);
|
||||
|
||||
if (first_value) {
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
index)),
|
||||
val2);
|
||||
first_value = false;
|
||||
}
|
||||
if (index < m_roundup_power2_divisions.size()) {
|
||||
m_roundup_power2_divisions[index] = val2;
|
||||
}
|
||||
last_index = index;
|
||||
}
|
||||
|
||||
if (std::string_view(config[i + 1]) != "]") {
|
||||
consumeToken(config, ++i, ',');
|
||||
}
|
||||
}
|
||||
} else { // Keep this for backwards compatibility
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1),
|
||||
"For roundups, the divisions has to be power of 2 ",
|
||||
"");
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val1);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync) {
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
// For ease of maintenance and understanding, the CUDA and ROCm
|
||||
// implementations of this function are separated. This avoids having many
|
||||
// #ifdef's throughout.
|
||||
#ifdef USE_ROCM
|
||||
// Ease burden on ROCm users by allowing either cuda or hip tokens.
|
||||
// cuda token is broken up to prevent hipify matching it.
|
||||
#define PYTORCH_TOKEN1 \
|
||||
"cud" \
|
||||
"aMallocAsync"
|
||||
#define PYTORCH_TOKEN2 "hipMallocAsync"
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
i++; // Move to the value after the colon
|
||||
TORCH_CHECK_VALUE(
|
||||
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
|
||||
(tokenizer[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
if (m_is_allocator_loaded) {
|
||||
bool aync_allocator_at_runtime = (tokenizer[i] != "native");
|
||||
TORCH_CHECK(
|
||||
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
|
||||
(config[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
used_cudaMallocAsync =
|
||||
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name() ||
|
||||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time, ",
|
||||
config[i],
|
||||
aync_allocator_at_runtime == m_use_async_allocator,
|
||||
"Allocator async backend parsed at runtime != allocator async backend parsed at load time, ",
|
||||
aync_allocator_at_runtime,
|
||||
" != ",
|
||||
get()->name());
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
m_use_async_allocator);
|
||||
}
|
||||
m_use_async_allocator =
|
||||
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
|
||||
// CUDA allocator is always loaded at the start of the program
|
||||
m_is_allocator_loaded = true;
|
||||
|
||||
#if defined(CUDA_VERSION)
|
||||
if (m_use_async_allocator) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
return i;
|
||||
#undef PYTORCH_TOKEN1
|
||||
#undef PYTORCH_TOKEN2
|
||||
#else // USE_ROCM
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
|
||||
"Unknown allocator backend, "
|
||||
"options are native and cudaMallocAsync");
|
||||
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
|
||||
if (used_cudaMallocAsync) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name(),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time");
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
}
|
||||
return i;
|
||||
#endif // USE_ROCM
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
|
||||
// If empty, set the default values
|
||||
m_max_split_size = std::numeric_limits<size_t>::max();
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
m_garbage_collection_threshold = 0;
|
||||
bool used_cudaMallocAsync = false;
|
||||
bool used_native_specific_option = false;
|
||||
|
||||
if (!env.has_value()) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
|
||||
m_last_allocator_settings = env.value();
|
||||
}
|
||||
|
||||
std::vector<std::string> config;
|
||||
lexArgs(env.value(), config);
|
||||
|
||||
for (size_t i = 0; i < config.size(); i++) {
|
||||
std::string_view config_item_view(config[i]);
|
||||
if (config_item_view == "max_split_size_mb") {
|
||||
i = parseMaxSplitSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "max_non_split_rounding_mb") {
|
||||
i = parseMaxNonSplitRoundingSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "garbage_collection_threshold") {
|
||||
i = parseGarbageCollectionThreshold(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "roundup_power2_divisions") {
|
||||
i = parseRoundUpPower2Divisions(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "backend") {
|
||||
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
|
||||
} else if (config_item_view == "expandable_segments") {
|
||||
used_native_specific_option = true;
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for expandable_segments");
|
||||
config_item_view = config[i];
|
||||
m_expandable_segments = (config_item_view == "True");
|
||||
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
|
||||
for (size_t i = 0; i < tokenizer.size(); i++) {
|
||||
const auto& key = tokenizer[i];
|
||||
if (key == "backend") {
|
||||
i = parseAllocatorConfig(tokenizer, i);
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
config_item_view == "release_lock_on_hipmalloc" ||
|
||||
config_item_view ==
|
||||
key == "release_lock_on_hipmalloc" ||
|
||||
key ==
|
||||
"release_lock_on_c"
|
||||
"udamalloc") {
|
||||
used_native_specific_option = true;
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for release_lock_on_cudamalloc");
|
||||
config_item_view = config[i];
|
||||
m_release_lock_on_cudamalloc = (config_item_view == "True");
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
config_item_view == "pinned_use_hip_host_register" ||
|
||||
config_item_view ==
|
||||
key == "pinned_use_hip_host_register" ||
|
||||
key ==
|
||||
"pinned_use_c"
|
||||
"uda_host_register") {
|
||||
i = parsePinnedUseCudaHostRegister(config, i);
|
||||
i = parsePinnedUseCudaHostRegister(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(config, i);
|
||||
} else if (key == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(tokenizer, i);
|
||||
used_native_specific_option = true;
|
||||
} else {
|
||||
const auto& keys =
|
||||
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
|
||||
TORCH_CHECK(
|
||||
false, "Unrecognized CachingAllocator option: ", config_item_view);
|
||||
keys.find(key) != keys.end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in Accelerator allocator config.");
|
||||
i = tokenizer.skipKey(i);
|
||||
}
|
||||
|
||||
if (i + 1 < config.size()) {
|
||||
consumeToken(config, ++i, ',');
|
||||
if (i + 1 < tokenizer.size()) {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
}
|
||||
}
|
||||
|
||||
if (used_cudaMallocAsync && used_native_specific_option) {
|
||||
if (m_use_async_allocator && used_native_specific_option) {
|
||||
TORCH_WARN(
|
||||
"backend:cudaMallocAsync ignores max_split_size_mb,"
|
||||
"roundup_power2_divisions, and garbage_collect_threshold.");
|
||||
@ -391,64 +121,33 @@ void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_cuda_host_register");
|
||||
m_pinned_use_cuda_host_register = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_cuda_host_register value", "");
|
||||
}
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
size_t val2 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2 ",
|
||||
"");
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to " +
|
||||
std::to_string(maxThreads),
|
||||
"");
|
||||
m_pinned_num_register_threads = val2;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_num_register_threads value", "");
|
||||
}
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t val2 = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2 ",
|
||||
"");
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK_VALUE(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to " +
|
||||
std::to_string(maxThreads),
|
||||
"");
|
||||
m_pinned_num_register_threads = val2;
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_background_threads");
|
||||
m_pinned_use_background_threads = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_background_threads value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
// General caching allocator utilities
|
||||
void setAllocatorSettings(const std::string& env) {
|
||||
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
|
||||
}
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -1,16 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
enum class Expandable_Segments_Handle_Type : int {
|
||||
@ -22,21 +18,28 @@ enum class Expandable_Segments_Handle_Type : int {
|
||||
// Environment config parser
|
||||
class C10_CUDA_API CUDAAllocatorConfig {
|
||||
public:
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
|
||||
static size_t max_split_size() {
|
||||
return instance().m_max_split_size;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
|
||||
}
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
|
||||
static double garbage_collection_threshold() {
|
||||
return instance().m_garbage_collection_threshold;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
garbage_collection_threshold();
|
||||
}
|
||||
|
||||
static bool expandable_segments() {
|
||||
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
use_expandable_segments();
|
||||
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
if (instance().m_expandable_segments) {
|
||||
if (enabled) {
|
||||
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
|
||||
}
|
||||
return false;
|
||||
#else
|
||||
return instance().m_expandable_segments;
|
||||
return enabled;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -62,8 +65,11 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_pinned_num_register_threads;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
|
||||
static bool pinned_use_background_threads() {
|
||||
return instance().m_pinned_use_background_threads;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
static size_t pinned_max_register_threads() {
|
||||
@ -73,92 +79,105 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return 128;
|
||||
}
|
||||
|
||||
// This is used to round-up allocation size to nearest power of 2 divisions.
|
||||
// More description below in function roundup_power2_next_division
|
||||
// As an example, if we want 4 divisions between 2's power, this can be done
|
||||
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
|
||||
static size_t roundup_power2_divisions(size_t size);
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static size_t roundup_power2_divisions(size_t size) {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions(size);
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static std::vector<size_t> roundup_power2_divisions() {
|
||||
return instance().m_roundup_power2_divisions;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions();
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.")
|
||||
static size_t max_non_split_rounding_size() {
|
||||
return instance().m_max_non_split_rounding_size;
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
max_non_split_rounding_size();
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
|
||||
static std::string last_allocator_settings() {
|
||||
std::lock_guard<std::mutex> lock(
|
||||
instance().m_last_allocator_settings_mutex);
|
||||
return instance().m_last_allocator_settings;
|
||||
return c10::CachingAllocator::getAllocatorSettings();
|
||||
}
|
||||
|
||||
static bool use_async_allocator() {
|
||||
return instance().m_use_async_allocator;
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
}
|
||||
|
||||
static CUDAAllocatorConfig& instance() {
|
||||
static CUDAAllocatorConfig* s_instance = ([]() {
|
||||
auto inst = new CUDAAllocatorConfig();
|
||||
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
|
||||
if (!env.has_value()) {
|
||||
// For backward compatibility, check for the old environment variable
|
||||
// PYTORCH_CUDA_ALLOC_CONF.
|
||||
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
}
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users, allow alternative HIP token
|
||||
if (!env.has_value()) {
|
||||
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
|
||||
}
|
||||
#endif
|
||||
inst->parseArgs(env);
|
||||
if (env.has_value()) {
|
||||
inst->parseArgs(env.value());
|
||||
}
|
||||
return inst;
|
||||
})();
|
||||
return *s_instance;
|
||||
}
|
||||
|
||||
void parseArgs(const std::optional<std::string>& env);
|
||||
void parseArgs(const std::string& env);
|
||||
|
||||
private:
|
||||
CUDAAllocatorConfig();
|
||||
CUDAAllocatorConfig() = default;
|
||||
|
||||
static void lexArgs(const std::string& env, std::vector<std::string>& config);
|
||||
static void consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c);
|
||||
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
|
||||
size_t parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync);
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
size_t parsePinnedUseCudaHostRegister(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
size_t parsePinnedNumRegisterThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
size_t i);
|
||||
|
||||
std::atomic<size_t> m_max_split_size;
|
||||
std::atomic<size_t> m_max_non_split_rounding_size;
|
||||
std::vector<size_t> m_roundup_power2_divisions;
|
||||
std::atomic<double> m_garbage_collection_threshold;
|
||||
std::atomic<size_t> m_pinned_num_register_threads;
|
||||
std::atomic<bool> m_expandable_segments;
|
||||
std::atomic<Expandable_Segments_Handle_Type>
|
||||
m_expandable_segments_handle_type;
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc;
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register;
|
||||
std::atomic<bool> m_pinned_use_background_threads;
|
||||
std::string m_last_allocator_settings;
|
||||
std::mutex m_last_allocator_settings_mutex;
|
||||
std::atomic<size_t> m_pinned_num_register_threads{1};
|
||||
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
|
||||
#if CUDA_VERSION >= 12030
|
||||
{Expandable_Segments_Handle_Type::UNSPECIFIED};
|
||||
#else
|
||||
{Expandable_Segments_Handle_Type::POSIX_FD};
|
||||
#endif
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc{false};
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register{false};
|
||||
std::atomic<bool> m_use_async_allocator{false};
|
||||
std::atomic<bool> m_is_allocator_loaded{false};
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"backend",
|
||||
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
|
||||
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_cud"
|
||||
"amalloc",
|
||||
"pinned_use_cud"
|
||||
"a_host_register",
|
||||
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_hipmalloc",
|
||||
"pinned_use_hip_host_register",
|
||||
"pinned_num_register_threads"};
|
||||
};
|
||||
|
||||
// General caching allocator utilities
|
||||
C10_CUDA_API void setAllocatorSettings(const std::string& env);
|
||||
// Keep this for backwards compatibility
|
||||
using c10::CachingAllocator::setAllocatorSettings;
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -64,10 +63,6 @@ namespace cuda::CUDACachingAllocator {
|
||||
using namespace c10::CachingAllocator;
|
||||
using namespace c10::CachingDeviceAllocator;
|
||||
|
||||
// Included here as this is externally used in CUDAAllocatorConfig
|
||||
const size_t kLargeBuffer =
|
||||
20971520; // "large" allocations may be packed in 20 MiB blocks
|
||||
|
||||
namespace Native {
|
||||
|
||||
//
|
||||
@ -843,8 +838,7 @@ struct AllocParams {
|
||||
size_t size,
|
||||
cudaStream_t stream,
|
||||
BlockPool* pool,
|
||||
size_t alloc_size,
|
||||
DeviceStats& stats)
|
||||
size_t alloc_size)
|
||||
: search_key(device, stream, size), pool(pool), alloc_size(alloc_size) {}
|
||||
|
||||
c10::DeviceIndex device() const {
|
||||
@ -1231,7 +1225,7 @@ class DeviceCachingAllocator {
|
||||
DeviceCachingAllocator()
|
||||
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
|
||||
stats.max_split_size =
|
||||
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
|
||||
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
|
||||
context_recorder_.store(nullptr);
|
||||
}
|
||||
|
||||
@ -1341,7 +1335,7 @@ class DeviceCachingAllocator {
|
||||
size_t size = round_size(orig_size);
|
||||
auto& pool = get_pool(size, stream);
|
||||
const size_t alloc_size = get_allocation_size(size);
|
||||
AllocParams params(device, size, stream, &pool, alloc_size, stats);
|
||||
AllocParams params(device, size, stream, &pool, alloc_size);
|
||||
params.stat_types = get_stat_types_for_pool(pool);
|
||||
|
||||
// First, try to get a block from the existing pool.
|
||||
@ -1356,7 +1350,8 @@ class DeviceCachingAllocator {
|
||||
// Do garbage collection if the flag is set.
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() >
|
||||
0.0)) {
|
||||
garbage_collect_cached_blocks(context);
|
||||
}
|
||||
// Attempt allocate
|
||||
@ -1388,7 +1383,7 @@ class DeviceCachingAllocator {
|
||||
beginAllocateToPool(mempool_id, filter);
|
||||
auto& mempool = get_pool(size, stream);
|
||||
AllocParams mempool_params(
|
||||
device, size, stream, &mempool, alloc_size, stats);
|
||||
device, size, stream, &mempool, alloc_size);
|
||||
mempool_params.stat_types = get_stat_types_for_pool(mempool);
|
||||
block_found = get_free_block(mempool_params);
|
||||
endAllocateToPool(mempool_id);
|
||||
@ -1608,7 +1603,7 @@ class DeviceCachingAllocator {
|
||||
stats.active_bytes[stat_type].increase(block->size);
|
||||
stats.requested_bytes[stat_type].increase(block->requested_size);
|
||||
});
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.increase(1);
|
||||
|
||||
auto allocated_bytes_gauge =
|
||||
@ -1659,7 +1654,7 @@ class DeviceCachingAllocator {
|
||||
block->pool->owner_MempoolId(),
|
||||
context ? context : block->context_when_allocated);
|
||||
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.decrease(1);
|
||||
|
||||
if (!block->stream_uses.empty()) {
|
||||
@ -1929,8 +1924,7 @@ class DeviceCachingAllocator {
|
||||
block_state.size,
|
||||
block_state.stream,
|
||||
&pool,
|
||||
block_state.size,
|
||||
stats);
|
||||
block_state.size);
|
||||
pool.blocks.erase(curr_block);
|
||||
params.block = curr_block;
|
||||
params.stat_types = get_stat_types_for_pool(pool);
|
||||
@ -2209,7 +2203,8 @@ class DeviceCachingAllocator {
|
||||
if (size < kMinBlockSize) {
|
||||
return kMinBlockSize;
|
||||
} else {
|
||||
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
|
||||
auto divisions =
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(size);
|
||||
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
|
||||
return roundup_power2_next_division(size, divisions);
|
||||
} else {
|
||||
@ -2699,7 +2694,7 @@ class DeviceCachingAllocator {
|
||||
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return (size < CUDAAllocatorConfig::max_split_size()) &&
|
||||
return (size < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
(remaining > kSmallSize);
|
||||
}
|
||||
}
|
||||
@ -2719,7 +2714,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
// Track block reuse interval only when garbage collection is enabled.
|
||||
++pool.get_free_blocks_call_count;
|
||||
}
|
||||
@ -2761,13 +2756,13 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
// Do not return an oversized block for a large request
|
||||
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
|
||||
if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
|
||||
return false;
|
||||
// Allow oversized block size to be rounded up but within a limit
|
||||
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >=
|
||||
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
|
||||
p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
|
||||
return false;
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
@ -2790,7 +2785,7 @@ class DeviceCachingAllocator {
|
||||
// therefore should be of less overheads.
|
||||
|
||||
size_t gc_threshold = static_cast<size_t>(
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() *
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() *
|
||||
static_cast<double>(allowed_memory_maximum));
|
||||
// No need to trigger GC yet
|
||||
if (total_allocated_memory <= gc_threshold) {
|
||||
@ -2938,7 +2933,7 @@ class DeviceCachingAllocator {
|
||||
stats.segment[stat_type].increase(1);
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
if (size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.increase(1);
|
||||
auto reserved_bytes_gauge =
|
||||
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
|
||||
@ -2967,7 +2962,7 @@ class DeviceCachingAllocator {
|
||||
bool release_available_cached_blocks(
|
||||
const AllocParams& p,
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
if (CUDAAllocatorConfig::max_split_size() ==
|
||||
if (AcceleratorAllocatorConfig::max_split_size() ==
|
||||
std::numeric_limits<size_t>::max())
|
||||
return false;
|
||||
BlockPool& pool = *p.pool;
|
||||
@ -2975,8 +2970,8 @@ class DeviceCachingAllocator {
|
||||
// because of std::unique_ptr, block cannot be trivially copied
|
||||
// Use constructor for search key.
|
||||
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
|
||||
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
|
||||
? CUDAAllocatorConfig::max_split_size()
|
||||
key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
|
||||
? AcceleratorAllocatorConfig::max_split_size()
|
||||
: key.size;
|
||||
auto it = pool.blocks.lower_bound(&key);
|
||||
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
|
||||
@ -2989,7 +2984,7 @@ class DeviceCachingAllocator {
|
||||
--it; // Back up one item. Now on the largest block for the correct
|
||||
// stream
|
||||
while ((totalReleased < key.size) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->stream == p.stream())) {
|
||||
auto cur = it;
|
||||
bool is_first = cur == pool.blocks.begin();
|
||||
@ -3114,7 +3109,7 @@ class DeviceCachingAllocator {
|
||||
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
|
||||
.current);
|
||||
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.decrease(1);
|
||||
pool->blocks.erase(block);
|
||||
delete block;
|
||||
@ -3741,8 +3736,8 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
|
||||
auto& md = result.config_metadata;
|
||||
md.garbage_collection_threshold =
|
||||
CUDAAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = CUDAAllocatorConfig::max_split_size();
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
|
||||
md.pinned_num_register_threads =
|
||||
CUDAAllocatorConfig::pinned_num_register_threads();
|
||||
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
|
||||
@ -3750,9 +3745,10 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
CUDAAllocatorConfig::release_lock_on_cudamalloc();
|
||||
md.pinned_use_host_register =
|
||||
CUDAAllocatorConfig::pinned_use_cuda_host_register();
|
||||
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
|
||||
md.last_allocator_settings =
|
||||
AcceleratorAllocatorConfig::last_allocator_settings();
|
||||
md.roundup_power2_divisions =
|
||||
CUDAAllocatorConfig::roundup_power2_divisions();
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions();
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -4130,49 +4126,10 @@ CUDAAllocator* allocator();
|
||||
} // namespace CudaMallocAsync
|
||||
|
||||
struct BackendStaticInitializer {
|
||||
// Parses env for backend at load time, duplicating some logic from
|
||||
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
|
||||
// runtime). Defers verbose exceptions and error checks, including Cuda
|
||||
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
|
||||
// works, maybe we should move all of CUDAAllocatorConfig here?
|
||||
CUDAAllocator* parseEnvForBackend() {
|
||||
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
if (!val.has_value()) {
|
||||
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
|
||||
}
|
||||
#endif
|
||||
if (val.has_value()) {
|
||||
const std::string& config = val.value();
|
||||
|
||||
std::regex exp("[\\s,]+");
|
||||
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
|
||||
std::sregex_token_iterator end;
|
||||
std::vector<std::string> options(it, end);
|
||||
|
||||
for (auto option : options) {
|
||||
std::regex exp2("[:]+");
|
||||
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
|
||||
std::sregex_token_iterator end2;
|
||||
std::vector<std::string> kv(it2, end2);
|
||||
if (kv.size() >= 2) {
|
||||
if (kv[0] == "backend") {
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
if (kv[1] ==
|
||||
"cud"
|
||||
"aMallocAsync" ||
|
||||
kv[1] == "hipMallocAsync")
|
||||
#else
|
||||
if (kv[1] == "cudaMallocAsync")
|
||||
#endif
|
||||
return CudaMallocAsync::allocator();
|
||||
if (kv[1] == "native")
|
||||
return &Native::allocator;
|
||||
}
|
||||
}
|
||||
}
|
||||
// If the environment variable is set, we use the CudaMallocAsync allocator.
|
||||
if (CUDAAllocatorConfig::use_async_allocator()) {
|
||||
return CudaMallocAsync::allocator();
|
||||
}
|
||||
return &Native::allocator;
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -49,10 +50,9 @@ namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
// Preserved only for BC reasons
|
||||
// NOLINTNEXTLINE(misc-unused-using-decls)
|
||||
using c10::CachingAllocator::kLargeBuffer;
|
||||
using c10::CachingDeviceAllocator::DeviceStats;
|
||||
|
||||
extern const size_t kLargeBuffer;
|
||||
|
||||
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
|
||||
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a
|
||||
|
||||
@ -5,15 +5,86 @@
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
struct simd_type {
|
||||
using t = T;
|
||||
};
|
||||
|
||||
// Helper that allows one to run simd ops over bfl16 by upcasting them to fp32
|
||||
template <typename T>
|
||||
using simd_type_t = typename simd_type<T>::t;
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct simd_type<bfloat> {
|
||||
using t = float;
|
||||
};
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||
return ::metal::simd_sum(val);
|
||||
return T(::metal::simd_sum(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<!::metal::is_same_v<T, long>, T> simd_prod(T val) {
|
||||
return ::metal::simd_product(val);
|
||||
return T(::metal::simd_product(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
// Extend simd_broadcast to 64-bit integral types using int2 trick
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) == 8, bool> =
|
||||
true>
|
||||
inline T simd_broadcast(T val, ushort lane_id) {
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), lane_id));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<!::metal::is_integral_v<T> || sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_broadcast(T val, ushort lane_id) {
|
||||
return ::metal::simd_broadcast(val, lane_id);
|
||||
}
|
||||
|
||||
// Floating simd_min/max with nan propagation
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline T simd_max(T val) {
|
||||
if (::metal::simd_any(::metal::isnan(val))) {
|
||||
return ::metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
return T(::metal::simd_max(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline T simd_min(T val) {
|
||||
if (::metal::simd_any(::metal::isnan(val))) {
|
||||
return ::metal::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
return T(::metal::simd_min(detail::simd_type_t<T>(val)));
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_max(T val) {
|
||||
return ::metal::simd_max(val);
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T> && sizeof(T) != 8, bool> =
|
||||
true>
|
||||
inline T simd_min(T val) {
|
||||
return ::metal::simd_min(val);
|
||||
}
|
||||
|
||||
// Metal does not support SIMD reductions over 64-bit types, but it could be
|
||||
@ -28,7 +99,7 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_sum(T val) {
|
||||
val += as_type<T>(
|
||||
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
|
||||
}
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -37,7 +108,78 @@ inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_prod(T val) {
|
||||
val *= as_type<T>(
|
||||
::metal::simd_shuffle_and_fill_down(as_type<int2>(val), int2(0), i));
|
||||
}
|
||||
return as_type<T>(::metal::simd_broadcast(as_type<int2>(val), 0));
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_max(T val) {
|
||||
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
|
||||
val = ::metal::max(
|
||||
val,
|
||||
as_type<T>(::metal::simd_shuffle_and_fill_down(
|
||||
as_type<int2>(val), int2(0), i)));
|
||||
}
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline ::metal::enable_if_t<::metal::is_same_v<T, long>, T> simd_min(T val) {
|
||||
for (ushort i = simdgroup_size / 2; i > 0; i /= 2) {
|
||||
val = ::metal::min(
|
||||
val,
|
||||
as_type<T>(::metal::simd_shuffle_and_fill_down(
|
||||
as_type<int2>(val), int2(0), i)));
|
||||
}
|
||||
return simd_broadcast(val, 0);
|
||||
}
|
||||
|
||||
// argmin/argmax helpers using simd_ballot
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_integral_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
::metal::enable_if_t<::metal::is_floating_point_v<T>, bool> = true>
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
inline c10::metal::pair<ARG_T, IDX_T> simd_argmin(ARG_T val, IDX_T idx_val) {
|
||||
auto rc = simd_argmin(val);
|
||||
return {rc.first, simd_broadcast(idx_val, rc.second)};
|
||||
}
|
||||
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
inline c10::metal::pair<ARG_T, IDX_T> simd_argmax(ARG_T val, IDX_T idx_val) {
|
||||
auto rc = simd_argmax(val);
|
||||
return {rc.first, simd_broadcast(idx_val, rc.second)};
|
||||
}
|
||||
|
||||
// Below algorithms are written with hardcoded assumption that simdgroup is 32
|
||||
@ -88,6 +230,44 @@ opmath_t<T> threadgroup_prod(
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, T val, unsigned idx, unsigned size) {
|
||||
auto rc = simd_max(val);
|
||||
if (idx % simdgroup_size == 0) {
|
||||
data[idx / simdgroup_size] = rc;
|
||||
}
|
||||
if (size > simdgroup_size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_max(data[idx]);
|
||||
if (idx == 0) {
|
||||
data[0] = rc1;
|
||||
}
|
||||
}
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_min(threadgroup T* data, T val, unsigned idx, unsigned size) {
|
||||
auto rc = simd_min(val);
|
||||
if (idx % simdgroup_size == 0) {
|
||||
data[idx / simdgroup_size] = rc;
|
||||
}
|
||||
if (size > simdgroup_size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_min(data[idx]);
|
||||
if (idx == 0) {
|
||||
data[0] = rc1;
|
||||
}
|
||||
}
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
float3 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
@ -123,52 +303,58 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_max(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
T rc = data[0];
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
rc = ::c10::metal::max(rc, data[idx]);
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
IDX_T threadgroup_argmax(
|
||||
threadgroup ARG_T* arg_data,
|
||||
threadgroup IDX_T* idx_data,
|
||||
ARG_T val,
|
||||
IDX_T idx_val,
|
||||
unsigned idx,
|
||||
unsigned size) {
|
||||
auto rc = simd_argmax(val, idx_val);
|
||||
if (size <= simdgroup_size) {
|
||||
return rc.second;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T threadgroup_min(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
T rc = data[0];
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
rc = ::c10::metal::min(rc, data[idx]);
|
||||
if (idx % simdgroup_size == 0) {
|
||||
arg_data[idx / simdgroup_size] = rc.first;
|
||||
idx_data[idx / simdgroup_size] = rc.second;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int threadgroup_argmax(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
int rc = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] > data[rc]) {
|
||||
rc = idx;
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_argmax(arg_data[idx], idx_data[idx]);
|
||||
if (idx == 0) {
|
||||
idx_data[0] = rc1.second;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return idx_data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int threadgroup_argmin(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
IDX_T threadgroup_argmin(
|
||||
threadgroup ARG_T* arg_data,
|
||||
threadgroup IDX_T* idx_data,
|
||||
ARG_T val,
|
||||
IDX_T idx_val,
|
||||
unsigned idx,
|
||||
unsigned size) {
|
||||
auto rc = simd_argmin(val, idx_val);
|
||||
if (size <= simdgroup_size) {
|
||||
return rc.second;
|
||||
}
|
||||
if (idx % simdgroup_size == 0) {
|
||||
arg_data[idx / simdgroup_size] = rc.first;
|
||||
idx_data[idx / simdgroup_size] = rc.second;
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
int rc = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] < data[rc]) {
|
||||
rc = idx;
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_argmin(arg_data[idx], idx_data[idx]);
|
||||
if (idx == 0) {
|
||||
idx_data[0] = rc1.second;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return idx_data[0];
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
@ -330,5 +330,11 @@ inline float log1p(float x) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2 = T1>
|
||||
struct pair {
|
||||
T1 first;
|
||||
T2 second;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
||||
130
c10/test/core/AllocatorConfig_test.cpp
Normal file
130
c10/test/core/AllocatorConfig_test.cpp
Normal file
@ -0,0 +1,130 @@
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace c10::CachingAllocator;
|
||||
constexpr size_t kMB = 1024 * 1024ul;
|
||||
|
||||
struct ExtendedAllocatorConfig {
|
||||
static ExtendedAllocatorConfig& instance() {
|
||||
static ExtendedAllocatorConfig instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Returns the device-specific option value in bytes.
|
||||
static size_t device_specific_option() {
|
||||
return instance().device_specific_option_;
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
}
|
||||
|
||||
void parseArgs(const std::string& env) {
|
||||
// Parse device-specific options from the environment variable
|
||||
ConfigTokenizer tokenizer(env);
|
||||
for (size_t i = 0; i < tokenizer.size(); i++) {
|
||||
const auto& key = tokenizer[i];
|
||||
if (key == "device_specific_option_mb") {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
device_specific_option_ = tokenizer.toSizeT(++i) * kMB;
|
||||
} else {
|
||||
i = tokenizer.skipKey(i);
|
||||
}
|
||||
|
||||
if (i + 1 < tokenizer.size()) {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Device-specific option, e.g., memory limit for a specific device.
|
||||
std::atomic<size_t> device_specific_option_{0};
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"device_specific_option_mb"};
|
||||
};
|
||||
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig)
|
||||
|
||||
TEST(AllocatorConfigTest, allocator_config_test) {
|
||||
std::string env =
|
||||
"max_split_size_mb:40,"
|
||||
"max_non_split_rounding_mb:30,"
|
||||
"garbage_collection_threshold:0.5,"
|
||||
"roundup_power2_divisions:[64:8,128:2,256:4,512:2,1024:4,>:1],"
|
||||
"expandable_segments:True,"
|
||||
"pinned_use_background_threads:True,"
|
||||
"device_specific_option_mb:64";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 40 * kMB);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::max_non_split_rounding_size(), 30 * kMB);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.5);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(32 * kMB), 8);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 1);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(8192 * kMB), 1);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), true);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), true);
|
||||
EXPECT_EQ(ExtendedAllocatorConfig::device_specific_option(), 64 * kMB);
|
||||
|
||||
env =
|
||||
"max_split_size_mb:20,"
|
||||
"max_non_split_rounding_mb:40,"
|
||||
"garbage_collection_threshold:0.8";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 20 * kMB);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::max_non_split_rounding_size(), 40 * kMB);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.8);
|
||||
|
||||
// roundup_power2_divisions knob array syntax
|
||||
env = "roundup_power2_divisions:[128:8,256:16,512:1,2048:8,>:2]";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 8);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 16);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2);
|
||||
|
||||
// roundup_power2_divisions single value syntax for backward compatibility
|
||||
env = "roundup_power2_divisions:4";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 4);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4);
|
||||
EXPECT_EQ(
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 4);
|
||||
|
||||
env = "expandable_segments:False,";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), false);
|
||||
|
||||
env = "pinned_use_background_threads:False";
|
||||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false);
|
||||
|
||||
env = "foo:123,bar:456";
|
||||
ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error);
|
||||
}
|
||||
@ -1,340 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
||||
#elif defined(SYCL_LANGUAGE_VERSION)
|
||||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
|
||||
:
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
|
||||
__CUDA_ARCH__ >= 800
|
||||
x(__bfloat16_as_ushort(__float2bfloat16(value)))
|
||||
#elif defined(__SYCL_DEVICE_ONLY__) && \
|
||||
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
|
||||
#else
|
||||
// RNE by default
|
||||
x(detail::round_to_nearest_even(value))
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
/// Implicit conversions
|
||||
inline C10_HOST_DEVICE BFloat16::operator float() const {
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
||||
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
|
||||
#elif defined(__SYCL_DEVICE_ONLY__) && \
|
||||
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
|
||||
#else
|
||||
return detail::f32_from_bits(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
|
||||
return *reinterpret_cast<const __nv_bfloat16*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
inline C10_HOST_DEVICE BFloat16::BFloat16(
|
||||
const sycl::ext::oneapi::bfloat16& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
|
||||
return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// CUDA intrinsics
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
|
||||
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
|
||||
#else
|
||||
return *ptr;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16
|
||||
operator+(const BFloat16& a, const BFloat16& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16
|
||||
operator-(const BFloat16& a, const BFloat16& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16
|
||||
operator*(const BFloat16& a, const BFloat16& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
|
||||
a.x = a.x | b.x;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
|
||||
a.x = a.x ^ b.x;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
|
||||
a.x = a.x & b.x;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
|
||||
return a + static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
|
||||
return a - static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
|
||||
return a * static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
|
||||
return a / static_cast<BFloat16>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
|
||||
return a + static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
|
||||
return a - static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
|
||||
return a * static_cast<BFloat16>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
|
||||
return a / static_cast<BFloat16>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
|
||||
return static_cast<BFloat16>(a) / b;
|
||||
}
|
||||
|
||||
// Overloading < and > operators, because std::max and std::min use them.
|
||||
|
||||
inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
|
||||
return float(lhs) > float(rhs);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
|
||||
return float(lhs) < float(rhs);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::BFloat16> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = true;
|
||||
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
|
||||
static constexpr auto has_denorm_loss =
|
||||
numeric_limits<float>::has_denorm_loss;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 8;
|
||||
static constexpr int digits10 = 2;
|
||||
static constexpr int max_digits10 = 4;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -125;
|
||||
static constexpr int min_exponent10 = -37;
|
||||
static constexpr int max_exponent = 128;
|
||||
static constexpr int max_exponent10 = 38;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::BFloat16 min() {
|
||||
return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 lowest() {
|
||||
return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 max() {
|
||||
return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 epsilon() {
|
||||
return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 round_error() {
|
||||
return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 infinity() {
|
||||
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 quiet_NaN() {
|
||||
return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 signaling_NaN() {
|
||||
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
|
||||
}
|
||||
static constexpr c10::BFloat16 denorm_min() {
|
||||
return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
|
||||
@ -1,116 +1 @@
|
||||
#pragma once
|
||||
|
||||
// Defines the bloat16 type (brain floating-point). This representation uses
|
||||
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
||||
#elif defined(SYCL_LANGUAGE_VERSION)
|
||||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
|
||||
float res = 0;
|
||||
uint32_t tmp = src;
|
||||
tmp <<= 16;
|
||||
|
||||
#if defined(USE_ROCM) && defined(__HIPCC__)
|
||||
float* tempRes;
|
||||
|
||||
// We should be using memcpy in order to respect the strict aliasing rule
|
||||
// but it fails in the HIP environment.
|
||||
tempRes = reinterpret_cast<float*>(&tmp);
|
||||
res = *tempRes;
|
||||
#else
|
||||
std::memcpy(&res, &tmp, sizeof(tmp));
|
||||
#endif
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
|
||||
uint32_t res = 0;
|
||||
|
||||
#if defined(USE_ROCM) && defined(__HIPCC__)
|
||||
// We should be using memcpy in order to respect the strict aliasing rule
|
||||
// but it fails in the HIP environment.
|
||||
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
|
||||
res = *tempRes;
|
||||
#else
|
||||
std::memcpy(&res, &src, sizeof(res));
|
||||
#endif
|
||||
|
||||
return res >> 16;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
|
||||
#if defined(USE_ROCM) && defined(__HIPCC__)
|
||||
if (src != src) {
|
||||
#elif defined(_MSC_VER)
|
||||
if (isnan(src)) {
|
||||
#else
|
||||
if (std::isnan(src)) {
|
||||
#endif
|
||||
return UINT16_C(0x7FC0);
|
||||
} else {
|
||||
const uint32_t U32 = c10::bit_cast<uint32_t>(src);
|
||||
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
||||
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(2) BFloat16 {
|
||||
uint16_t x;
|
||||
|
||||
// HIP wants __host__ __device__ tag, CUDA does not
|
||||
#if defined(USE_ROCM) && defined(__HIPCC__)
|
||||
C10_HOST_DEVICE BFloat16() = default;
|
||||
#else
|
||||
BFloat16() = default;
|
||||
#endif
|
||||
|
||||
struct from_bits_t {};
|
||||
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
|
||||
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
||||
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
|
||||
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
|
||||
#endif
|
||||
|
||||
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
|
||||
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
|
||||
#endif
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/BFloat16.h>
|
||||
|
||||
@ -1,28 +1 @@
|
||||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
/// Defines the Float4_e2m1fn_x2 type (4-bit floating-point, two elements packed
|
||||
/// into one byte). This is the FP4 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.3.3)
|
||||
///
|
||||
/// Given two high precision values val0 and val1, here is the
|
||||
/// binary configuration of their packed representation, from MSB to LSB:
|
||||
///
|
||||
/// original value | val1 : val0
|
||||
/// ========================================
|
||||
/// bit index (MSB==7, LSB==0) | 7654 : 3210
|
||||
/// sign/exponent/mantissa | seem : seem
|
||||
///
|
||||
|
||||
namespace c10 {
|
||||
|
||||
struct alignas(1) Float4_e2m1fn_x2 {
|
||||
uint8_t val_;
|
||||
Float4_e2m1fn_x2() = default;
|
||||
C10_HOST_DEVICE explicit Float4_e2m1fn_x2(uint8_t val) : val_(val) {}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/Float4_e2m1fn_x2.h>
|
||||
|
||||
@ -1,274 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
|
||||
: x(detail::fp8e4m3fn_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
|
||||
return detail::fp8e4m3fn_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
|
||||
return (x & 0b01111111) == 0b01111111;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
|
||||
const Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fn to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e4m3fn> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 4;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 3;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -5;
|
||||
static constexpr int min_exponent10 = -1;
|
||||
static constexpr int max_exponent = 8;
|
||||
static constexpr int max_exponent10 = 2;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e4m3fn min() {
|
||||
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn lowest() {
|
||||
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn max() {
|
||||
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn epsilon() {
|
||||
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn round_error() {
|
||||
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn quiet_NaN() {
|
||||
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn denorm_min() {
|
||||
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
|
||||
@ -1,238 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
/// bias = 7
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E4M3FN number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)input << 24;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the bits 0-30
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*/
|
||||
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
|
||||
/*
|
||||
* Renorm shift is the number of bits to shift mantissa left to make the
|
||||
* half-precision number normalized. If the initial number is normalized, some
|
||||
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
|
||||
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
|
||||
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
|
||||
* mantissa will shift into exponent, turning the biased exponent into 1, and
|
||||
* making mantissa normalized (i.e. without leading 1).
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
uint32_t renorm_shift = __clz(nonsign);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
// Note: zero is not a supported input into `__builtin_clz`
|
||||
uint32_t renorm_shift =
|
||||
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
|
||||
#elif defined(_MSC_VER) && !defined(__clang__)
|
||||
unsigned long nonsign_bsr;
|
||||
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
|
||||
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
|
||||
#else
|
||||
// Note: zero is not a supported input into `__builtin_clz`
|
||||
uint32_t renorm_shift =
|
||||
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
|
||||
#endif
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
/*
|
||||
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
|
||||
* the addition overflows it into bit 31, and the subsequent shift turns the
|
||||
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
|
||||
* is Nan, 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t inf_nan_mask =
|
||||
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
|
||||
/*
|
||||
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
|
||||
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
|
||||
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
|
||||
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
|
||||
* 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||
/*
|
||||
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
|
||||
* was denormal)
|
||||
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
|
||||
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
|
||||
* bits of the 23-bit mantissa of IEEE single-precision number.
|
||||
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
|
||||
* different in exponent bias (0x7F for single-precision number less 0x07
|
||||
* for fp8e4m3fn number).
|
||||
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
|
||||
* account for renormalization. As renorm_shift is less than 0x78, this
|
||||
* can be combined with step 3.
|
||||
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
|
||||
* input was NaN or infinity.
|
||||
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
|
||||
* into zero if the input was zero.
|
||||
* 7. Combine with the sign of the input number.
|
||||
*/
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 480.0f, which is the first value
|
||||
* not representable in fp8e4m3fn range:
|
||||
* 0 1111 111 - fp8e4m3fn
|
||||
* 0 10000111 11100000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 7) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = 0x7f;
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(121) << 23)) {
|
||||
// Input number is smaller than 2^(-6), which is the smallest
|
||||
// fp8e4m3fn normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e4m3fn {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e4m3fn() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
|
||||
@ -1,279 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
|
||||
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
|
||||
const Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e4m3fnuz> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 4;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 3;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -6;
|
||||
static constexpr int min_exponent10 = -1;
|
||||
static constexpr int max_exponent = 8;
|
||||
static constexpr int max_exponent10 = 2;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e4m3fnuz min() {
|
||||
return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz lowest() {
|
||||
return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz max() {
|
||||
return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz epsilon() {
|
||||
return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz round_error() {
|
||||
return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz infinity() {
|
||||
// NaN (no infinities)
|
||||
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
|
||||
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz denorm_min() {
|
||||
return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
|
||||
@ -1,139 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
|
||||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration remains the same as Float8_e4m3fn:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
/// The key differences versus Float8_e4m3fn are:
|
||||
/// bias = 8
|
||||
/// no infinities or negative zero
|
||||
/// NaN only when sign bit is 1, rest all 0s
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
|
||||
/// the existing Float8_e4m3fn implementation.
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 256.0f, which is the first value not representable
|
||||
* (i.e. the first value which would overflow in to the sign bit, resulting in
|
||||
* a NaN) in fp8e4m3fnuz range:
|
||||
* 1 0000 000 - fp8e4m3fnuz
|
||||
* 0 10000111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 8) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint32_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fnuz_max) {
|
||||
// NaN -- sign bit set to 1, rest 0s.
|
||||
return 0x80;
|
||||
}
|
||||
|
||||
if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) {
|
||||
// Input exponent is less than -7, the smallest e4m3fnuz exponent, so the
|
||||
// number will become subnormal.
|
||||
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
if (result == 0) {
|
||||
// fnuz types don't have negative zero.
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
|
||||
result |= sign >> 24;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e4m3fnuz {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e4m3fnuz() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e4m3fnuz& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e4m3fnuz-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
|
||||
@ -1,286 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
#define EXP_WIDTH_FP8 5
|
||||
#define MAN_WIDTH_FP8 2
|
||||
#define EXP_BIAS_FP8 15
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
|
||||
: x(detail::fp8e5m2_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
|
||||
return detail::fp8e5m2_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
|
||||
return (x & 0b01111111) > 0b01111100;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
|
||||
return (x & 0b01111111) == 0b01111100;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(
|
||||
const Float8_e5m2& a,
|
||||
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2 to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e5m2> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 3;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 2;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -13;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::Float8_e5m2 min() {
|
||||
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 max() {
|
||||
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 lowest() {
|
||||
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 epsilon() {
|
||||
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 round_error() {
|
||||
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 infinity() {
|
||||
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 quiet_NaN() {
|
||||
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 denorm_min() {
|
||||
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
||||
|
||||
@ -1,146 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
/// bias = 15
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E5M2 number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 26-30 24-25 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
uint16_t half_representation = input;
|
||||
half_representation <<= 8;
|
||||
return fp16_ieee_to_fp32_value(half_representation);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of fp32 infinity
|
||||
* 0 11111111 00000000000000000000000
|
||||
*/
|
||||
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
|
||||
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value
|
||||
* not representable in fp8e5m2 range:
|
||||
* 0 11111 00 - fp8e5m2
|
||||
* 0 10001111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 15) + (23 - 2) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(113) << 23)) {
|
||||
// Input number is smaller than 2^(-14), which is the smallest
|
||||
// fp8e5m2 normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint32_t mant_odd = (f_bits >> 21) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 21);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e5m2 {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e5m2() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e5m2(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
||||
|
||||
@ -1,285 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
|
||||
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
|
||||
const Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e5m2fnuz> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 3;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 2;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -14;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::Float8_e5m2fnuz min() {
|
||||
return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz max() {
|
||||
return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz lowest() {
|
||||
return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz epsilon() {
|
||||
return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz round_error() {
|
||||
return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz infinity() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
|
||||
// surprising and we should figure out what to do about it.
|
||||
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz denorm_min() {
|
||||
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
||||
|
||||
@ -1,138 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including
|
||||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration remains the same as e5m2:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
/// The key differences that e5m2fnuz brings are:
|
||||
/// bias = 16
|
||||
/// no infinities or negative zero
|
||||
/// NaN only when sign bit is 1, rest all 0s
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
|
||||
/// the existing Float8_e4m3fn implementation.
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value not
|
||||
* representable (i.e. the first value which would overflow in to the sign
|
||||
* bit, resulting in a NaN) in fp8e4m3fnuz range:
|
||||
* 1 00000 00 - fp8e5m2fnuz
|
||||
* 0 10001111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range
|
||||
* into denormalized representation.
|
||||
* magic number: ((127 - 16) + (23 - 2) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
uint32_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fnuz_max) {
|
||||
// NaN -- sign bit set to 1, rest 0s
|
||||
return 0x80;
|
||||
}
|
||||
|
||||
if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) {
|
||||
// Input exponent is less than -15, the smallest e5m2fnuz exponent, so the
|
||||
// number will become subnormal.
|
||||
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
if (result == 0) {
|
||||
// fnuz types don't have negative zero.
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 21) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 21);
|
||||
}
|
||||
|
||||
result |= sign >> 24;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e5m2fnuz {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e5m2fnuz() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e5m2fnuz& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e5m2fnuz-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
||||
|
||||
@ -1,112 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
// TODO(#146647): Can we remove the below warning?
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
|
||||
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
// if exponent is zero, need to special case to return 2^-127 instead of zero
|
||||
if (x == 0) {
|
||||
return c10::detail::fp32_from_bits(0x00400000);
|
||||
}
|
||||
|
||||
// if exponent is NaN, need to special case to return properly encoded NaN
|
||||
if (isnan()) {
|
||||
return c10::detail::fp32_from_bits(0x7f800001);
|
||||
}
|
||||
|
||||
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
|
||||
uint32_t res = x << 23;
|
||||
|
||||
return c10::detail::fp32_from_bits(res);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
|
||||
return x == 0b11111111;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e8m0fnu to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e8m0fnu> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = false;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = false;
|
||||
static constexpr auto has_denorm_loss = false;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 1;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 1; // just a 2!
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -126;
|
||||
static constexpr int min_exponent10 = -38;
|
||||
static constexpr int max_exponent = 128;
|
||||
static constexpr int max_exponent10 = 38;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e8m0fnu min() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu lowest() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu max() {
|
||||
// 254 biased, which is 127 unbiased, so 2^127
|
||||
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu epsilon() {
|
||||
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
|
||||
// is "the difference between 1.0 and the next representable value of the
|
||||
// given floating-point type". The next representable value is 2.0, so the
|
||||
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
|
||||
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu round_error() {
|
||||
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
|
||||
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
|
||||
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
||||
|
||||
@ -1,120 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
|
||||
/// conversions to standard C types
|
||||
/// Binary configuration :
|
||||
/// eeeeeeee
|
||||
/// no sign bits
|
||||
/// 8 exponent bits
|
||||
/// no mantissa bits
|
||||
///
|
||||
/// This is the E8M0 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.4.1)
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
// TODO(#146647): do we need to special case OPENCL?
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
uint32_t f_bits = c10::detail::fp32_to_bits(f);
|
||||
|
||||
// extract the exponent
|
||||
uint32_t exponent = (f_bits >> 23) & 0b11111111;
|
||||
|
||||
// special case float32 NaN and +-inf to map to e8m0 nan
|
||||
if (exponent == 0b11111111) {
|
||||
return exponent;
|
||||
}
|
||||
|
||||
// next, we use guard, round, sticky bits and the LSB to implement round to
|
||||
// nearest, with ties to even
|
||||
|
||||
// guard bit - bit 23, or 22 zero-indexed
|
||||
uint8_t g = (f_bits & 0x400000) > 0;
|
||||
// round bit - bit 22, or 21 zero-indexed
|
||||
uint8_t r = (f_bits & 0x200000) > 0;
|
||||
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
|
||||
uint8_t s = (f_bits & 0x1FFFFF) > 0;
|
||||
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
|
||||
// original float32 is denormal, and to 1 if the original float32 is normal.
|
||||
uint8_t lsb = exponent > 0;
|
||||
|
||||
// implement the RNE logic
|
||||
bool round_up = false;
|
||||
|
||||
// if g == 0, round down (no-op)
|
||||
if (g == 1) {
|
||||
if ((r == 1) || (s == 1)) {
|
||||
// round up
|
||||
round_up = true;
|
||||
} else {
|
||||
if (lsb == 1) {
|
||||
// round up
|
||||
round_up = true;
|
||||
}
|
||||
// if lsb == 0, round down (no-op)
|
||||
}
|
||||
}
|
||||
|
||||
if (round_up) {
|
||||
// adjust exponent
|
||||
// note that if exponent was 255 we would have already returned earlier, so
|
||||
// we know we can add one safely without running out of bounds
|
||||
exponent++;
|
||||
}
|
||||
|
||||
return exponent;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e8m0fnu {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e8m0fnu() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e8m0fnu& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
||||
|
||||
@ -1,350 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
||||
#elif defined(SYCL_LANGUAGE_VERSION)
|
||||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
#include <ATen/cpu/vec/vec_half.h>
|
||||
#endif
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
/// Constructors
|
||||
inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {}
|
||||
inline Half::operator float16_t() const {
|
||||
return detail::fp16_from_bits(x);
|
||||
}
|
||||
#else
|
||||
|
||||
inline C10_HOST_DEVICE Half::Half(float value)
|
||||
:
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
x(__half_as_short(__float2half(value)))
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
x(c10::bit_cast<uint16_t>(sycl::half(value)))
|
||||
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
x(at::vec::float2half_scalar(value))
|
||||
#else
|
||||
x(detail::fp16_ieee_from_fp32_value(value))
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Half::operator float() const {
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return __half2float(*reinterpret_cast<const __half*>(&x));
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
return float(c10::bit_cast<sycl::half>(x));
|
||||
#elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
return at::vec::half2float_scalar(x);
|
||||
#elif defined(__aarch64__) && !defined(__CUDACC__)
|
||||
return detail::native_fp16_to_fp32_value(x);
|
||||
#else
|
||||
return detail::fp16_ieee_to_fp32_value(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif /* !defined(__aarch64__) || defined(__CUDACC__) \
|
||||
*/
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
inline C10_HOST_DEVICE Half::Half(const __half& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half::operator __half() const {
|
||||
return *reinterpret_cast<const __half*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
inline C10_HOST_DEVICE Half::Half(const sycl::half& value) {
|
||||
x = *reinterpret_cast<const unsigned short*>(&value);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half::operator sycl::half() const {
|
||||
return *reinterpret_cast<const sycl::half*>(&x);
|
||||
}
|
||||
#endif
|
||||
|
||||
// CUDA intrinsics
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
|
||||
(defined(__clang__) && defined(__CUDA__))
|
||||
inline __device__ Half __ldg(const Half* ptr) {
|
||||
return __ldg(reinterpret_cast<const __half*>(ptr));
|
||||
}
|
||||
#endif
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator-(const Half& a) {
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
|
||||
defined(__HIP_DEVICE_COMPILE__)
|
||||
return __hneg(a);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
return -c10::bit_cast<sycl::half>(a);
|
||||
#else
|
||||
return -static_cast<float>(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Half a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Half a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Half a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Half a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Half b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Half b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Half b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Half b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Half a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Half a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Half a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Half a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Half b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Half b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Half b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Half b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Half operator+(Half a, int b) {
|
||||
return a + static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator-(Half a, int b) {
|
||||
return a - static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator*(Half a, int b) {
|
||||
return a * static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator/(Half a, int b) {
|
||||
return a / static_cast<Half>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator+(int a, Half b) {
|
||||
return static_cast<Half>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator-(int a, Half b) {
|
||||
return static_cast<Half>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator*(int a, Half b) {
|
||||
return static_cast<Half>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator/(int a, Half b) {
|
||||
return static_cast<Half>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
|
||||
return a + static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
|
||||
return a - static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
|
||||
return a * static_cast<Half>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
|
||||
return a / static_cast<Half>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
|
||||
return static_cast<Half>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
|
||||
return static_cast<Half>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
|
||||
return static_cast<Half>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
|
||||
return static_cast<Half>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Half to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Half> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = true;
|
||||
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
|
||||
static constexpr auto has_denorm_loss =
|
||||
numeric_limits<float>::has_denorm_loss;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = true;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 11;
|
||||
static constexpr int digits10 = 3;
|
||||
static constexpr int max_digits10 = 5;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -13;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
static constexpr c10::Half min() {
|
||||
return c10::Half(0x0400, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half lowest() {
|
||||
return c10::Half(0xFBFF, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half max() {
|
||||
return c10::Half(0x7BFF, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half epsilon() {
|
||||
return c10::Half(0x1400, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half round_error() {
|
||||
return c10::Half(0x3800, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half infinity() {
|
||||
return c10::Half(0x7C00, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half quiet_NaN() {
|
||||
return c10::Half(0x7E00, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half signaling_NaN() {
|
||||
return c10::Half(0x7D00, c10::Half::from_bits());
|
||||
}
|
||||
static constexpr c10::Half denorm_min() {
|
||||
return c10::Half(0x0001, c10::Half::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
|
||||
428
c10/util/Half.h
428
c10/util/Half.h
@ -1,424 +1,8 @@
|
||||
#pragma once
|
||||
#include <torch/headeronly/util/Half.h>
|
||||
|
||||
/// Defines the Half type (half-precision floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32, instead of using CUDA half intrinsics.
|
||||
/// Most uses of this type within ATen are memory bound, including the
|
||||
/// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
|
||||
/// If you are writing a compute bound kernel, you can use the CUDA half
|
||||
/// intrinsics directly on the Half type from device code.
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cmath>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
// need to keep the following for BC because the APIs in here were exposed
|
||||
// before migrating Half to torch/headeronly
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
|
||||
!defined(__APPLE__)
|
||||
#include <ATen/cpu/vec/vec_half.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iosfwd>
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
||||
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
||||
#elif defined(SYCL_LANGUAGE_VERSION)
|
||||
#include <sycl/sycl.hpp> // for SYCL 2020
|
||||
#endif
|
||||
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || \
|
||||
defined(_M_IX86)
|
||||
#if defined(__F16C__) && \
|
||||
!(defined(__CUDA_ARCH__) || defined(__CUDACC__) || \
|
||||
defined(__HIP_DEVICE_COMPILE__))
|
||||
#define C10_X86_F16 1
|
||||
#include <immintrin.h> // import conversion ops from f16cintrin.h
|
||||
#endif // defined(__F16C__) && !(defined(__CUDA_ARCH__) || defined(__CUDACC__)
|
||||
// || defined(__HIP_DEVICE_COMPILE__))
|
||||
#endif // __x86_64__ || _M_X64 || __i386 || _M_IX86
|
||||
#endif // __GNUC__ || __clang__
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+-----+------------+-------------------+
|
||||
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
|
||||
* +---+-----+------------+-------------------+
|
||||
* Bits 31 26-30 16-25 0-15
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)h << 16;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the bits 0-30
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +---+-----+------------+-------------------+
|
||||
* | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
|
||||
* +---+-----+------------+-------------------+
|
||||
* Bits 30 27-31 17-26 0-16
|
||||
*/
|
||||
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
|
||||
/*
|
||||
* Renorm shift is the number of bits to shift mantissa left to make the
|
||||
* half-precision number normalized. If the initial number is normalized, some
|
||||
* of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
|
||||
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
|
||||
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
|
||||
* mantissa will shift into exponent, turning the biased exponent into 1, and
|
||||
* making mantissa normalized (i.e. without leading 1).
|
||||
*/
|
||||
#ifdef _MSC_VER
|
||||
unsigned long nonsign_bsr;
|
||||
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
|
||||
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
|
||||
#else
|
||||
uint32_t renorm_shift = __builtin_clz(nonsign);
|
||||
#endif
|
||||
renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
|
||||
/*
|
||||
* Iff half-precision number has exponent of 15, the addition overflows
|
||||
* it into bit 31, and the subsequent shift turns the high 9 bits
|
||||
* into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
|
||||
* had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t inf_nan_mask =
|
||||
((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
|
||||
/*
|
||||
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
|
||||
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
|
||||
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
|
||||
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
|
||||
* 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||
/*
|
||||
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
|
||||
* was denormal)
|
||||
* 2. Shift nonsign right by 3 so the exponent (5 bits originally)
|
||||
* becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
|
||||
* bits of the 23-bit mantissa of IEEE single-precision number.
|
||||
* 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
|
||||
* different in exponent bias (0x7F for single-precision number less 0xF
|
||||
* for half-precision number).
|
||||
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
|
||||
* account for renormalization. As renorm_shift is less than 0x70, this
|
||||
* can be combined with step 3.
|
||||
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
|
||||
* input was NaN or infinity.
|
||||
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
|
||||
* into zero if the input was zero.
|
||||
* 7. Combine with the sign of the input number.
|
||||
*/
|
||||
return sign |
|
||||
((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtsh_ss(h);
|
||||
#else
|
||||
/*
|
||||
* Extend the half-precision floating-point number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+-----+------------+-------------------+
|
||||
* | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
|
||||
* +---+-----+------------+-------------------+
|
||||
* Bits 31 26-30 16-25 0-15
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)h << 16;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the high bits
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +-----+------------+---------------------+
|
||||
* |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
|
||||
* +-----+------------+---------------------+
|
||||
* Bits 27-31 17-26 0-16
|
||||
*/
|
||||
const uint32_t two_w = w + w;
|
||||
|
||||
/*
|
||||
* Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
|
||||
* mantissa and exponent of a single-precision floating-point number:
|
||||
*
|
||||
* S|Exponent | Mantissa
|
||||
* +-+---+-----+------------+----------------+
|
||||
* |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
|
||||
* +-+---+-----+------------+----------------+
|
||||
* Bits | 23-31 | 0-22
|
||||
*
|
||||
* Next, there are some adjustments to the exponent:
|
||||
* - The exponent needs to be corrected by the difference in exponent bias
|
||||
* between single-precision and half-precision formats (0x7F - 0xF = 0x70)
|
||||
* - Inf and NaN values in the inputs should become Inf and NaN values after
|
||||
* conversion to the single-precision number. Therefore, if the biased
|
||||
* exponent of the half-precision input was 0x1F (max possible value), the
|
||||
* biased exponent of the single-precision output must be 0xFF (max possible
|
||||
* value). We do this correction in two steps:
|
||||
* - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
|
||||
* below) rather than by 0x70 suggested by the difference in the exponent bias
|
||||
* (see above).
|
||||
* - Then we multiply the single-precision result of exponent adjustment by
|
||||
* 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
|
||||
* necessary exponent adjustment by 0x70 due to difference in exponent bias.
|
||||
* The floating-point multiplication hardware would ensure than Inf and
|
||||
* NaN would retain their value on at least partially IEEE754-compliant
|
||||
* implementations.
|
||||
*
|
||||
* Note that the above operations do not handle denormal inputs (where biased
|
||||
* exponent == 0). However, they also do not operate on denormal inputs, and
|
||||
* do not produce denormal results.
|
||||
*/
|
||||
constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
||||
// const float exp_scale = 0x1.0p-112f;
|
||||
constexpr uint32_t scale_bits = (uint32_t)15 << 23;
|
||||
float exp_scale_val = 0;
|
||||
#if defined(_MSC_VER) && defined(__clang__)
|
||||
__builtin_memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#else
|
||||
std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
|
||||
#endif
|
||||
|
||||
const float exp_scale = exp_scale_val;
|
||||
const float normalized_value =
|
||||
fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
||||
|
||||
/*
|
||||
* Convert denormalized half-precision inputs into single-precision results
|
||||
* (always normalized). Zero inputs are also handled here.
|
||||
*
|
||||
* In a denormalized number the biased exponent is zero, and mantissa has
|
||||
* on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
|
||||
*
|
||||
* zeros | mantissa
|
||||
* +---------------------------+------------+
|
||||
* |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
|
||||
* +---------------------------+------------+
|
||||
* Bits 10-31 0-9
|
||||
*
|
||||
* Now, remember that denormalized half-precision numbers are represented as:
|
||||
* FP16 = mantissa * 2**(-24).
|
||||
* The trick is to construct a normalized single-precision number with the
|
||||
* same mantissa and thehalf-precision input and with an exponent which would
|
||||
* scale the corresponding mantissa bits to 2**(-24). A normalized
|
||||
* single-precision floating-point number is represented as: FP32 = (1 +
|
||||
* mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
|
||||
* exponent is 126, a unit change in the mantissa of the input denormalized
|
||||
* half-precision number causes a change of the constructed single-precision
|
||||
* number by 2**(-24), i.e. the same amount.
|
||||
*
|
||||
* The last step is to adjust the bias of the constructed single-precision
|
||||
* number. When the input half-precision number is zero, the constructed
|
||||
* single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
|
||||
* 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
|
||||
* single-precision number to get the numerical equivalent of the input
|
||||
* half-precision number.
|
||||
*/
|
||||
constexpr uint32_t magic_mask = UINT32_C(126) << 23;
|
||||
constexpr float magic_bias = 0.5f;
|
||||
const float denormalized_value =
|
||||
fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
||||
|
||||
/*
|
||||
* - Choose either results of conversion of input as a normalized number, or
|
||||
* as a denormalized number, depending on the input exponent. The variable
|
||||
* two_w contains input exponent in bits 27-31, therefore if its smaller than
|
||||
* 2**27, the input is either a denormal number, or zero.
|
||||
* - Combine the result of conversion of exponent and mantissa with the sign
|
||||
* of the input number.
|
||||
*/
|
||||
constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
||||
const uint32_t result = sign |
|
||||
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
|
||||
: fp32_to_bits(normalized_value));
|
||||
return fp32_from_bits(result);
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 16-bit floating-point number in IEEE half-precision format, in bit
|
||||
* representation.
|
||||
*
|
||||
* @note The implementation relies on IEEE-like (no assumption about rounding
|
||||
* mode and no operations on denormals) floating-point operations and bitcasts
|
||||
* between integer and floating-point variables.
|
||||
*/
|
||||
inline uint16_t fp16_ieee_from_fp32_value(float f) {
|
||||
#ifdef C10_X86_F16
|
||||
return _cvtss_sh(f, _MM_FROUND_TO_NEAREST_INT);
|
||||
#else
|
||||
// const float scale_to_inf = 0x1.0p+112f;
|
||||
// const float scale_to_zero = 0x1.0p-110f;
|
||||
constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
|
||||
constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
|
||||
float scale_to_inf_val = 0, scale_to_zero_val = 0;
|
||||
std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
|
||||
std::memcpy(
|
||||
&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
|
||||
const float scale_to_inf = scale_to_inf_val;
|
||||
const float scale_to_zero = scale_to_zero_val;
|
||||
|
||||
#if defined(_MSC_VER) && _MSC_VER == 1916
|
||||
float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
|
||||
#else
|
||||
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
||||
#endif
|
||||
|
||||
const uint32_t w = fp32_to_bits(f);
|
||||
const uint32_t shl1_w = w + w;
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
||||
if (bias < UINT32_C(0x71000000)) {
|
||||
bias = UINT32_C(0x71000000);
|
||||
}
|
||||
|
||||
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
||||
const uint32_t bits = fp32_to_bits(base);
|
||||
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
||||
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
||||
const uint32_t nonsign = exp_bits + mantissa_bits;
|
||||
return static_cast<uint16_t>(
|
||||
(sign >> 16) |
|
||||
(shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
|
||||
#endif // C10_X86_F16
|
||||
}
|
||||
|
||||
#ifdef C10_X86_F16
|
||||
#undef C10_X86_F16
|
||||
#endif // C10_X86_F16
|
||||
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
inline float16_t fp16_from_bits(uint16_t h) {
|
||||
return c10::bit_cast<float16_t>(h);
|
||||
}
|
||||
|
||||
inline uint16_t fp16_to_bits(float16_t f) {
|
||||
return c10::bit_cast<uint16_t>(f);
|
||||
}
|
||||
|
||||
// According to https://godbolt.org/z/frExdbsWG it would translate to single
|
||||
// fcvt s0, h0
|
||||
inline float native_fp16_to_fp32_value(uint16_t h) {
|
||||
return static_cast<float>(fp16_from_bits(h));
|
||||
}
|
||||
|
||||
inline uint16_t native_fp16_from_fp32_value(float f) {
|
||||
return fp16_to_bits(static_cast<float16_t>(f));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(2) Half {
|
||||
unsigned short x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
// HIP wants __host__ __device__ tag, CUDA does not
|
||||
#if defined(USE_ROCM)
|
||||
C10_HOST_DEVICE Half() = default;
|
||||
#else
|
||||
Half() = default;
|
||||
#endif
|
||||
|
||||
constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {}
|
||||
#if defined(__aarch64__) && !defined(__CUDACC__)
|
||||
inline Half(float16_t value);
|
||||
inline operator float16_t() const;
|
||||
#else
|
||||
inline C10_HOST_DEVICE Half(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
inline C10_HOST_DEVICE Half(const __half& value);
|
||||
inline C10_HOST_DEVICE operator __half() const;
|
||||
#endif
|
||||
#ifdef SYCL_LANGUAGE_VERSION
|
||||
inline C10_HOST_DEVICE Half(const sycl::half& value);
|
||||
inline C10_HOST_DEVICE operator sycl::half() const;
|
||||
#endif
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Half& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Half-inl.h> // IWYU pragma: keep
|
||||
|
||||
@ -1,140 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wstring-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion")
|
||||
#endif
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Returns false since we cannot have x < 0 if x is unsigned.
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(
|
||||
const T& /*x*/,
|
||||
std::true_type /*is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if a signed variable x < 0
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) {
|
||||
return x < T(0);
|
||||
}
|
||||
|
||||
/// Returns true if x < 0
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :-(
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(const T& x) {
|
||||
return is_negative(x, std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
/// Returns the sign of an unsigned variable x as 0, 1
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) {
|
||||
return T(0) < x;
|
||||
}
|
||||
|
||||
/// Returns the sign of a signed variable x as -1, 0, 1
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) {
|
||||
return (T(0) < x) - (x < T(0));
|
||||
}
|
||||
|
||||
/// Returns the sign of x as -1, 0, 1
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :-(
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x) {
|
||||
return signum(x, std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
/// Returns true if a and b are not both negative
|
||||
template <typename T, typename U>
|
||||
inline constexpr bool signs_differ(const T& a, const U& b) {
|
||||
return is_negative(a) != is_negative(b);
|
||||
}
|
||||
|
||||
// Suppress sign compare warning when compiling with GCC
|
||||
// as later does not account for short-circuit rule before
|
||||
// raising the warning, see https://godbolt.org/z/Tr3Msnz99
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wsign-compare"
|
||||
#endif
|
||||
|
||||
/// Returns true if x is greater than the greatest value of the type Limit
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool greater_than_max(const T& x) {
|
||||
constexpr bool can_overflow =
|
||||
std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
|
||||
return can_overflow && x > (std::numeric_limits<Limit>::max)();
|
||||
}
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
/// Returns true if x < lowest(Limit). Standard comparison
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& x,
|
||||
std::false_type /*limit_is_unsigned*/,
|
||||
std::false_type /*x_is_unsigned*/) {
|
||||
return x < std::numeric_limits<Limit>::lowest();
|
||||
}
|
||||
|
||||
/// Returns false since all the limit is signed and therefore includes
|
||||
/// negative values but x cannot be negative because it is unsigned
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& /*x*/,
|
||||
std::false_type /*limit_is_unsigned*/,
|
||||
std::true_type /*x_is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if x < 0, where 0 is constructed from T.
|
||||
/// Limit is not signed, so its lower value is zero
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& x,
|
||||
std::true_type /*limit_is_unsigned*/,
|
||||
std::false_type /*x_is_unsigned*/) {
|
||||
return x < T(0);
|
||||
}
|
||||
|
||||
/// Returns false sign both types are unsigned
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& /*x*/,
|
||||
std::true_type /*limit_is_unsigned*/,
|
||||
std::true_type /*x_is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if x is less than the lowest value of type T
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(const T& x) {
|
||||
return less_than_lowest<Limit>(
|
||||
x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/TypeSafeSignMath.h>
|
||||
|
||||
@ -1,46 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
|
||||
#include <bit>
|
||||
#define C10_HAVE_STD_BIT_CAST 1
|
||||
#else
|
||||
#define C10_HAVE_STD_BIT_CAST 0
|
||||
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
|
||||
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
|
||||
|
||||
namespace c10 {
|
||||
|
||||
#if C10_HAVE_STD_BIT_CAST
|
||||
using std::bit_cast;
|
||||
#else
|
||||
// Implementations of std::bit_cast() from C++ 20.
|
||||
//
|
||||
// This is a less sketchy version of reinterpret_cast.
|
||||
//
|
||||
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
|
||||
// information as well as the source of our implementations.
|
||||
template <class To, class From>
|
||||
C10_HOST_DEVICE std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
To>
|
||||
// constexpr support needs compiler magic
|
||||
bit_cast(const From& src) noexcept {
|
||||
static_assert(
|
||||
std::is_trivially_constructible_v<To>,
|
||||
"This implementation additionally requires "
|
||||
"destination type to be trivially constructible");
|
||||
|
||||
To dst;
|
||||
std::memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
#endif // C10_HAVE_STD_BIT_CAST
|
||||
#undef C10_HAVE_STD_BIT_CAST
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/bit_cast.h>
|
||||
|
||||
@ -4,531 +4,7 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
#include <thrust/complex.h>
|
||||
#endif
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
|
||||
#endif
|
||||
#if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// c10::complex is an implementation of complex numbers that aims
|
||||
// to work on all devices supported by PyTorch
|
||||
//
|
||||
// Most of the APIs duplicates std::complex
|
||||
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
|
||||
//
|
||||
// [NOTE: Complex Operator Unification]
|
||||
// Operators currently use a mix of std::complex, thrust::complex, and
|
||||
// c10::complex internally. The end state is that all operators will use
|
||||
// c10::complex internally. Until then, there may be some hacks to support all
|
||||
// variants.
|
||||
//
|
||||
//
|
||||
// [Note on Constructors]
|
||||
//
|
||||
// The APIs of constructors are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/complex
|
||||
//
|
||||
// Since C++14, all constructors are constexpr in std::complex
|
||||
//
|
||||
// There are three types of constructors:
|
||||
// - initializing from real and imag:
|
||||
// `constexpr complex( const T& re = T(), const T& im = T() );`
|
||||
// - implicitly-declared copy constructor
|
||||
// - converting constructors
|
||||
//
|
||||
// Converting constructors:
|
||||
// - std::complex defines converting constructor between float/double/long
|
||||
// double,
|
||||
// while we define converting constructor between float/double.
|
||||
// - For these converting constructors, upcasting is implicit, downcasting is
|
||||
// explicit.
|
||||
// - We also define explicit casting from std::complex/thrust::complex
|
||||
// - Note that the conversion from thrust is not constexpr, because
|
||||
// thrust does not define them as constexpr ????
|
||||
//
|
||||
//
|
||||
// [Operator =]
|
||||
//
|
||||
// The APIs of operator = are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
|
||||
//
|
||||
// Since C++20, all operator= are constexpr. Although we are not building with
|
||||
// C++20, we also obey this behavior.
|
||||
//
|
||||
// There are three types of assign operator:
|
||||
// - Assign a real value from the same scalar type
|
||||
// - In std, this is templated as complex& operator=(const T& x)
|
||||
// with specialization `complex& operator=(T x)` for float/double/long
|
||||
// double Since we only support float and double, on will use `complex&
|
||||
// operator=(T x)`
|
||||
// - Copy assignment operator and converting assignment operator
|
||||
// - There is no specialization of converting assignment operators, which type
|
||||
// is
|
||||
// convertible is solely dependent on whether the scalar type is convertible
|
||||
//
|
||||
// In addition to the standard assignment, we also provide assignment operators
|
||||
// with std and thrust
|
||||
//
|
||||
//
|
||||
// [Casting operators]
|
||||
//
|
||||
// std::complex does not have casting operators. We define casting operators
|
||||
// casting to std::complex and thrust::complex
|
||||
//
|
||||
//
|
||||
// [Operator ""]
|
||||
//
|
||||
// std::complex has custom literals `i`, `if` and `il` defined in namespace
|
||||
// `std::literals::complex_literals`. We define our own custom literals in the
|
||||
// namespace `c10::complex_literals`. Our custom literals does not follow the
|
||||
// same behavior as in std::complex, instead, we define _if, _id to construct
|
||||
// float/double complex literals.
|
||||
//
|
||||
//
|
||||
// [real() and imag()]
|
||||
//
|
||||
// In C++20, there are two overload of these functions, one it to return the
|
||||
// real/imag, another is to set real/imag, they are both constexpr. We follow
|
||||
// this design.
|
||||
//
|
||||
//
|
||||
// [Operator +=,-=,*=,/=]
|
||||
//
|
||||
// Since C++20, these operators become constexpr. In our implementation, they
|
||||
// are also constexpr.
|
||||
//
|
||||
// There are two types of such operators: operating with a real number, or
|
||||
// operating with another complex number. For the operating with a real number,
|
||||
// the generic template form has argument type `const T &`, while the overload
|
||||
// for float/double/long double has `T`. We will follow the same type as
|
||||
// float/double/long double in std.
|
||||
//
|
||||
// [Unary operator +-]
|
||||
//
|
||||
// Since C++20, they are constexpr. We also make them expr
|
||||
//
|
||||
// [Binary operators +-*/]
|
||||
//
|
||||
// Each operator has three versions (taking + as example):
|
||||
// - complex + complex
|
||||
// - complex + real
|
||||
// - real + complex
|
||||
//
|
||||
// [Operator ==, !=]
|
||||
//
|
||||
// Each operator has three versions (taking == as example):
|
||||
// - complex == complex
|
||||
// - complex == real
|
||||
// - real == complex
|
||||
//
|
||||
// Some of them are removed on C++20, but we decide to keep them
|
||||
//
|
||||
// [Operator <<, >>]
|
||||
//
|
||||
// These are implemented by casting to std::complex
|
||||
//
|
||||
//
|
||||
//
|
||||
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
|
||||
// because:
|
||||
// - lots of members and functions of c10::Half are not constexpr
|
||||
// - thrust::complex only support float and double
|
||||
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * 2) complex {
|
||||
using value_type = T;
|
||||
|
||||
T real_ = T(0);
|
||||
T imag_ = T(0);
|
||||
|
||||
constexpr complex() = default;
|
||||
C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
|
||||
: real_(re), imag_(im) {}
|
||||
template <typename U>
|
||||
explicit constexpr complex(const std::complex<U>& other)
|
||||
: complex(other.real(), other.imag()) {}
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
|
||||
: real_(other.real()), imag_(other.imag()) {}
|
||||
// NOTE can not be implemented as follow due to ROCm bug:
|
||||
// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
|
||||
// complex(other.real(), other.imag()) {}
|
||||
#endif
|
||||
|
||||
// Use SFINAE to specialize casting constructor for c10::complex<float> and
|
||||
// c10::complex<double>
|
||||
template <typename U = T>
|
||||
C10_HOST_DEVICE explicit constexpr complex(
|
||||
const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
template <typename U = T>
|
||||
C10_HOST_DEVICE constexpr complex(
|
||||
const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
|
||||
constexpr complex<T>& operator=(T re) {
|
||||
real_ = re;
|
||||
imag_ = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator+=(T re) {
|
||||
real_ += re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator-=(T re) {
|
||||
real_ -= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator*=(T re) {
|
||||
real_ *= re;
|
||||
imag_ *= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator/=(T re) {
|
||||
real_ /= re;
|
||||
imag_ /= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator=(const complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator+=(const complex<U>& rhs) {
|
||||
real_ += rhs.real();
|
||||
imag_ += rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator-=(const complex<U>& rhs) {
|
||||
real_ -= rhs.real();
|
||||
imag_ -= rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator*=(const complex<U>& rhs) {
|
||||
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
|
||||
T a = real_;
|
||||
T b = imag_;
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define FORCE_INLINE_APPLE __attribute__((always_inline))
|
||||
#else
|
||||
#define FORCE_INLINE_APPLE
|
||||
#endif
|
||||
template <typename U>
|
||||
constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
|
||||
// the calculation below follows numpy's complex division
|
||||
T a = real_;
|
||||
T b = imag_;
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
// std::abs is already constexpr by gcc
|
||||
auto abs_c = std::abs(c);
|
||||
auto abs_d = std::abs(d);
|
||||
#else
|
||||
auto abs_c = c < 0 ? -c : c;
|
||||
auto abs_d = d < 0 ? -d : d;
|
||||
#endif
|
||||
|
||||
if (abs_c >= abs_d) {
|
||||
if (abs_c == U(0) && abs_d == U(0)) {
|
||||
/* divide by zeros should yield a complex inf or nan */
|
||||
real_ = a / abs_c;
|
||||
imag_ = b / abs_d;
|
||||
} else {
|
||||
auto rat = d / c;
|
||||
auto scl = U(1.0) / (c + d * rat);
|
||||
real_ = (a + b * rat) * scl;
|
||||
imag_ = (b - a * rat) * scl;
|
||||
}
|
||||
} else {
|
||||
auto rat = c / d;
|
||||
auto scl = U(1.0) / (d + c * rat);
|
||||
real_ = (a * rat + b) * scl;
|
||||
imag_ = (b * rat - a) * scl;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#undef FORCE_INLINE_APPLE
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator=(const std::complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename U>
|
||||
explicit constexpr operator std::complex<U>() const {
|
||||
return std::complex<U>(std::complex<T>(real(), imag()));
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
|
||||
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
|
||||
}
|
||||
#endif
|
||||
|
||||
// consistent with NumPy behavior
|
||||
explicit constexpr operator bool() const {
|
||||
return real() || imag();
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE constexpr T real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr void real(T value) {
|
||||
real_ = value;
|
||||
}
|
||||
C10_HOST_DEVICE constexpr T imag() const {
|
||||
return imag_;
|
||||
}
|
||||
constexpr void imag(T value) {
|
||||
imag_ = value;
|
||||
}
|
||||
};
|
||||
|
||||
namespace complex_literals {
|
||||
|
||||
constexpr complex<float> operator""_if(long double imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator""_id(long double imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<float> operator""_if(unsigned long long imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator""_id(unsigned long long imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
} // namespace complex_literals
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& val) {
|
||||
return complex<T>(-val.real(), -val.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
|
||||
return complex<T>(lhs + rhs.real(), rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = -rhs;
|
||||
return result += lhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = rhs;
|
||||
return result *= lhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result(lhs, T());
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
// Define operators between integral scalars and c10::complex. std::complex does
|
||||
// not support this when T is a floating-point number. This is useful because it
|
||||
// saves a lot of "static_cast" when operate a complex and an integer. This
|
||||
// makes the code both less verbose and potentially more efficient.
|
||||
#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
|
||||
typename std::enable_if_t< \
|
||||
std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
|
||||
int> = 0
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
|
||||
return a + static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) + b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
|
||||
return a - static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) - b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
|
||||
return a * static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) * b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
|
||||
return a / static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) / b;
|
||||
}
|
||||
|
||||
#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
|
||||
return (lhs.real() == rhs) && (lhs.imag() == T());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
|
||||
return (lhs == rhs.real()) && (T() == rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_ostream<CharT, Traits>& operator<<(
|
||||
std::basic_ostream<CharT, Traits>& os,
|
||||
const complex<T>& x) {
|
||||
return (os << static_cast<std::complex<T>>(x));
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_istream<CharT, Traits>& operator>>(
|
||||
std::basic_istream<CharT, Traits>& is,
|
||||
complex<T>& x) {
|
||||
std::complex<T> tmp;
|
||||
is >> tmp;
|
||||
x = tmp;
|
||||
return is;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/complex.h>
|
||||
|
||||
// std functions
|
||||
//
|
||||
@ -594,72 +70,6 @@ constexpr c10::complex<T> conj(const c10::complex<T>& z) {
|
||||
|
||||
} // namespace std
|
||||
|
||||
namespace c10 {
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
return static_cast<complex<T>>(thrust::polar(r, theta));
|
||||
#else
|
||||
// std::polar() requires r >= 0, so spell out the explicit implementation to
|
||||
// avoid a branch.
|
||||
return complex<T>(r * std::cos(theta), r * std::sin(theta));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
struct alignas(4) complex<Half> {
|
||||
Half real_;
|
||||
Half imag_;
|
||||
|
||||
// Constructors
|
||||
complex() = default;
|
||||
// Half constructor is not constexpr so the following constructor can't
|
||||
// be constexpr
|
||||
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
||||
: real_(real), imag_(imag) {}
|
||||
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
||||
: real_(value.real()), imag_(value.imag()) {}
|
||||
|
||||
// Conversion operator
|
||||
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
||||
return {real_, imag_};
|
||||
}
|
||||
|
||||
constexpr C10_HOST_DEVICE Half real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr C10_HOST_DEVICE Half imag() const {
|
||||
return imag_;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
||||
auto a = static_cast<float>(real_);
|
||||
auto b = static_cast<float>(imag_);
|
||||
auto c = static_cast<float>(other.real());
|
||||
auto d = static_cast<float>(other.imag());
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
|
||||
#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
|
||||
// math functions are included in a separate file
|
||||
#include <c10/util/complex_math.h> // IWYU pragma: keep
|
||||
|
||||
@ -1,33 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <cstdint>
|
||||
|
||||
namespace c10::detail {
|
||||
|
||||
C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_float(w);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return __uint_as_float((unsigned int)w);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castu32_f32(w);
|
||||
#else
|
||||
return c10::bit_cast<float>(w);
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
|
||||
#if defined(__OPENCL_VERSION__)
|
||||
return as_uint(f);
|
||||
#elif defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return (uint32_t)__float_as_uint(f);
|
||||
#elif defined(__INTEL_COMPILER)
|
||||
return _castf32_u32(f);
|
||||
#else
|
||||
return c10::bit_cast<uint32_t>(f);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace c10::detail
|
||||
#include <torch/headeronly/util/floating_point_utils.h>
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
@ -20,8 +21,6 @@ constexpr size_t kMinBlockSize = 512;
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
constexpr size_t kLargeBuffer = 20971520;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
|
||||
@ -1346,6 +1346,10 @@ if(BUILD_TEST)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
|
||||
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
|
||||
add_subdirectory(
|
||||
${TORCH_ROOT}/test/cpp/tensorexpr
|
||||
${CMAKE_BINARY_DIR}/test_tensorexpr
|
||||
)
|
||||
if(USE_DISTRIBUTED)
|
||||
add_subdirectory(${TORCH_ROOT}/test/cpp/c10d ${CMAKE_BINARY_DIR}/test_cpp_c10d)
|
||||
if(NOT WIN32)
|
||||
@ -1767,6 +1771,10 @@ if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
|
||||
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
|
||||
|
||||
if(USE_FBGEMM_GENAI)
|
||||
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
|
||||
endif()
|
||||
|
||||
# Since PyTorch files contain HIP headers, this is also needed to capture the includes.
|
||||
# ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake
|
||||
target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE} ${ROCM_INCLUDE_DIRS})
|
||||
|
||||
@ -362,14 +362,6 @@ function(torch_compile_options libname)
|
||||
# For MS official doc: https://learn.microsoft.com/en-us/cpp/build/reference/zc-preprocessor
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:preprocessor" PARENT_SCOPE)
|
||||
|
||||
if(${MSVC_TOOLSET_VERSION} GREATER_EQUAL 143)
|
||||
# Add /d2implyavx512upperregs- to disable compiler over-aggressive optimization, which caused involeved AVX512 register on AVX2 machine.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/145702#issuecomment-2874029459
|
||||
target_compile_options(${libname} PUBLIC $<$<COMPILE_LANGUAGE:CXX>:/d2implyavx512upperregs->)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
target_compile_options(${libname} PUBLIC
|
||||
$<$<COMPILE_LANGUAGE:CXX>:
|
||||
${MSVC_RUNTIME_LIBRARY_OPTION}
|
||||
|
||||
17
docs/source/_static/js/runllm-widget.js
Normal file
17
docs/source/_static/js/runllm-widget.js
Normal file
@ -0,0 +1,17 @@
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
var script = document.createElement("script");
|
||||
script.type = "module";
|
||||
script.id = "runllm-widget-script"
|
||||
|
||||
script.src = "https://widget.runllm.com";
|
||||
|
||||
script.setAttribute("version", "stable");
|
||||
script.setAttribute("crossorigin", "true");
|
||||
script.setAttribute("runllm-keyboard-shortcut", "Mod+j");
|
||||
script.setAttribute("runllm-name", "PyTorch");
|
||||
script.setAttribute("runllm-position", "BOTTOM_RIGHT");
|
||||
script.setAttribute("runllm-assistant-id", "834");
|
||||
|
||||
script.async = true;
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
BIN
docs/source/compile/_static/dynamo_summary_diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 424 KiB |
15
docs/source/compile/header_code.py
Normal file
15
docs/source/compile/header_code.py
Normal file
@ -0,0 +1,15 @@
|
||||
import functools
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# to lower notebook execution time while hiding backend="eager"
|
||||
torch.compile = functools.partial(torch.compile, backend="eager")
|
||||
|
||||
# to clear torch logs format
|
||||
os.environ["TORCH_LOGS_FORMAT"] = ""
|
||||
torch._logging._internal.DEFAULT_FORMATTER = (
|
||||
torch._logging._internal._default_formatter()
|
||||
)
|
||||
torch._logging._internal._init_logs()
|
||||
142
docs/source/compile/programming_model.common_graph_breaks.md
Normal file
142
docs/source/compile/programming_model.common_graph_breaks.md
Normal file
@ -0,0 +1,142 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
|
||||
torch._logging.set_logs(graph_breaks=True)
|
||||
```
|
||||
|
||||
# Common Graph Breaks
|
||||
|
||||
Below are some common graph breaks and some workarounds.
|
||||
|
||||
## Incorrect Code
|
||||
Your code might contain errors (meaning it doesn't execute even without `torch.compile`). In the example below, there's a typo in the `torch.sin` call due to an extra argument. **Always disable `torch.compile` to check if the code runs correctly.**
|
||||
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
y = torch.sin(x, x)
|
||||
return y
|
||||
|
||||
try:
|
||||
fn(torch.ones(3, 3))
|
||||
except Exception as e:
|
||||
pass
|
||||
```
|
||||
|
||||
Dynamo makes a best-effort attempt to hint if a graph break is caused by your code.
|
||||
But it can still sometimes be difficult to tell from the logs if the graph break is caused by an error in your code,
|
||||
is a more complicated graph break, or is a `torch.compile` bug. In order to differentiate, we recommend trying to run your code without `torch.compile` to see if you still get the error reported by the graph break.
|
||||
|
||||
## Data-dependent operations
|
||||
|
||||
`torch.compile` graph breaks on data-dependent operations such as data-dependent control flow (if-statements, loops with tensors) and direct tensor data accesses (`.item`, `.data_ptr`).
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
y = x.sum()
|
||||
if y > 0:
|
||||
return x + y.item()
|
||||
return x - y.item()
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
The general workaround for these graph breaks is to avoid doing data-dependent operations. Some specific workarounds are:
|
||||
|
||||
- If your control flow doesn't actually depend on data values, consider modifying your code to perform control flow on constants.
|
||||
|
||||
|
||||
```{code-cell}
|
||||
# old
|
||||
x = torch.randn(3, 3)
|
||||
@torch.compile
|
||||
def fn(y):
|
||||
if x.sum() > 0:
|
||||
return y + x
|
||||
else:
|
||||
return y - x
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# new
|
||||
x = torch.randn(3, 3)
|
||||
cond = (x.sum() > 0).item()
|
||||
@torch.compile
|
||||
def fn(y):
|
||||
if cond:
|
||||
return y + x
|
||||
else:
|
||||
return y - x
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
- Use higher-order ops like {ref}`cond` in place of data-dependent control flow
|
||||
|
||||
|
||||
```{code-cell}
|
||||
# old
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
if x.sum() > 0:
|
||||
return x + 1
|
||||
return x - 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# new
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return torch.cond(
|
||||
x.sum() > 0,
|
||||
lambda x: x + 1,
|
||||
lambda x: x - 1,
|
||||
(x,),
|
||||
)
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
- If you have a `.item()` call, try `torch._dynamo.config.capture_scalar_outputs = True`
|
||||
or `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`.
|
||||
- Wrap problematic parts of the function in a custom operator
|
||||
|
||||
## Printing and logging
|
||||
|
||||
Printing/logging/issuing warnings will result in a graph break.
|
||||
You can try working around this by using `torch._dynamo.config.reorderable_logging_functions`.
|
||||
This config is used to reorder logging functions so that they are called at the end of the
|
||||
traced function, thus avoiding a graph break.
|
||||
However, the logged contents may differ if, for example, a mutation occurs.
|
||||
|
||||
|
||||
```{code-cell}
|
||||
torch._dynamo.config.reorderable_logging_functions.add(print)
|
||||
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x += 1
|
||||
print("log!")
|
||||
return torch.sin(x)
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
75
docs/source/compile/programming_model.compiler_disable.md
Normal file
75
docs/source/compile/programming_model.compiler_disable.md
Normal file
@ -0,0 +1,75 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
|
||||
torch._logging.set_logs(graph_breaks=True, graph_code=True)
|
||||
```
|
||||
|
||||
# Disabling and Suppressing Errors
|
||||
For some model architectures, there are portions of the model which are particularly difficult to compile -
|
||||
either there are many graph breaks, or there are crashes.
|
||||
You may want to explicitly disable these portions of the model which are problematic so that you can apply
|
||||
`torch.compile` to the parts that work. You can do this by using the `@torch.compiler.disable` decorator.
|
||||
When `torch.compile` attempts to call a disabled function, it breaks the graph and skips tracing the disabled function,
|
||||
resuming tracing after the call. By default, all recursive calls made from a disabled function are also disabled.
|
||||
Use the `recursive=False` option to allow compilation for recursive calls.
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return x + 1 # not traced
|
||||
|
||||
@torch.compiler.disable
|
||||
def outer1(x):
|
||||
x = x + 2 # not traced
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return inner1(x)
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
x = outer1(x)
|
||||
return x + 4 # traced
|
||||
|
||||
print(f(torch.ones(3)))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def inner2(x):
|
||||
torch._dynamo.graph_break() # traced
|
||||
return x + 1 # traced
|
||||
|
||||
@torch.compiler.disable(recursive=False)
|
||||
def outer2(x):
|
||||
x = x + 2 # not traced
|
||||
torch._dynamo.graph_break() # not traced
|
||||
return inner2(x)
|
||||
|
||||
@torch.compile
|
||||
def g(x):
|
||||
x = outer2(x)
|
||||
return x + 4 # traced
|
||||
|
||||
print(g(torch.ones(3)))
|
||||
```
|
||||
|
||||
For example, one can use `torch.compiler.disable` to disable `torch.compile` on sparse architecture in
|
||||
recommendation models, as the sparse arch is difficult to compile.
|
||||
Preprocessing and logging functions are other examples of functions that typically cause
|
||||
a lot of graph breaks and do not get value from being compiled.
|
||||
|
||||
If you are experiencing compiler crashes and you want to continue regardless,
|
||||
you can set `torch._dynamo.config.suppress_errors = True`.
|
||||
When the compiler crashes, we will just skip tracing the function and try again later.
|
||||
**This is not best practice** - it is better to eventually manually add `disable` annotations as necessary.
|
||||
12
docs/source/compile/programming_model.custom_ops.md
Normal file
12
docs/source/compile/programming_model.custom_ops.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Custom Operators
|
||||
|
||||
**Summary:**
|
||||
- Use custom operators to have `torch.compile` treat a function as opaque. `torch.compile` will never trace into the function and Inductor (the backend) will run the function as-is.
|
||||
|
||||
You may wish to use a custom operator in any of the following situations:
|
||||
- Your code calls some C/C++/CUDA code. Dynamo is a Python bytecode interpreter and generally does not know how to handle calls to C/C++/CUDA functions that are bound to Python.
|
||||
- Dynamo and non-strict tracing have trouble tracing through a function and you want it to be ignored by `torch.compile`.
|
||||
|
||||
Please see [the Python custom ops tutorial](https://pytorch.org/tutorials/advanced/python_custom_ops.html#python-custom-ops-tutorial)for more details on how to wrap a Python function into a `torch.compile`-understood custom operator.
|
||||
|
||||
For more advanced use cases, you may wish to use our C++ Custom Operator API; please see [here](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) for more information.
|
||||
167
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
167
docs/source/compile/programming_model.dynamo_core_concepts.md
Normal file
@ -0,0 +1,167 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
```
|
||||
|
||||
# Dynamo Core Concepts
|
||||
|
||||
**Summary:**
|
||||
|
||||
- Dynamo, `torch.compile`'s frontend, performs **tracing** to capture the semantics of a Python function
|
||||
(and its nested function calls) into a linear sequence of operations (the "(FX) graph"),
|
||||
residual bytecode, and "guards" (a list of conditions under which the graph and bytecode are valid).
|
||||
- Unsupported Python features lead to **graph breaks**, where Dynamo compiles a partial graph acquired from tracing,
|
||||
then runs the unsupported code, then resumes tracing.
|
||||
- Graph breaks may lead to slowness in torch.compile and prevent backend optimization opportunities.
|
||||
If you're not seeing the performance you expect, then check for graph breaks.
|
||||
|
||||
## Dynamo Tracing
|
||||
`torch.compile`'s frontend (Dynamo) is a custom Python bytecode interpreter designed to allow graph compilation
|
||||
in PyTorch programs while retaining the full flexibility of Python. Given a function to be compiled, Dynamo
|
||||
interprets Python bytecode to extract sequences of PyTorch operations into 1 or more FX graphs that may be further optimized by a backend.
|
||||
|
||||

|
||||
|
||||
For example, for the function `f` in the above diagram, Dynamo produces:
|
||||
- a single **FX graph** that takes in the original input plus some additional inputs required by the function.
|
||||
- **Python bytecode** that can be used as a drop-in replacement for `f`. In our example, the bytecode retrieves
|
||||
the additional inputs and passes it to the graph and also contains unoptimizable Python side effects (the list append)
|
||||
- **guards** that specify the conditions under which the graph and bytecode are valid. Unless otherwise specified,
|
||||
the graph produced by Dynamo specializes on the shapes of input Tensors.
|
||||
|
||||
(programming_model.dynamo_core_concepts.graph_breaks)=
|
||||
|
||||
## Graph Breaks
|
||||
Dynamo traces your code and attempts to capture your PyTorch code into a single computation graph of PyTorch
|
||||
operators (FX graph). However, this is not always possible. When encountering code that can't be traced, a "**graph break**" occurs.
|
||||
In the default `torch.compile` settings, a graph break involves compiling the FX graph that has been determined so far,
|
||||
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
|
||||
|
||||
Graph breaks are a feature that allows Dynamo to run over arbitrary Python code and carve out functional subgraphs that can each be individually optimized.
|
||||
|
||||
However, it is possible for graph breaks to lead to unexpected slowness in `torch.compile`.
|
||||
If you're not getting the speedups you expect, we recommend checking for graph breaks and removing them.
|
||||
|
||||
Graph breaks may occur on things like:
|
||||
|
||||
- Data-dependent if-statements
|
||||
- Many Python built-in functions
|
||||
- C functions
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(graph_breaks=True)
|
||||
```
|
||||
|
||||
Below is an example of a graph break due to calling an unsupported operation `torch.save`:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def f(x):
|
||||
y = x ** 2 / 2
|
||||
torch.save(y, "foo.pt") # torch.save is an unsupported operation
|
||||
z = y ** 3 / 6
|
||||
return z
|
||||
|
||||
x = torch.randn(3)
|
||||
print(f(x))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import os
|
||||
os.remove("foo.pt")
|
||||
```
|
||||
|
||||
The semantics of `torch.compile(f)(x)` are roughly this:
|
||||
|
||||
```python
|
||||
def compiled_f_semantics(x):
|
||||
y = torch.compile(g, fullgraph=True)(x)
|
||||
torch.save(y, "foo.pt")
|
||||
z = torch.compile(h, fullgraph=True)(x)
|
||||
return z
|
||||
|
||||
def g(x):
|
||||
return x ** 2 / 2
|
||||
|
||||
def h(x):
|
||||
return y ** 3 / 6
|
||||
```
|
||||
|
||||
## Guards
|
||||
|
||||
`torch.compile` makes some assumptions about runtime values as we trace through code. During tracing, we generate "guards",
|
||||
which are runtime checks for these assumptions. Guards are run in future calls to the compiled function to determine if we
|
||||
can reuse previously compiled code. Examples of runtime checks are constant values, types, and object IDs.
|
||||
|
||||
Below is an example of generated guards. The `TENSOR_MATCH` guard checks for the input's type, device, dtype, shape, etc.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(guards=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
```
|
||||
|
||||
## Recompilations
|
||||
If the guards fail for every instance of previously compiled code, then `torch.compile` must "recompile" the function,
|
||||
requiring the original code to be traced again. In the example below, recompilation is necessary because the guard checking the tensor argument's shape failed.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._logging.set_logs(recompiles=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
print(fn(torch.ones(4, 4)))
|
||||
```
|
||||
|
||||
## Dynamic Shapes
|
||||
|
||||
`torch.compile` initially assumes tensor shapes are static/constant and guards based on these assumptions. By using "dynamic shapes,"
|
||||
we can get `torch.compile` to produce compiled code that can accept tensor inputs with different shapes - we avoid recompiling every time shapes differ.
|
||||
By default, automatic dynamic shapes are enabled in `torch.compile(dynamic=None)` - if compilation fails due to shape mismatch,
|
||||
recompilation is attempted with dynamic shapes. Dynamic shapes can also be fully enabled (`dynamic=True`) or disabled (`dynamic=False`).
|
||||
|
||||
Below, we enable dynamic shapes and note that we no longer need to recompile.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import logging
|
||||
torch._logging.set_logs(dynamic=logging.DEBUG, recompiles=True)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(dynamic=True)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
print(fn(torch.ones(3, 3)))
|
||||
print(fn(torch.ones(4, 4)))
|
||||
```
|
||||
|
||||
For more information on dynamic shapes, see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
|
||||
101
docs/source/compile/programming_model.dynamo_nonstrict_trace.md
Normal file
101
docs/source/compile/programming_model.dynamo_nonstrict_trace.md
Normal file
@ -0,0 +1,101 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
```
|
||||
|
||||
# Use `torch._dynamo.nonstrict_trace`
|
||||
|
||||
**Summary:**
|
||||
- Use `nonstrict_trace` to trace a function with non-strict tracing inside of a `torch.compile`'d region.
|
||||
You may wish to do this because the Dynamo graph breaks on something inside of the function
|
||||
and you are sure that the function is non-strict traceable.
|
||||
|
||||
Consider the following scenario:
|
||||
|
||||
```{code-cell}
|
||||
def get_magic_num():
|
||||
# This explicit graph break call is meant to emulate any kind of Dynamo
|
||||
# graph break, e.g., the function is implemented in C, or uses some python
|
||||
# language feature Dynamo doesn't yet support.
|
||||
torch._dynamo.graph_break()
|
||||
return torch.tensor([42])
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x):
|
||||
n = get_magic_num()
|
||||
return x + n
|
||||
try:
|
||||
func(torch.rand(10))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
If we run the code above, we'll get an error from Dynamo, because it sees a graph break while the user specified `fullgraph=True`.
|
||||
|
||||
In these situations, if a user still wants to keep `fullgraph=True`, they typically have several options:
|
||||
|
||||
1. The graph break is due to a language feature Dynamo doesn't yet support.
|
||||
In this case, the user either rewrites their code, or files an issue on GitHub.
|
||||
2. The graph break is due to a call to a function implemented in C.
|
||||
In this case, the user can try to use a custom op.
|
||||
The user could also try providing a polyfill (a reference implementation in Python)
|
||||
so that Dynamo can trace through it.
|
||||
3. Worst case scenario -- an internal compiler error. In this case, the user likely has to file an issue on GitHub.
|
||||
|
||||
In addition to all these options, PyTorch does provide an alternative `torch._dynamo.nonstrict_trace`, if the function call that induced the graph break satisfies certain requirements:
|
||||
|
||||
- The requirements of [general non-strict tracing](programming_model.non_strict_tracing_model).
|
||||
- The inputs and outputs must contain either basic types (e.g., `int`, `float`, `list`, `dict`, `torch.Tensor`),
|
||||
or user-defined types that are registered to `torch.utils._pytree`.
|
||||
- The function must be defined outside the `torch.compile`'d region.
|
||||
- Any non-input values read by the function will be treated as a constant
|
||||
(e.g., a global tensor), and will not be guarded on.
|
||||
|
||||
When tracing through a call to a `torch._dynamo.nonstrict_trace`'d function, `torch.compile` switches to [non-strict tracing](programming_model.non_strict_tracing_model),
|
||||
and the FX graph will eventually contain all the relevant tensor operations which happened inside that function.
|
||||
|
||||
For the example above, we can use `torch._dynamo.nonstrict_trace to eliminate` the graph break:
|
||||
|
||||
```{code-cell}
|
||||
@torch._dynamo.nonstrict_trace
|
||||
def get_magic_num():
|
||||
# This explicit graph break call is meant to emulate any kind of Dynamo
|
||||
# graph break, e.g., the function is implemented in C, or uses some python
|
||||
# language feature Dynamo doesn't yet support.
|
||||
torch._dynamo.graph_break()
|
||||
return torch.tensor([42])
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x):
|
||||
n = get_magic_num()
|
||||
return x + n
|
||||
print(func(torch.rand(10)))
|
||||
# No graph break and no error.
|
||||
```
|
||||
|
||||
Note that one can use it inside a `torch.compile`'d region as well:
|
||||
|
||||
```{code-cell}
|
||||
def get_magic_num():
|
||||
# This explicit graph break call is meant to emulate any kind of Dynamo
|
||||
# graph break, e.g., the function is implemented in C, or uses some python
|
||||
# language feature Dynamo doesn't yet support.
|
||||
torch._dynamo.graph_break()
|
||||
return torch.tensor([42])
|
||||
@torch.compile(fullgraph=True)
|
||||
def func(x):
|
||||
n = torch._dynamo.nonstrict_trace(get_magic_num)()
|
||||
return x + n
|
||||
print(func(torch.rand(10)))
|
||||
# No graph break and no error.
|
||||
```
|
||||
24
docs/source/compile/programming_model.fullgraph_false.md
Normal file
24
docs/source/compile/programming_model.fullgraph_false.md
Normal file
@ -0,0 +1,24 @@
|
||||
# Working with `fullgraph=False`
|
||||
While `fullgraph=False` is the default `torch.compile` setting, the semantics of resuming compilation upon encountering a graph break are more complicated.
|
||||
You can find details on the `fullgraph=False` semantics in the subsections.
|
||||
|
||||
The strategy for using `torch.compile(fullgraph=False)` is as follows:
|
||||
|
||||
1. [Determine the ideal location to place `torch.compile`](programming_model.where_to_apply_compile). Normally, it is the highest-level function that doesn’t result in excessive graph breaks.
|
||||
Functions that do a lot of preprocessing or I/O operations are examples of functions that result in many graph breaks and do not significantly benefit from `torch.compile`.
|
||||
a. You can isolate issues by first compiling individual functions/modules before compiling entire models.
|
||||
2. [Apply `torch.compiler.disable` to functions in the compiled region that result in a lot of graph breaks
|
||||
and do not benefit from compilation](programming_model.compiler_disable). In this case, one graph break is better than potentially tens or hundreds.
|
||||
3. [Use `TORCH_LOGS="graph_breaks"` or tlparse to investigate remaining graph breaks.](programming_model.observability)
|
||||
Work around these graph breaks using the same approaches as working around graph breaks under
|
||||
the `fullgraph=True` programming model. Not all graph breaks need to be removed - some may
|
||||
impact performance more than others. The general rule is to focus on graph breaks that are happening during model computation.
|
||||
a. We recommend using `torch.compile(backend='eager')` when debugging graph breaks, for faster debugging iteration times
|
||||
|
||||
|
||||
```{toctree}
|
||||
programming_model.where_to_apply_compile
|
||||
programming_model.compiler_disable
|
||||
programming_model.nested_graph_breaks
|
||||
programming_model.skipped_functions
|
||||
```
|
||||
247
docs/source/compile/programming_model.fullgraph_true.md
Normal file
247
docs/source/compile/programming_model.fullgraph_true.md
Normal file
@ -0,0 +1,247 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
```
|
||||
|
||||
# Use `fullgraph=True` to Identify and Eliminate Graph Breaks
|
||||
|
||||
Using `torch.compile(fullgraph=False)` (the default) is a good way to get started with `torch.compile`: it supports all Python programs out-of-the-box via the ability to graph break and gives good performance on common cases.
|
||||
|
||||
However, if you're trying to get more performance out of your model, you should explicitly think about what regions of code should be compiled:
|
||||
- We recommend using `torch.compile(fullgraph=True)` to find and eliminate graph breaks in your code.
|
||||
- If you're a library developer (or testing if your code "works" with `torch.compile`), we recommend testing using `torch.compile(fullgraph=True)`.
|
||||
|
||||
`torch.compile(fullgraph=True)` offers stronger guarantees over `fullgraph=False`:
|
||||
we will always capture a single FX graph to be compiled (or error if we cannot due to a graph break).
|
||||
**In particular, you are forced to resolve every graph break that is encountered.**
|
||||
|
||||
There are a number of strategies for resolving a graph break.
|
||||
|
||||
## Strategy 1: Rewrite the unsupported code to use features supported by Dynamo
|
||||
|
||||
Many graph break error messages will give some suggestions on how to rewrite code to avoid the graph break.
|
||||
If the graph break is still difficult to resolve, then please move on to the next strategy
|
||||
or submit an issue to the [PyTorch GitHub repo](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
More graph break examples and how to resolve them can be found in [Common Graph Breaks](programming_model.common_graph_breaks).
|
||||
|
||||
Example: Dynamo does not support calling `next` on a `list_iterator` object that was an input to the function being compiled.
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(xs):
|
||||
a = next(xs)
|
||||
b = next(xs)
|
||||
return a + b
|
||||
|
||||
xs = [torch.tensor(1.), torch.tensor(2.)]
|
||||
try:
|
||||
out = f(iter(xs))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
Instead, rewrite the compiled function to accept a list.
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(fullgraph=True)
|
||||
def f_rewritten(xs):
|
||||
it = iter(xs)
|
||||
a = next(it)
|
||||
b = next(it)
|
||||
return a + b
|
||||
|
||||
f_rewritten(xs)
|
||||
```
|
||||
|
||||
## Strategy 2: Pure functions can always be compiled via an escape hatch.
|
||||
|
||||
**Summary**: The space of all Python functions is vast and thus it is impractical for Dynamo to be able to trace
|
||||
through every Python function without graph breaks. For Python functions considered to be "pure"
|
||||
that Dynamo cannot trace through without graph breaks, we provide some escape hatches to attempt
|
||||
to trace through these functions anyway:
|
||||
|
||||
1. Use `custom_op` or `triton_op` on pure triton kernels.
|
||||
2. Use `nonstrict_trace` for pure functions that only use PyTorch Tensor ops.
|
||||
3. Use `custom_op` for all other pure functions.
|
||||
|
||||
A "pure function" is a function with the following properties:
|
||||
|
||||
- Determinism. Given the same inputs, the pure function will always return the same output
|
||||
- No external side effects. A pure function does not have any externally-visible side effects,
|
||||
such as modifying external state or performing I/O operations.
|
||||
Side effects that remain internal to the function are allowed (e.g. mutating intermediate tensors).
|
||||
One notable exception is that mutating `torch.*` ops on function input Tensors are generally allowed.
|
||||
- Explicit input/output. All the input data must be passed through the function parameters and all of the outputs are returned from the function.
|
||||
|
||||
See [Pure Functions](programming_model.non_strict_tracing_model.pure_functions) for examples.
|
||||
|
||||
Dynamo is theoretically able to handle a wide variety of impure functions, but may be lacking coverage for specific
|
||||
Python language features. However, pure functions can always be compiled via an escape hatch.
|
||||
|
||||
If you have a graph break it may be possible to refactor the code around it into a pure function and use an escape hatch that bypasses Dynamo tracing:
|
||||
|
||||
1. Use `torch._dynamo.nonstrict_trace` if you want the Tensor operations in the function to show up in the Dynamo output graph (and therefore be optimizable). `nonstrict_trace` tells Dynamo to use **non-strict tracing**.
|
||||
2. Use custom operators if you want the function to be opaque w.r.t. to `torch.compile` (both the frontend Dynamo and the backend).
|
||||
|
||||
Note that there is nothing preventing these escape hatches from being applied to impure functions,
|
||||
but **we do not provide any soundness guarantees**.
|
||||
|
||||
Example: If Dynamo doesn't support some Python feature or API that is non-strict traceable (e.g. it uses PyTorch operations), [use `torch._dynamo.nonstrict_trace` to capture it instead](programming_model.dynamo_nonstrict_trace).
|
||||
|
||||
```{code-cell}
|
||||
# this is a function that Dynamo doesn't support (due to the graph_break() call).
|
||||
def g(x):
|
||||
y = x.sin()
|
||||
torch._dynamo.graph_break()
|
||||
z = y.sin()
|
||||
return z
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
w = x.sin()
|
||||
return g(w)
|
||||
|
||||
x = torch.randn(3)
|
||||
try:
|
||||
f(x) # Graph Break: there was a call to torch._dynamo.graph_break()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def f_rewritten(x):
|
||||
w = x.sin()
|
||||
return torch._dynamo.nonstrict_trace(g)(w)
|
||||
f_rewritten(x) # works
|
||||
```
|
||||
|
||||
Example: use [custom operators](programming_model.custom_ops) to create opaque functions w.r.t. to `torch.compile`
|
||||
|
||||
```{code-cell}
|
||||
from torch.utils.cpp_extension import load_inline
|
||||
|
||||
# C++ source code for the square operation
|
||||
cpp_source = """
|
||||
torch::Tensor square_cpu(torch::Tensor input) {
|
||||
// Check that input is a CPU tensor
|
||||
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
|
||||
|
||||
// Create output tensor with same shape and dtype as input
|
||||
torch::Tensor output = torch::empty_like(input);
|
||||
|
||||
// Get data pointers
|
||||
float* input_data = input.data_ptr<float>();
|
||||
float* output_data = output.data_ptr<float>();
|
||||
|
||||
// Get total number of elements
|
||||
int64_t numel = input.numel();
|
||||
|
||||
// For loop to compute square of each element
|
||||
for (int64_t i = 0; i < numel; i++) {
|
||||
output_data[i] = input_data[i] * input_data[i];
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
"""
|
||||
|
||||
# Load the extension inline
|
||||
square_module = load_inline(
|
||||
name="square_cpu_kernel",
|
||||
cpp_sources=cpp_source,
|
||||
functions=["square_cpu"],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
def square(x):
|
||||
return square_module.square_cpu(x)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
return square(x)
|
||||
|
||||
try:
|
||||
f(torch.randn(3, 3)) # graph break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# Use torch.library.custom_op to define a new custom operator.
|
||||
# Custom operators are opaque with respect to torch.compile:
|
||||
# that is, torch.compile does not peek into them.
|
||||
|
||||
@torch.library.custom_op("mylib::square", mutates_args=())
|
||||
def square(x: torch.Tensor) -> torch.Tensor:
|
||||
return square_module.square_cpu(x)
|
||||
|
||||
# Use register_fake to add a ``FakeTensor`` kernel for the operator
|
||||
@square.register_fake
|
||||
def _(x):
|
||||
return x.new_empty(x.size())
|
||||
|
||||
print(f(torch.randn(3, 3))) # no graph break
|
||||
```
|
||||
|
||||
For more information on `triton_op` for custom triton kernels, see the
|
||||
[user-defined triton kernel tutorial](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html).
|
||||
|
||||
|
||||
## Strategy 3: Don't compile the code
|
||||
|
||||
Not all code is amenable to being compiled. `torch.compile` is a compiler for Tensor computation;
|
||||
it will not be able to optimize things like disk IO. Try to refactor the code such that the unsupported
|
||||
code is not called in the compiled region.
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
y = x ** 2 / 2
|
||||
torch.save(y, "foo.pt")
|
||||
z = y ** 3 / 6
|
||||
return z
|
||||
|
||||
x = torch.randn(3)
|
||||
try:
|
||||
f(x) # Graph Break: torch.save not supported
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
def f_rewritten(x):
|
||||
y = g(x)
|
||||
torch.save(y, "foo.pt")
|
||||
z = h(y)
|
||||
return z
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def g(x):
|
||||
y = x ** 2 / 2
|
||||
return y
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def h(y):
|
||||
z = y ** 3 / 6
|
||||
return z
|
||||
|
||||
f_rewritten(x)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import os
|
||||
os.remove("foo.pt")
|
||||
```
|
||||
21
docs/source/compile/programming_model.graph_breaks_index.md
Normal file
21
docs/source/compile/programming_model.graph_breaks_index.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Working with Graph Breaks
|
||||
|
||||
As you might remember from (Dynamo Core Concepts)[programming_model.dynamo_core_concepts] that Dynamo performs a graph break when
|
||||
it encounters code that can't be traced. In the default `torch.compile` settings, Dynamo compiles the FX graph
|
||||
that has been determined up to that point, executes the unsupported code in regular Python, and then resumes tracing.
|
||||
|
||||
Graph breaks enable Dynamo to trace through arbitrary Python code and carve out functional
|
||||
subgraphs that can each be individually optimized.
|
||||
|
||||
However, graph breaks may cause unexpected slowness in `torch.compile`.
|
||||
If you're not seeing the expected speedups, we recommend checking for graph breaks and removing them.
|
||||
|
||||
The following sections outline strategies for addressing graph breaks.
|
||||
|
||||
```{toctree}
|
||||
programming_model.fullgraph_true
|
||||
programming_model.common_graph_breaks
|
||||
programming_model.dynamo_nonstrict_trace
|
||||
programming_model.custom_ops
|
||||
programming_model.fullgraph_false
|
||||
```
|
||||
16
docs/source/compile/programming_model.md
Normal file
16
docs/source/compile/programming_model.md
Normal file
@ -0,0 +1,16 @@
|
||||
# torch.compile Programming Model
|
||||
|
||||
The `torch.compile` programming model:
|
||||
1. Clarifies some internal behaviors of `torch.compile` so that one can better predict compiler behavior on user code and
|
||||
2. Provides ways for one to take more fine-grained control over `torch.compile`.
|
||||
|
||||
By understanding the `torch.compile` programming model, one can systematically unblock themselves when encountering issues with `torch.compile`.
|
||||
|
||||
```{toctree}
|
||||
programming_model.dynamo_core_concepts
|
||||
programming_model.graph_breaks_index
|
||||
programming_model.non_strict_tracing_model
|
||||
programming_model.recompilation
|
||||
programming_model.observability
|
||||
programming_model.reporting_issues
|
||||
```
|
||||
191
docs/source/compile/programming_model.nested_graph_breaks.md
Normal file
191
docs/source/compile/programming_model.nested_graph_breaks.md
Normal file
@ -0,0 +1,191 @@
|
||||
# Nested Graph Breaks
|
||||
|
||||
Summary:
|
||||
- Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
|
||||
- A nested graph break results in {math}`\mathcal O(N)` duplicate graph break behavior
|
||||
|
||||
Recall that when `torch.compile` is applied to a function, any nested function calls are also traced.
|
||||
A **nested graph break** refers to any graph break that happens in a nested function call.
|
||||
|
||||
```python
|
||||
def inner(x):
|
||||
...
|
||||
torch._dynamo.graph_break() # nested graph break
|
||||
...
|
||||
|
||||
@torch.compile
|
||||
def outer(x):
|
||||
...
|
||||
y = inner(x)
|
||||
...
|
||||
```
|
||||
|
||||
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
|
||||
|
||||
Recall that in `fullgraph=False`, [graph breaks are handled](programming_model.dynamo_core_concepts.graph_breaks) by compiling the FX graph that has been determined so far,
|
||||
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
|
||||
Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
|
||||
|
||||
We can therefore resume tracing after a nested graph break with this restriction in the following way:
|
||||
|
||||
First, consider the below example where `torch.compile` traces from `f` and traces all the way until the
|
||||
graph break in `inner1` is encountered.
|
||||
|
||||
```python
|
||||
def inner1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
def inner2(x):
|
||||
x = x + 4
|
||||
x = inner1(x)
|
||||
x = x + 8
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
# start tracing from here
|
||||
x = x + 16
|
||||
x = inner2(x)
|
||||
x = x + 32
|
||||
|
||||
f(torch.randn(3))
|
||||
```
|
||||
|
||||
Since we can only resume from top-level functions, we graph break on the `inner2` call in `f`.
|
||||
```python
|
||||
# The semantics of torch.compile(f)(x) is roughly this:
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = inner2(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
`inner2` is then automatically compiled as a top-level function.
|
||||
We trace all the way until the graph break in `inner1` is encountered again.
|
||||
|
||||
```python
|
||||
def inner1(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
# this torch.compile is automatically applied
|
||||
@torch.compile
|
||||
def inner2(x):
|
||||
# start tracing from here
|
||||
x = x + 4
|
||||
x = inner1(x)
|
||||
x = x + 8
|
||||
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = inner2(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
Then we graph break on the `inner1` call in `inner2`.
|
||||
```python
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = inner1(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
```
|
||||
|
||||
`inner1` is then automatically compiled as a top-level function.
|
||||
The graph break is from `inner1`, so we handle the graph break normally.
|
||||
```python
|
||||
# this torch.compile is automatically applied
|
||||
@torch.compile
|
||||
def inner1(x):
|
||||
# start tracing from here
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break() # stop tracing due to graph break
|
||||
return x + 2
|
||||
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = compiled_inner2_semantics(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = inner1(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
`inner1` is handled normally:
|
||||
|
||||
```python
|
||||
def compiled_inner1_semantics(x):
|
||||
y = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return torch.compile(resume_inner1_semantics)(y)
|
||||
|
||||
def resume_inner1_semantics(x):
|
||||
return x + 2
|
||||
```
|
||||
|
||||
So the initial code is semantically equivalent to
|
||||
```python
|
||||
def compiled_f_semantics(x):
|
||||
y = x + 16
|
||||
z = compiled_inner2_semantics(y)
|
||||
return torch.compile(resume_f_semantics)(z)
|
||||
|
||||
def resume_f_semantics(x):
|
||||
return x + 32
|
||||
|
||||
def compiled_inner2_semantics(x):
|
||||
y = x + 4
|
||||
z = compiled_inner1_semantics(y)
|
||||
return torch.compile(resume_inner2_semantics)(z)
|
||||
|
||||
def resume_inner2_semantics(x):
|
||||
return x + 8
|
||||
|
||||
def compiled_inner1_semantics(x):
|
||||
y = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return torch.compile(resume_inner1_semantics)(y)
|
||||
|
||||
def resume_inner1_semantics(x):
|
||||
return x + 2
|
||||
|
||||
compiled_f_semantics(torch.randn(3))
|
||||
```
|
||||
|
||||
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.
|
||||
**This explains why you may encounter duplicate graph breaks when using `torch.compile`.**
|
||||
|
||||
In summary, nested graph breaks are handled by:
|
||||
- Tracing from the top-level function all the way to the nested graph break
|
||||
- Graph breaking on the top-level function at the call to the second-level function
|
||||
- Compiling the PyTorch ops tracked so far and running the compiled graph
|
||||
- Calling the second-level function, which gets automatically compiled as a top-level function
|
||||
- Resuming tracing after the second-level function call
|
||||
|
||||
Note that the runtime of handling this graph break is {math}`\mathcal O(NK)`, where {math}`N` is the nesting depth,
|
||||
and {math}`K` is the number of instructions from the top-level function to the graph break.
|
||||
We end up tracing {math}`\mathcal O(N^2)` frames, and we trace the same graph break {math}`\mathcal O(N)` times.
|
||||
@ -0,0 +1,204 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
```
|
||||
|
||||
# Non-strict Tracing Programming Model
|
||||
|
||||
**Summary:**
|
||||
- **Non-strict tracing** is a way to trace Python code that is less strict than Dynamo, but may result in silent incorrectness.
|
||||
- Non-strict tracing runs a Python function and uses Python and PyTorch’s operator overloading capabilities to record what Tensor operations occurred during execution into a trace.
|
||||
- A function is **non-strict traceable** if it complies with some constraints, namely, that the function is **pure** and does not directly manipulate Tensor.data_ptr().
|
||||
- Non-strict tracing may **specialize** on certain variables and treat them as **constants**, baking the values of the variables into the trace.
|
||||
|
||||
`torch.compile` internals (`make_fx`, AOTDispatcher) use **non-strict tracing**. [`torch._dynamo.nonstrict_trace`](programming_model.dynamo_nonstrict_trace) can also be used in `torch.compile`d code to mark sections of code to be traced with non-strict tracing.
|
||||
Non-strict tracing runs a Python function and uses Python and PyTorch’s operator overloading capabilities to record what Tensor operations occurred during execution into a trace.
|
||||
|
||||
**`make_fx`** is the main entrypoint for non-strict tracing. For the following function, only the top branch is taken during execution of the inputs, so it captures a graph with only that branch.
|
||||
|
||||
```{code-cell}
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
def f(x):
|
||||
if x.shape[0] > 2:
|
||||
return x ** 2 / 6
|
||||
else:
|
||||
return x * 3
|
||||
x = torch.randn(3)
|
||||
gm = make_fx(f, tracing_mode="fake")(x)
|
||||
gm.print_readable()
|
||||
```
|
||||
|
||||
Non-strict tracing differs from Dynamo (strict) tracing in that **it is unsafe**, that is, given a function, it captures a graph of Tensor operations that may have different semantics than the original function.
|
||||
Given a Python function, Dynamo Tracing captures a graph of Tensor operations and residual bytecode that when combined give the same semantics as the Python function.
|
||||
|
||||
(programming_model.non_strict_tracing_model.pure_functions)=
|
||||
|
||||
## Pure Functions
|
||||
|
||||
Non-strict tracing is sound only on **pure functions**, and thus only pure functions should be non-strict traced.
|
||||
|
||||
A pure function is a function with the following properties:
|
||||
|
||||
- **Determinism.** Given the same inputs, the pure function will always return the same output.
|
||||
- **No side effects.** A pure function does not have any side effects such as modifying external state or performing I/O operations.
|
||||
- **Explicit input/output.** All the input data must be passed through the function parameters and all of the outputs are returned from the function.
|
||||
|
||||
Here are some examples of impure functions for which the captured graph behaves differently from the original function.
|
||||
|
||||
### Example 1: No explicit input (e.g. accesses global tensor)
|
||||
```{code-cell}
|
||||
var = torch.tensor(1)
|
||||
def function_with_global_access(y):
|
||||
return y + var
|
||||
x = torch.tensor([0, 1, 2])
|
||||
# _allow_non_fake_inputs=True is needed to capture the global variable
|
||||
# for demonstration purposes.
|
||||
gm = make_fx(
|
||||
function_with_global_access, tracing_mode="fake", _allow_non_fake_inputs=True
|
||||
)(x)
|
||||
# Non-strict Tracing captures the value of the global (1.)
|
||||
print("1. call function", function_with_global_access(x))
|
||||
print("1. call graph", gm(x))
|
||||
# However, after changing the global, the captured graph
|
||||
# produces a different result from the original function
|
||||
var = torch.tensor(2)
|
||||
print("2. call function", function_with_global_access(x))
|
||||
print("2. call graph", gm(x))
|
||||
# To capture a graph that can have a varying `var` tensor,
|
||||
# it must be an explicit input:
|
||||
def function_fixed(y, var):
|
||||
return y + var
|
||||
var = torch.tensor(3)
|
||||
gm = make_fx(function_fixed, tracing_mode="fake")(x, var)
|
||||
print("3. call function", function_fixed(x, var))
|
||||
print("3. call graph", gm(x, var))
|
||||
var = torch.tensor(4)
|
||||
print("4. call function", function_fixed(x, var))
|
||||
print("4. call graph", gm(x, var))
|
||||
```
|
||||
|
||||
See [Specialization and Constants](specialization-and-constants) for an explanation of why.
|
||||
|
||||
### Example 2: Side effect (printing)
|
||||
|
||||
```{code-cell}
|
||||
def function_with_side_effect(y):
|
||||
print(y)
|
||||
x = torch.tensor([0, 1, 2])
|
||||
_ = function_with_side_effect(x)
|
||||
```
|
||||
|
||||
Running `f` in Python prints a Tensor as a side effect.
|
||||
|
||||
```{code-cell}
|
||||
gm = make_fx(function_with_side_effect, tracing_mode="fake")(x)
|
||||
```
|
||||
|
||||
During non-strict tracing, this print occurs during the graph capture.
|
||||
|
||||
```{code-cell}
|
||||
_ = gm(x)
|
||||
```
|
||||
|
||||
The graph does not store a call to the `print` statement, so executing the graph doesn’t print anything.
|
||||
|
||||
### Example 3: Side effect (input list mutation)
|
||||
|
||||
```{code-cell}
|
||||
lst = []
|
||||
def function_with_input_list_mutation(lst):
|
||||
val = lst.pop()
|
||||
return val
|
||||
x = torch.tensor([0, 1, 2])
|
||||
y = torch.tensor([0, 1, 2])
|
||||
# Each time the function is executed, the list shrinks in size
|
||||
lst = [x, y]
|
||||
function_with_input_list_mutation(lst)
|
||||
print("len(lst) after one call", len(lst))
|
||||
function_with_input_list_mutation(lst)
|
||||
print("len(lst) after two calls", len(lst))
|
||||
# With Non-strict Tracing, the length of the list shrinks during
|
||||
# the graph capture but not in invocations of the graph.
|
||||
lst = [x, y]
|
||||
gm = make_fx(function_with_input_list_mutation, tracing_mode="fake")(lst)
|
||||
print("len(lst) after graph capture", len(lst))
|
||||
gm(lst)
|
||||
print("len(lst) after one call to graph", len(lst))
|
||||
gm(lst)
|
||||
print("len(lst) after two calls to graph", len(lst))
|
||||
```
|
||||
|
||||
### No direct data_ptr manipulation
|
||||
Directly manipulating `Tensor.data_ptr` is not non-strict traceable. The intuition behind this is that PyTorch is unable to tell *how* you manipulated the `data_ptr`.
|
||||
|
||||
```{code-cell}
|
||||
import ctypes
|
||||
# Create a tensor with a single element
|
||||
tensor = torch.tensor([42], dtype=torch.int32) # Using int32 for simplicity
|
||||
def function_with_data_ptr(tensor):
|
||||
# Get the data pointer
|
||||
ptr = tensor.data_ptr()
|
||||
# Cast the pointer to a ctypes pointer
|
||||
ctypes_ptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_int32))
|
||||
# Increment the value at the pointer
|
||||
ctypes_ptr.contents.value += 1
|
||||
return tensor
|
||||
try:
|
||||
make_fx(function_with_data_ptr, tracing_mode="fake")(tensor)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
```
|
||||
|
||||
(specialization-and-constants)=
|
||||
## Specialization and Constants
|
||||
|
||||
Non-strict tracing captures a graph that may be specialized on some values. What this means is the captured graph is only valid for these values. We say the graph treats those values as **constant**.
|
||||
|
||||
All non-Tensor variables are treated as constant during Non-strict Tracing:
|
||||
|
||||
```{code-cell}
|
||||
def f(x, y):
|
||||
return x + y
|
||||
x = torch.tensor([0, 1, 2])
|
||||
y = 3.14
|
||||
gm = make_fx(f, tracing_mode="fake")(x, y)
|
||||
gm.print_readable()
|
||||
```
|
||||
|
||||
3.14 is a constant in the graph.
|
||||
|
||||
Non-strict tracing will also specialize on properties of the input Tensors.
|
||||
|
||||
```{code-cell}
|
||||
def f(x):
|
||||
if x.shape[0] > 2:
|
||||
return x ** 2 / 6
|
||||
else:
|
||||
return x * 3
|
||||
x = torch.randn(3)
|
||||
gm = make_fx(f, tracing_mode="fake")(x)
|
||||
gm.print_readable()
|
||||
```
|
||||
|
||||
And it will also specialize on any variables not directly passed into the function:
|
||||
|
||||
```{code-cell}
|
||||
var = torch.tensor(1)
|
||||
def f(x):
|
||||
return x + y
|
||||
x = torch.randn(3)
|
||||
gm = make_fx(f, tracing_mode="fake")(x)
|
||||
gm.print_readable()
|
||||
```
|
||||
141
docs/source/compile/programming_model.observability.md
Normal file
141
docs/source/compile/programming_model.observability.md
Normal file
@ -0,0 +1,141 @@
|
||||
# tlparse / TORCH_TRACE
|
||||
|
||||
tlparse / `TORCH_TRACE` are a pair of tools that produce compilation reports that look [like this](https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html).
|
||||
|
||||
Traces are fairly straightforward to collect. To collect a trace, run your model like so:
|
||||
|
||||
```bash
|
||||
TORCH_TRACE="/tmp/tracedir" python foo.py
|
||||
pip install tlparse
|
||||
tlparse /tmp/tracedir
|
||||
```
|
||||
|
||||
This approach works even if you are running a distributed job, providing a trace for each rank.
|
||||
It will open your browser with HTML similar to what’s generated above.
|
||||
If you are making a bug report for a complicated problem that you don’t have a standalone reproduction for,
|
||||
you can still greatly assist PyTorch developers by attaching the trace log generated in `/tmp/tracedir`.
|
||||
|
||||
```{warning}
|
||||
The trace log contains all of your model code.
|
||||
Do not share the trace log if the model you are working on is sensitive. The trace log does NOT contain weights.
|
||||
```
|
||||
|
||||
```{raw} html
|
||||
<style>
|
||||
.red {background-color:#ff0000;}
|
||||
.green {background-color:#00ff00;}
|
||||
.dark-green {background-color:#027f02;}
|
||||
</style>
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. role:: red
|
||||
.. role:: green
|
||||
.. role:: dark-green
|
||||
```
|
||||
|
||||
The output of `tlparse` is primarily aimed for PyTorch developers,
|
||||
and the log format is easy to upload and share on GitHub.
|
||||
However, as a non-PyTorch developer, you can still extract useful information from it.
|
||||
We recommend starting with the inline help text in the report, which explains its contents.
|
||||
Here are some insights you can gain from a `tlparse`:
|
||||
|
||||
- What model code was compiled by looking at the stack trie?
|
||||
This is especially useful if you're not familiar with the codebase being compiled!
|
||||
- How many graph breaks / distinct compilation regions are there?
|
||||
(Each distinct compile is its own color coded block like {dark-green}`[0/0]`).
|
||||
Frames that are potentially graph-broken are light green {green}`[2/4]`.
|
||||
If there are a lot of frames, that is suspicious, and suggests that you had some catastrophic graph breaks,
|
||||
or maybe your code isn't a good match for `torch.compile`.
|
||||
- How many times did I recompile a particular frame? Something that recompiled a lot will look like:
|
||||
{dark-green}`[10/0]` {dark-green}`[10/1]` {dark-green}`[10/2]`
|
||||
\- if something is being recompiled a lot, that is very suspicious and worth looking into, even if it isn't the root cause of your problem.
|
||||
- Was there a compilation error? Frames that errored will look like {red}`[0/1]`.
|
||||
- What intermediate compiler products did I generate for a given frame?
|
||||
For example, you can look at the high-level generated FX graph or the generated Triton code.
|
||||
- Is there relevant information for a particular frame? You can find these in `compilation_metrics`.
|
||||
|
||||
## TORCH_LOGS
|
||||
|
||||
You can use the `TORCH_LOGS` environment variable to selectively enable parts of the `torch.compile` stack to log.
|
||||
`TORCH_LOGS` is in fact the source of logs for `tlparse`. The format of the `TORCH_LOGS` environment variable looks like this:
|
||||
|
||||
```bash
|
||||
TORCH_LOGS="<option1>,<option2>,..." python foo.py
|
||||
```
|
||||
|
||||
You can also programmatically set logging options using `torch._logging.set_logs`:
|
||||
|
||||
```python
|
||||
import logging
|
||||
torch._logging.set_logs(graph_breaks=True, dynamic=logging.DEBUG)
|
||||
```
|
||||
|
||||
The most useful options are:
|
||||
|
||||
- `graph_breaks`: logs locations of graph breaks in user code and the reason for the graph break
|
||||
- `guards`: logs guards that are generated
|
||||
- `recompiles`: logs which function recompiled and the guards that failed, leading to the recompilation
|
||||
- `dynamic`: logs related to dynamic shapes
|
||||
- `output_code`: logs the code generated by Inductor
|
||||
|
||||
Some more helpful `TORCH_LOGS` options include:
|
||||
|
||||
```{eval-rst}
|
||||
.. list-table::
|
||||
:widths: 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Option
|
||||
- Description
|
||||
* - +all
|
||||
- Output debug logs from all ``torch.compile`` components
|
||||
* - +dynamo
|
||||
- Output debug logs from TorchDynamo
|
||||
* - +aot
|
||||
- Output debug logs from AOTAutograd
|
||||
* - +inductor
|
||||
- Output debug logs from TorchInductor
|
||||
* - dynamic
|
||||
- Output logs from dynamic shapes
|
||||
* - graph_code
|
||||
- Output the Python code for the FX graph that Dynamo generated
|
||||
* - graph_sizes
|
||||
- Output the tensor sizes of the FX graph that Dynamo generated
|
||||
* - trace_bytecode
|
||||
- Output the bytecode instructions that Dynamo is tracing through and the symbolic interpreter stack Dynamo is keeping track of
|
||||
* - trace_source
|
||||
- Output the line of code in the original source that Dynamo is currently tracing through
|
||||
* - bytecode
|
||||
- Output Dynamo-generated bytecode
|
||||
* - guards
|
||||
- Output generated guards
|
||||
* - recompiles
|
||||
- Output recompilation reasons (only the first guard check that fails)
|
||||
* - recompiles_verbose
|
||||
- Output all guard checks that fail when a recompilation occurs
|
||||
* - aot_graphs
|
||||
- Output graph generated by AOTAutograd
|
||||
* - aot_joint_graphs
|
||||
- Output the joint forward-backward graph generated by AOTAutograd
|
||||
* - output_code
|
||||
- Output code generated by Inductor
|
||||
* - kernel_code
|
||||
- Output code generated by Inductor on a per-kernel basis
|
||||
* - schedule
|
||||
- Output Inductor scheduling logs
|
||||
* - perf_hints
|
||||
- Output Inductor perf hint logs
|
||||
* - fusion
|
||||
- Output Inductor fusion logs
|
||||
```
|
||||
|
||||
For the full list of options, see [torch.\_logging](https://pytorch.org/docs/stable/logging.html)
|
||||
and [torch.\_logging.set_logs](https://pytorch.org/docs/stable/generated/torch._logging.set_logs.html#torch._logging.set_logs).
|
||||
|
||||
## tlparse vs. TORCH_LOGS
|
||||
|
||||
Generally, we suggest first using `tlparse` when encountering issues.
|
||||
`tlparse` is ideal for debugging large models and gaining a high-level overview of how your model was compiled.
|
||||
On the other hand, `TORCH_LOGS` is preferred for small examples and fine-grained debugging detail,
|
||||
when we already have an idea of which `torch.compile` component is causing the problem.
|
||||
161
docs/source/compile/programming_model.recompilation.md
Normal file
161
docs/source/compile/programming_model.recompilation.md
Normal file
@ -0,0 +1,161 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
|
||||
torch._logging.set_logs(recompiles=True)
|
||||
```
|
||||
|
||||
# Dealing with Recompilations
|
||||
|
||||
Recompilations are necessary for `torch.compile` soundness, but can result in significantly increased compile time.
|
||||
Thus, minimizing recompilations while preserving soundness is essential for reducing compile time.
|
||||
|
||||
You can view recompilations and their reasons using tlparse or `TORCH_LOGS=recompiles`.
|
||||
|
||||
## Is Dynamic Shapes Enabled?
|
||||
|
||||
In the below example, we recompile due to mismatched shapes:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x + 1
|
||||
fn(torch.ones(3))
|
||||
fn(torch.ones(4))
|
||||
```
|
||||
|
||||
Make sure that the dynamic option of `torch.compile` is not set to `False`.
|
||||
The default option, `dynamic=None`, will only attempt dynamic shapes after the first compilation.
|
||||
You can set `dynamic=True` to upfront compile as dynamic as possible:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(dynamic=True)
|
||||
def gn(x):
|
||||
return x + 1
|
||||
gn(torch.ones(3))
|
||||
gn(torch.ones(4))
|
||||
```
|
||||
|
||||
For more information on dynamic shapes, including dealing with errors/recompilations due to
|
||||
dynamic shapes, see [the dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit?tab=t.0#heading=h.fh8zzonyw8ng).
|
||||
|
||||
## Wrapping Constants with Tensors
|
||||
By default, `int` / `float` variables are treated as constants and are guarded on their exact value.
|
||||
In the below example, we have a recompilation for each function call.
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x, c):
|
||||
return x + c
|
||||
for i in range(5):
|
||||
fn(torch.ones(i), 0.5 + i)
|
||||
```
|
||||
|
||||
In particular, for LR schedulers, initializing with a constant can lead to recompilations:
|
||||
|
||||
```{code-cell}
|
||||
mod = torch.nn.Linear(3, 3)
|
||||
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
|
||||
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
|
||||
@torch.compile
|
||||
def gn(inp):
|
||||
opt.zero_grad(True)
|
||||
out = mod(inp).sum()
|
||||
out.backward()
|
||||
opt.step()
|
||||
sched.step()
|
||||
for i in range(5):
|
||||
gn(torch.ones(3, 3))
|
||||
```
|
||||
|
||||
In both examples, we can wrap `float` variables in tensors in order to prevent recompilations.
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
torch._dynamo.reset()
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
# first example
|
||||
for i in range(5):
|
||||
fn(torch.ones(i), torch.tensor(0.5 + i))
|
||||
# second example
|
||||
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
|
||||
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
|
||||
for i in range(5):
|
||||
gn(torch.ones(3, 3))
|
||||
```
|
||||
|
||||
(programming_model.recompilation.changing_cache_size_limit)=
|
||||
## Changing the Cache Size Limit
|
||||
|
||||
There is a limit to how many times a function can be recompiled,
|
||||
determined by `torch._dynamo.config.cache_size_limit` and `torch._dynamo.config.accumulated_cache_size_limit`
|
||||
(The exact difference between these 2 values is detailed in [`torch/_dynamo/cache_size.py`](https://github.com/pytorch/pytorch/blob/4ce6e6ec8890a3f6ee604c9efb3ff153825ce575/torch/_dynamo/cache_size.py#L14)).
|
||||
If the Dynamo cache limit is hit, then all future compilation attempts **will result in the function being skipped (run eagerly)**.
|
||||
Dynamo will still attempt to use previously compiled bytecode for future function calls, if the guards pass.
|
||||
Note that in the case of a recompilation limit hit, **all nested function calls WILL be skipped**
|
||||
(Dynamo will try to use previously compiled bytecode for the nested functions).
|
||||
Dynamo will also issue a warning containing the affected function and which limit was hit.
|
||||
In the example below, each function call results in a recompile attempt.
|
||||
When we hit the cache size limit (by default, 8), we stop attempting to recompile.
|
||||
(Note that we set `dynamic=False` for demonstration purposes to force recompilation every time).
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile(dynamic=False)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
for i in range(1, 10):
|
||||
# recompile every time due to dynamic=False
|
||||
fn(torch.ones(i))
|
||||
```
|
||||
|
||||
If you know that the number of recompilations has a reasonable constant upper bound, you can raise the cache size limit.
|
||||
If the cost of recompilation outweighs the benefit of compilation, then you can consider lowering the cache size limit.
|
||||
|
||||
```{code-cell}
|
||||
torch._dynamo.config.cache_size_limit = 16
|
||||
@torch.compile(dynamic=False)
|
||||
def gn(x):
|
||||
return x + 1
|
||||
for i in range(1, 10):
|
||||
gn(torch.ones(i))
|
||||
```
|
||||
|
||||
## Graph Breaking to Reduce Recompilation Costs
|
||||
If a large graph is recompiling and causing high compile time, you can intentionally introduce
|
||||
a graph break in order to reduce recompilation costs, at the expense of introducing a performance hit.
|
||||
|
||||
```{code-cell}
|
||||
def very_large_function(x):
|
||||
return x + 1
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def fn(x, c):
|
||||
y = very_large_function(x) # recompiled every time
|
||||
return y + c
|
||||
|
||||
for i in range(1, 5):
|
||||
fn(torch.ones(3), i)
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def gn(x, c):
|
||||
y = very_large_function(x) # compiled only once
|
||||
torch._dynamo.graph_break()
|
||||
return y + c # recompiled every time
|
||||
|
||||
for i in range(1, 5):
|
||||
gn(torch.ones(3), i)
|
||||
```
|
||||
73
docs/source/compile/programming_model.reporting_issues.md
Normal file
73
docs/source/compile/programming_model.reporting_issues.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Reporting Issues
|
||||
|
||||
If the provided workarounds were not enough to get `torch.compile` working,
|
||||
then you should consider reporting the issue to PyTorch.
|
||||
But there are a few things that you can do to make our lives significantly easier.
|
||||
|
||||
## Ablation
|
||||
|
||||
Check which component of the `torch.compile` stack is the one causing the issue using the `backend=` option for `torch.compile`.
|
||||
In particular, try:
|
||||
|
||||
- `torch.compile(fn, backend="eager")`, which only runs TorchDynamo, the graph capture component of `torch.compile`.
|
||||
- `torch.compile(fn, backend="aot_eager")`, which runs TorchDynamo and AOTAutograd, which additionally generates the backward graph during compilation.
|
||||
- `torch.compile(fn, backend="aot_eager_decomp_partition")`, which runs TorchDynamo and AOTAutograd with operator decompositions/partitions.
|
||||
- `torch.compile(fn, backend="inductor")`, which runs TorchDynamo, AOTAutograd, and TorchInductor, the backend ML compiler that generates compiled kernels.
|
||||
|
||||
If you only fail with the Inductor backend, you can additionally test various Inductor modes:
|
||||
|
||||
- `torch.compile(fn, backend="inductor", mode="default")`
|
||||
- `torch.compile(fn, backend="inductor", mode="reduce-overhead")`
|
||||
- `torch.compile(fn, backend="inductor", mode="max-autotune")`
|
||||
|
||||
You can also check if dynamic shapes is causing issues with any backend:
|
||||
|
||||
- `torch.compile(fn, dynamic=True)` (always use dynamic shapes)
|
||||
- `torch.compile(fn, dynamic=False)` (never use dynamic shapes)
|
||||
- `torch.compile(fn, dynamic=None)` (automatic dynamic shapes)
|
||||
|
||||
## Bisecting
|
||||
|
||||
Did you try on the latest nightly? Did something work in the past but now no longer works?
|
||||
Can you bisect to determine the first nightly where your issue occurs?
|
||||
Bisecting is especially helpful for performance, accuracy, or compile time regressions,
|
||||
where it is not immediately obvious where the problem originates from.
|
||||
|
||||
## Creating a reproducer
|
||||
|
||||
Creating reproducers is a lot of work, and it is perfectly fine if you do not have the time to do it.
|
||||
However, if you are a motivated user unfamiliar with the internals of `torch.compile`,
|
||||
creating a standalone reproducer can have a huge impact on our ability to fix the bug.
|
||||
Without a reproducer, your bug report must contain enough information for us to identify the root cause of the problem and write a reproducer from scratch.
|
||||
|
||||
Here's a list of useful reproducers, ranked from most to least preferred:
|
||||
|
||||
1. **Self-contained, small reproducer:** A script with no external dependencies, under 100 lines of code, that reproduces the problem when run.
|
||||
2. **Self-contained, large reproducer:** Even if it's large, being self-contained is a huge advantage!
|
||||
3. **Non-self-contained reproducer with manageable dependencies:**
|
||||
For example, if you can reproduce the problem by running a script after `pip install transformers`,
|
||||
that's manageable. We can likely run it and investigate.
|
||||
4. **Non-self-contained reproducer requiring substantial setup:** This might involve downloading datasets,
|
||||
multiple environment setup steps, or specific system library versions requiring a Docker image.
|
||||
The more complex the setup, the harder it is for us to recreate the environment.
|
||||
|
||||
:::{note}
|
||||
Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary.
|
||||
:::
|
||||
|
||||
If possible, try to make your reproducer single-process, as those are easier to debug than a multi-process reproducer.
|
||||
|
||||
Additionally, below is a non-exhaustive list of aspects to check in your
|
||||
issue that you can attempt to replicate in your reproducer:
|
||||
|
||||
- **Autograd**. Did you have tensor inputs with `requires_grad=True`? Did you call `backward()` on the output?
|
||||
- **Dynamic shapes**. Did you set `dynamic=True`? Or did you run the test code multiple times with varying shapes?
|
||||
- **Custom operators**. Is there a custom operator involved in the real workflow?
|
||||
Can you replicate some of its important characteristics using the Python custom operator API?
|
||||
- **Configuration**. Did you set all the same configuration?
|
||||
This includes `torch._dynamo.config` and `torch._inductor.config` settings,
|
||||
as well as arguments to `torch.compile` like `backend` / `mode`.
|
||||
- **Context managers**. Did you replicate any active context managers?
|
||||
This could be `torch.no_grad`, automatic mixed precision, `TorchFunctionMode` / `TorchDispatchMode`,
|
||||
activation checkpointing, compiled autograd etc.
|
||||
- **Tensor subclasses**. Is there a tensor subclass involved?
|
||||
199
docs/source/compile/programming_model.skipped_functions.md
Normal file
199
docs/source/compile/programming_model.skipped_functions.md
Normal file
@ -0,0 +1,199 @@
|
||||
---
|
||||
file_format: mystnb
|
||||
kernelspec:
|
||||
name: python3
|
||||
mystnb:
|
||||
execution_timeout: 30
|
||||
execution_show_tb: True
|
||||
merge_streams: True
|
||||
---
|
||||
|
||||
```{code-cell}
|
||||
:tags: [remove-cell]
|
||||
import torch
|
||||
|
||||
import header_code
|
||||
import logging
|
||||
torch._logging.set_logs(dynamo=logging.DEBUG)
|
||||
```
|
||||
|
||||
# Skipped Functions
|
||||
|
||||
**Summary:**
|
||||
- Sometimes, `torch.compile` completely gives up compiling a function and runs it eagerly instead,
|
||||
resulting in potentially lost optimization opportunities.
|
||||
- There are ways to work around skipped functions in order to re-enable tracing around the problematic code.
|
||||
|
||||
Sometimes, `torch.compile` with `fullgraph=False` is unable to resume tracing when encountering a graph break
|
||||
or other compiler error. In many of these cases, `torch.compile` will skip compiling the function entirely and run it eagerly.
|
||||
|
||||
Note that the skip is only applied to the current function and NOT any nested function calls.
|
||||
`torch.compile` will still attempt to compile nested calls.
|
||||
|
||||
<!-- TODO: fix logging for skipped functions. -->
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
return x + 1
|
||||
def inner2(x):
|
||||
return x + 2
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = inner1(x)
|
||||
torch._dynamo.skip_frame()
|
||||
x = inner2(x)
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In the above example, `torch.compile` will trace `fn` (including `inner1`) up until the `skip_frame`.
|
||||
Then `fn` is skipped and run eagerly - `inner1` and `inner2` are compiled when they are called.
|
||||
|
||||
Skipping functions may result in lost optimization opportunities,
|
||||
so it is important to check if code you want compiled is being skipped, and if so, to work around the skip.
|
||||
|
||||
## Graph Break in a Loop
|
||||
|
||||
`torch.compile` cannot resume tracing if a graph break occurs in a loop:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
for i in range(5):
|
||||
x = x + 1
|
||||
if i == 3:
|
||||
torch._dynamo.graph_break()
|
||||
return x
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In this example, we can avoid skipping by unrolling the loop:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
def inner(i):
|
||||
nonlocal x
|
||||
x = x + 1
|
||||
if i == 3:
|
||||
torch._dynamo.graph_break()
|
||||
inner(0)
|
||||
inner(1)
|
||||
inner(2)
|
||||
inner(3)
|
||||
inner(4)
|
||||
return x
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
In general, resolving the graph break causing the skip will also resolve the skip.
|
||||
|
||||
## Graph Break in a Context Manager
|
||||
|
||||
Another common example of an unresumable graph break is a graph break in most context managers:
|
||||
|
||||
```{code-cell}
|
||||
class CustomCtxManager:
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with CustomCtxManager():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
We can avoid skipping by moving the graph break outside of the context manager:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with CustomCtxManager():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
with CustomCtxManager():
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
There are some context managers where Dynamo can resume after a graph break.
|
||||
Some of these can be found in `supported_ctx_manager_classes` in `torch/_dynamo/variables/torch.py`.
|
||||
In general, any context manager represented by a `ContextWrappingVariable` subclass in
|
||||
`torch/_dynamo/variables/ctx_manager.py` support resuming after a graph break. For example:
|
||||
|
||||
```{code-cell}
|
||||
import contextlib
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
with contextlib.nullcontext():
|
||||
with torch.no_grad():
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
## Graph Break in a Try Block
|
||||
|
||||
A graph break in a try block cannot be resumed:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
try:
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
We can avoid skipping by moving the graph break outside of the try block:
|
||||
|
||||
```{code-cell}
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
try:
|
||||
x = x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
torch._dynamo.graph_break()
|
||||
try:
|
||||
return x + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
|
||||
## Hitting a Recompilation Limit
|
||||
See [Changing the Cache Size Limit.](programming_model.recompilation.changing_cache_size_limit)
|
||||
|
||||
## Compiler Errors
|
||||
Some compiler errors will result in skipped functions.
|
||||
Other compiler errors will result in a hard error rather than a skipped function.
|
||||
|
||||
## Dealing with Skipped Functions
|
||||
In general, you can resolve a skipped function by fixing the underlying graph break or error that
|
||||
is causing the function to be skipped.
|
||||
|
||||
If the graph break/error causing the skipped function is difficult to fix,
|
||||
then consider isolating the graph break/error in its own function so that minimal things are skipped.
|
||||
|
||||
```{code-cell}
|
||||
def inner1(x):
|
||||
return x + 1
|
||||
def inner2(x):
|
||||
return x + 2
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
x = inner1(x)
|
||||
def problematic_code():
|
||||
torch._dynamo.skip_frame()
|
||||
problematic_code()
|
||||
x = inner2(x)
|
||||
fn(torch.randn(3))
|
||||
```
|
||||
@ -0,0 +1,77 @@
|
||||
# Where to apply torch.compile?
|
||||
|
||||
We recommend applying `torch.compile` to the highest-level function that doesn’t cause excessive problems.
|
||||
Typically, it is:
|
||||
- your `train` or `eval` step with the optimizer but without the loop,
|
||||
- your top-level `nn.Module`
|
||||
- or some sub-`nn.Module`s.
|
||||
|
||||
`torch.compile` specifically doesn’t handle distributed wrapper modules like DDP or FSDP very well,
|
||||
so consider applying `torch.compile` to the inner module passed to the wrapper.
|
||||
|
||||
```python
|
||||
# inference
|
||||
model = ...
|
||||
model.compile()
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
out = model(inp)
|
||||
```
|
||||
|
||||
```python
|
||||
# training
|
||||
model = ...
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
|
||||
@torch.compile
|
||||
def train(mod, data):
|
||||
opt.zero_grad(True)
|
||||
pred = mod(data[0])
|
||||
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
train(model, inp)
|
||||
```
|
||||
|
||||
```python
|
||||
# DistributedDataParallel
|
||||
model = ...
|
||||
model.compile()
|
||||
model_ddp = DistributedDataParallel(model, ...)
|
||||
|
||||
for _ in range(N_ITERS):
|
||||
inp = ...
|
||||
out = model_ddp(inp)
|
||||
```
|
||||
|
||||
<!-- TODO add examples for specific model domains, compile(model) vs. model.compile()-->
|
||||
|
||||
## `compile(model)` vs `model.compile()`
|
||||
|
||||
Due to nuances to how `torch.compile` interacts with `nn.Module` instances,
|
||||
we advise using the `.compile()` method of `nn.Module` instances if you wish to compile them as
|
||||
top-level functions. Nested module calls will be traced correctly -
|
||||
there is no need to call `.compile()` in that case.
|
||||
|
||||
```python
|
||||
# DO NOT DO THIS
|
||||
model = MyModel()
|
||||
model = torch.compile(model)
|
||||
model(inp)
|
||||
|
||||
# DO THIS
|
||||
model = MyModel()
|
||||
model.compile()
|
||||
model(inp)
|
||||
|
||||
# this is also acceptable
|
||||
@torch.compile
|
||||
def fn(model, inp):
|
||||
return model(inp)
|
||||
model = MyModel()
|
||||
fn(model, inp)
|
||||
```
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user