mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 23:54:56 +08:00
Compare commits
133 Commits
flex-flash
...
newtest-ba
| Author | SHA1 | Date | |
|---|---|---|---|
| 19f1f9960d | |||
| fd6655a0f5 | |||
| a7f3bdf550 | |||
| 510e8b4ae0 | |||
| 83ba3f1101 | |||
| 1fad16aacb | |||
| 444e2381d0 | |||
| 6085bf7565 | |||
| 8201dbf4bc | |||
| 26d045bb60 | |||
| 356ac3103a | |||
| d4109a0f99 | |||
| 7ea789ccfb | |||
| 7e8197e34d | |||
| 50eac811a6 | |||
| 4e0f179d0b | |||
| 36e59d9b12 | |||
| fc340d0ca3 | |||
| 53e47af0f7 | |||
| 66ad881fc7 | |||
| 1d3eef27ac | |||
| dd95900cec | |||
| 1cdd665526 | |||
| 7cb2dcd2dd | |||
| e5a81aa7ba | |||
| 3e2aa4b0e3 | |||
| 6646461764 | |||
| f74da2a136 | |||
| d35b27dde5 | |||
| a9dc1566d4 | |||
| 33a1996714 | |||
| ee62177c19 | |||
| 64cbaa876c | |||
| 4516c59f5f | |||
| 8bc843a9ec | |||
| e39a62c70d | |||
| 978e3a9142 | |||
| e2a5c42e7e | |||
| 5116c49b52 | |||
| fecdebe385 | |||
| e136a9175b | |||
| 9a680e14b7 | |||
| 805a102beb | |||
| 6e8d705a22 | |||
| 9c18901bfd | |||
| a29ed5e1ac | |||
| d2792f51b2 | |||
| be71000ff5 | |||
| 3f86076775 | |||
| 1616777cd2 | |||
| 38895c0ac2 | |||
| 310f901a71 | |||
| e11b1cd97e | |||
| b599d91738 | |||
| fd6a6658c3 | |||
| 04973496a8 | |||
| 1548b011ea | |||
| e57a92734d | |||
| 79ff3b320b | |||
| 426f249f20 | |||
| d33a484763 | |||
| a81ffbc5f5 | |||
| 465fe4d9f7 | |||
| 9477af1063 | |||
| dcc36e38bb | |||
| efd78584a8 | |||
| 135762ea20 | |||
| e2ee9cfaa2 | |||
| 06d28de17a | |||
| df9720b8b5 | |||
| 85e74d5ace | |||
| 0450f05658 | |||
| 595a65f5c2 | |||
| 8c6c2e40eb | |||
| 32840d19f9 | |||
| 2040f00112 | |||
| c137f9da0b | |||
| 5e8b95605f | |||
| 8ea86a6e31 | |||
| acad808545 | |||
| c687446374 | |||
| dd22ba09b4 | |||
| c0e0126399 | |||
| e4b123b5e4 | |||
| 5711a8f069 | |||
| b4b71d011e | |||
| 52376b9b6f | |||
| 1371a98b0e | |||
| 2a286cbdf4 | |||
| 7c37b8e1e0 | |||
| ee2649219c | |||
| b0b3e6e48b | |||
| 3967dbedf4 | |||
| 4396b15aa7 | |||
| bb6766053b | |||
| a4fc051c9a | |||
| 5cc6a0abc1 | |||
| 90f13f3b2a | |||
| cb9b74872b | |||
| c964204829 | |||
| 2ac45c2752 | |||
| 83e2ea8135 | |||
| d994027a41 | |||
| cb4f41e125 | |||
| 690fc9cf88 | |||
| eb853e222b | |||
| 06395276e4 | |||
| 8becf646ef | |||
| fa68216ca1 | |||
| 25ef3d315d | |||
| 7e00f2ec9d | |||
| 490cb3f1a4 | |||
| b95cf5c91d | |||
| 5e2ef2a465 | |||
| 9f753f8c0d | |||
| db437690d1 | |||
| 669009bcd1 | |||
| e4e2701429 | |||
| 64cc649275 | |||
| b1fb552974 | |||
| bb62e1f769 | |||
| 327e2ca580 | |||
| 1ebcba4e1b | |||
| 5f7eae697d | |||
| c1722db0f7 | |||
| 8a233d6000 | |||
| bf3ebd7ad4 | |||
| c07bb277a0 | |||
| f89c28cc6b | |||
| 8fedcfa59a | |||
| 6662a76f59 | |||
| 05aade1b6d | |||
| f946b25865 |
@ -1 +1 @@
|
||||
11ec6354315768a85da41032535e3b7b99c5f706
|
||||
f7888497a1eb9e98d4c07537f0d0bcfe180d1363
|
||||
|
||||
@ -103,5 +103,5 @@ fi
|
||||
# It depends on torch and triton. We don't want to install
|
||||
# triton and torch from production on Docker CI images
|
||||
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
|
||||
pip_install helion==0.0.10 --no-deps
|
||||
pip_install helion --no-deps
|
||||
fi
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
|
||||
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
@ -50,7 +50,7 @@ IPython==8.12.0
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
|
||||
@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
|
||||
ROCBLAS_LIB_DST=lib/rocblas/library
|
||||
ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH)
|
||||
ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES)
|
||||
|
||||
# hipblaslt library files
|
||||
HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library
|
||||
|
||||
@ -627,6 +627,8 @@ test_perf_for_dashboard() {
|
||||
device=cuda_a10g
|
||||
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
|
||||
device=cuda_h100
|
||||
elif [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
device=cuda_b200
|
||||
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
||||
device=rocm
|
||||
fi
|
||||
@ -801,6 +803,16 @@ test_dynamo_benchmark() {
|
||||
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
|
||||
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
|
||||
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
|
||||
# TODO (huydhn): Just smoke test some sample models
|
||||
if [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
if [[ "${suite}" == "huggingface" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="DistillGPT2"
|
||||
elif [[ "${suite}" == "timm_models" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="inception_v3"
|
||||
elif [[ "${suite}" == "torchbench" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="hf_Bert"
|
||||
fi
|
||||
fi
|
||||
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
|
||||
else
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
bf305f538005f2e900f8850ed57146024a8bc559
|
||||
9b57c7bd5ad4db093c5bb31c802df9f04d933ac9
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
ca9e2be3ed6320b51f52f536595cd24e254f8bb2
|
||||
6a39ba85fe0f2fff9494b5eccea717c93510c230
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
29ae4c76c026185f417a25e841d2cd5e65f087a3
|
||||
b6a5b82b9948b610fa4c304d0d869c82b8f17db1
|
||||
|
||||
4
.github/merge_rules.yaml
vendored
4
.github/merge_rules.yaml
vendored
@ -488,6 +488,10 @@
|
||||
- torch/_dynamo/**
|
||||
- torch/csrc/dynamo/**
|
||||
- test/dynamo/**
|
||||
- test/dynamo_expected_failures/**
|
||||
- test/dynamo_skips/**
|
||||
- test/inductor_expected_failures/**
|
||||
- test/inductor_skips/**
|
||||
approved_by:
|
||||
- guilhermeleobas
|
||||
mandatory_checks_name:
|
||||
|
||||
@ -193,7 +193,7 @@ LIBTORCH_CONTAINER_IMAGES: dict[str, str] = {
|
||||
"cpu": "libtorch-cxx11-builder:cpu",
|
||||
}
|
||||
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"]
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
|
||||
|
||||
|
||||
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
|
||||
@ -315,6 +315,11 @@ def generate_wheels_matrix(
|
||||
# TODO: Enable python 3.13t on cpu-s390x
|
||||
if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
|
||||
continue
|
||||
# TODO: Enable python 3.14 on non linux OSes
|
||||
if os != "linux" and (
|
||||
python_version == "3.14" or python_version == "3.14t"
|
||||
):
|
||||
continue
|
||||
|
||||
if use_split_build and (
|
||||
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"
|
||||
|
||||
20
.github/workflows/_linux-test.yml
vendored
20
.github/workflows/_linux-test.yml
vendored
@ -96,7 +96,7 @@ jobs:
|
||||
steps:
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
instructions: |
|
||||
@ -109,7 +109,7 @@ jobs:
|
||||
no-sudo: true
|
||||
|
||||
- name: Setup Python
|
||||
if: matrix.runner == 'B200'
|
||||
if: contains(matrix.runner, 'b200')
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200'
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200')
|
||||
|
||||
- name: configure aws credentials
|
||||
if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
@ -128,7 +128,7 @@ jobs:
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }}
|
||||
if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }}
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
@ -166,17 +166,17 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
|
||||
with:
|
||||
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Setup GPU_FLAG for docker run
|
||||
id: setup-gpu-flag
|
||||
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }}
|
||||
|
||||
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
|
||||
id: setup-sscache-port-flag
|
||||
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }}
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Lock NVIDIA A100 40GB Frequency
|
||||
run: |
|
||||
@ -277,8 +277,8 @@ jobs:
|
||||
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
|
||||
TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }}
|
||||
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
|
||||
SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }}
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
@ -403,7 +403,7 @@ jobs:
|
||||
job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }}
|
||||
|
||||
- name: Authenticate with AWS
|
||||
if: ${{ matrix.runner == 'B200' }}
|
||||
if: ${{ contains(matrix.runner, 'b200') }}
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
||||
|
||||
3
.github/workflows/docker-builds.yml
vendored
3
.github/workflows/docker-builds.yml
vendored
@ -76,7 +76,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
|
||||
pytorch-linux-jammy-py3-clang12-executorch,
|
||||
# Executorch pin needs update
|
||||
# pytorch-linux-jammy-py3-clang12-executorch,
|
||||
pytorch-linux-jammy-py3.12-triton-cpu
|
||||
]
|
||||
include:
|
||||
|
||||
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
@ -0,0 +1,154 @@
|
||||
name: inductor-perf-b200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 7 * * 1-6
|
||||
- cron: 0 7 * * 0
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
build:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
# Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100
|
||||
# or newer GPUs, so it doesn't benefit much from existing compiler cache
|
||||
# from trunk. Also use a memory-intensive runner here because memory is
|
||||
# usually the bottleneck
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
selected-test-configs: ${{ inputs.benchmark_configs }}
|
||||
build-additional-packages: "vision audio fbgemm torchao"
|
||||
secrets: inherit
|
||||
|
||||
test-periodically:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 1-6'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test-weekly:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 0'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 1440
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
30
.github/workflows/inductor-periodic.yml
vendored
30
.github/workflows/inductor-periodic.yml
vendored
@ -81,21 +81,21 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
9
.github/workflows/nightly.yml
vendored
9
.github/workflows/nightly.yml
vendored
@ -75,10 +75,11 @@ jobs:
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .github/ci_commit_pins
|
||||
- repo-name: executorch
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .ci/docker/ci_commit_pins
|
||||
# executorch jobs are disabled since it needs some manual work for the hash update
|
||||
# - repo-name: executorch
|
||||
# repo-owner: pytorch
|
||||
# branch: main
|
||||
# pin-folder: .ci/docker/ci_commit_pins
|
||||
- repo-name: triton
|
||||
repo-owner: triton-lang
|
||||
branch: main
|
||||
|
||||
1
.github/workflows/pull.yml
vendored
1
.github/workflows/pull.yml
vendored
@ -434,6 +434,7 @@ jobs:
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3-clang12-executorch-build:
|
||||
if: false # Docker build needs pin update
|
||||
name: linux-jammy-py3-clang12-executorch
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
|
||||
2
.github/workflows/update-viablestrict.yml
vendored
2
.github/workflows/update-viablestrict.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
||||
with:
|
||||
repository: pytorch/pytorch
|
||||
stable-branch: viable/strict
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]'
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]'
|
||||
secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }}
|
||||
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
|
||||
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
/torch/csrc/autograd/ @albanD @soulitzer
|
||||
/torch/autograd/ @albanD @soulitzer
|
||||
/tools/autograd/ @albanD @soulitzer
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||
/torch/optim/ @albanD @janeyx99
|
||||
/test/test_public_bindings.py @albanD
|
||||
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
|
||||
/torch/utils/_cxx_pytree.py @XuehaiPan
|
||||
/torch/utils/pytree/ @XuehaiPan
|
||||
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
|
||||
|
||||
# Relating to libtorch ABI
|
||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
||||
/torch/headeronly/ @janeyx99
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
|
||||
@ -439,6 +439,7 @@ if(USE_ROCM)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
@ -703,21 +704,17 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
cmake_path(GET SHADER STEM TGT_STEM)
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
|
||||
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
||||
list(APPEND AIR_BASIC ${TGT_BASIC})
|
||||
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
|
||||
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
|
||||
endforeach()
|
||||
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
|
||||
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
|
||||
add_custom_command(
|
||||
COMMAND echo "// $$(date)" > metallib_dummy.cpp
|
||||
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
|
||||
DEPENDS kernels_basic.metallib
|
||||
OUTPUT metallib_dummy.cpp
|
||||
COMMENT "Updating metallibs timestamp")
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
||||
else()
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
|
||||
@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl
|
||||
}
|
||||
|
||||
bool pinned_use_background_threads() override {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void _assert_match<c10::Device, std::optional<c10::Device>>(
|
||||
const c10::Device& original,
|
||||
const std::optional<c10::Device>& compared,
|
||||
const std::string& name) {
|
||||
if (compared) {
|
||||
const c10::Device& expected = compared.value();
|
||||
if (original.type() != expected.type()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// If the expected device doesn't have an index (e.g., just "cuda"),
|
||||
// or if both devices have the same index, consider them equal
|
||||
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||
|
||||
@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
|
||||
auto* C_data = C.data_ptr<T>();
|
||||
const auto* S_data = scales.const_data_ptr<T>();
|
||||
|
||||
int M = A.size(0);
|
||||
int N = B.size(0);
|
||||
int K = A.size(1);
|
||||
int lda = A.stride(0);
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 4;
|
||||
int64_t M = A.size(0);
|
||||
int64_t N = B.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 4;
|
||||
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * lda;
|
||||
const auto* B_ptr = B_data + nb_start * K;
|
||||
|
||||
@ -526,7 +526,7 @@ namespace {
|
||||
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
|
||||
@ -681,7 +681,7 @@ namespace {
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
|
||||
@ -1634,6 +1634,9 @@ bool use_fast_accum) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
@ -1716,6 +1719,9 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
|
||||
// check that the strides are valid, the fn will throw an error if not
|
||||
check_valid_strides_and_return_transposed(mat_a);
|
||||
|
||||
@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
|
||||
class CuFFTConfig {
|
||||
public:
|
||||
|
||||
// Only move semantics is enought for this class. Although we already use
|
||||
// Only move semantics is enough for this class. Although we already use
|
||||
// unique_ptr for the plan, still remove copy constructor and assignment op so
|
||||
// we don't accidentally copy and take perf hit.
|
||||
CuFFTConfig(const CuFFTConfig&) = delete;
|
||||
|
||||
@ -241,6 +241,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>(
|
||||
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
|
||||
@ -264,6 +266,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
0,
|
||||
0,
|
||||
a_row_major,
|
||||
|
||||
@ -38,18 +38,20 @@ __global__ void prepare_grouped_gemm_data(
|
||||
Strides tensor_StrideA,
|
||||
Strides tensor_StrideB,
|
||||
Strides tensor_StrideOutput,
|
||||
Strides tensor_ShapeA,
|
||||
Strides tensor_ShapeB,
|
||||
int64_t a_scale_stride,
|
||||
int64_t b_scale_stride,
|
||||
bool a_row_major = true,
|
||||
bool b_row_major = false) {
|
||||
int32_t tid = threadIdx.x;
|
||||
int32_t delta = 0;
|
||||
int32_t offset = 0;
|
||||
if (offs != nullptr) {
|
||||
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
||||
delta = offs[tid] - start;
|
||||
if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
||||
}
|
||||
offset = offs[tid];
|
||||
delta = offset - start;
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n");
|
||||
|
||||
// TMA transfers require global memory tensor addresses to be
|
||||
// aligned to 16 bytes.
|
||||
@ -84,6 +86,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
int64_t lda, ldb, ldoutput;
|
||||
if (M < 0) {
|
||||
// A and output is 2d
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n");
|
||||
M = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
|
||||
@ -96,6 +99,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
|
||||
B_ptrs[tid] = B + tid * tensor_StrideB[0];
|
||||
} else if (N < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n");
|
||||
N = delta;
|
||||
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
|
||||
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
|
||||
@ -108,6 +112,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
|
||||
}
|
||||
} else if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n");
|
||||
// A, B is 2d, output is 3d
|
||||
K = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
|
||||
@ -298,6 +298,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
// scale stride will be used inside the kernel only if needed,
|
||||
// so for 1d scales the "1" assigned here won't be used
|
||||
int64_t a_scale_stride = scale_a.stride(0);
|
||||
@ -325,6 +328,8 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
a_scale_stride,
|
||||
b_scale_stride);
|
||||
|
||||
|
||||
@ -28,6 +28,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<Tensor>& running_mean,
|
||||
const std::optional<Tensor>& running_var,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon,
|
||||
Tensor& out,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
@ -120,7 +136,12 @@ size_t _get_cudnn_batch_norm_reserve_space_size(
|
||||
return reserve_size;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// Param `reserve` is a placeholder, just passing an empty tensor.
|
||||
// usage:
|
||||
// auto reserve = torch::empty({0}, torch::device(torch::kCUDA));
|
||||
// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var,
|
||||
// reserve);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
@ -128,7 +149,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
double epsilon,
|
||||
Tensor& output_t,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_t_opt);
|
||||
@ -168,9 +193,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
|
||||
training, input->suggest_memory_format(), input->dim());
|
||||
|
||||
auto output_t =
|
||||
at::empty_like(*input, input->options(), input->suggest_memory_format());
|
||||
|
||||
TensorArg output{output_t, "output", 0};
|
||||
|
||||
auto handle = getCudnnHandle();
|
||||
@ -182,15 +204,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
Tensor save_mean, save_var;
|
||||
|
||||
Tensor reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
size_t workspace_size;
|
||||
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
|
||||
@ -238,9 +253,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
reserve_size));
|
||||
} else {
|
||||
reserve = at::empty({0}, input->options().dtype(kByte));
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
|
||||
handle,
|
||||
mode,
|
||||
@ -261,10 +273,48 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// save_mean and save_var can be undefined
|
||||
// If this causes problems, we can initialize them to empty tensors
|
||||
// of the correct type
|
||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>{
|
||||
output_t, save_mean, save_var, reserve};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
const std::optional<Tensor>& running_mean_t_opt,
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
auto output_t = at::empty_like(
|
||||
input_t, input_t.options(), input_t.suggest_memory_format());
|
||||
Tensor save_mean, save_var, reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
} else {
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
}
|
||||
|
||||
return cudnn_batch_norm_out(
|
||||
input_t,
|
||||
weight_t,
|
||||
bias_t_opt,
|
||||
running_mean_t_opt,
|
||||
running_var_t_opt,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
output_t,
|
||||
save_mean,
|
||||
save_var,
|
||||
reserve);
|
||||
}
|
||||
|
||||
// NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
// in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
// which is why this doesn't accept a 'training' parameter.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
|
||||
@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool use_mkldnn_typed_matmul(
|
||||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
bool dtype_check = false;
|
||||
if constexpr (std::is_same_v<T, c10::BFloat16>) {
|
||||
#if defined(__aarch64__)
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
dtype_check = use_mkldnn_bf16_matmul() &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
|
||||
}
|
||||
#else
|
||||
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == kBFloat16);
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
} else
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<T, c10::Half>) {
|
||||
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
|
||||
(mat1.scalar_type() == kHalf);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
dtype_check = dtype_check &&
|
||||
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
|
||||
(mat1.scalar_type() == kFloat);
|
||||
{
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
|
||||
mat2.scalar_type() == kBFloat16 &&
|
||||
(!result.defined() || result.scalar_type() == kBFloat16) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
if (!dtype_check) {
|
||||
return false;
|
||||
}
|
||||
bool size_check =
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
|
||||
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || result.scalar_type() == mat1.scalar_type());
|
||||
return dtype_check && size_check;
|
||||
}
|
||||
|
||||
bool use_mkldnn_fp16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
|
||||
mat2.scalar_type() == kHalf &&
|
||||
(!result.defined() || result.scalar_type() == kHalf) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_bf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_tf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
auto mat1_type = mat1.scalar_type();
|
||||
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
|
||||
return false;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
|
||||
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
|
||||
});
|
||||
return false;
|
||||
return (
|
||||
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_tf32_matmul(mat1, mat2, result));
|
||||
}
|
||||
|
||||
static void _mkldnn_matmul_i8i8i32_with_primitive(
|
||||
|
||||
@ -469,4 +469,94 @@ Tensor _weight_int4pack_mm_xpu(
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor& _int_mm_out_xpu(
|
||||
const Tensor& self,
|
||||
const Tensor& mat2,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(
|
||||
self.dim() == 2,
|
||||
"Expected self to be of dimension 2 but got ",
|
||||
self.dim());
|
||||
TORCH_CHECK(
|
||||
mat2.dim() == 2,
|
||||
"Expected mat2 to be of dimension 2 but got ",
|
||||
mat2.dim());
|
||||
TORCH_CHECK(
|
||||
self.size(1) == mat2.size(0),
|
||||
"self.size(1) needs to match mat2.size(0) but got ",
|
||||
self.size(1),
|
||||
" and ",
|
||||
mat2.size(0));
|
||||
|
||||
TORCH_CHECK(
|
||||
self.dtype() == at::kChar,
|
||||
"Expected self dtype to be of type int8 but got ",
|
||||
self.dtype());
|
||||
TORCH_CHECK(
|
||||
mat2.dtype() == at::kChar,
|
||||
"Expected mat2 dtype to be of type int8 but got ",
|
||||
mat2.dtype());
|
||||
TORCH_CHECK(
|
||||
result.dtype() == at::kInt,
|
||||
"Expected result dtype to be of type kInt but got ",
|
||||
result.dtype());
|
||||
TORCH_CHECK(
|
||||
result.size(0) == self.size(0),
|
||||
"Expected result.size(0) to be ",
|
||||
self.size(0),
|
||||
" but got ",
|
||||
result.size(0));
|
||||
TORCH_CHECK(
|
||||
result.size(1) == mat2.size(1),
|
||||
"Expected result.size(1) to be ",
|
||||
mat2.size(1),
|
||||
" but got ",
|
||||
result.size(1));
|
||||
|
||||
TORCH_CHECK(
|
||||
result.dim() == 2,
|
||||
"Expected result to be of dimension 2 but got ",
|
||||
result.dim());
|
||||
|
||||
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
|
||||
|
||||
if (result.numel() == 0 || self.size(1) == 0) {
|
||||
return result.zero_();
|
||||
}
|
||||
|
||||
Tensor bias = at::Tensor();
|
||||
Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat));
|
||||
Tensor mat2_zero_points = at::Tensor();
|
||||
auto post_op_args = torch::List<std::optional<at::Scalar>>();
|
||||
|
||||
at::native::onednn::quantized_matmul(
|
||||
self.contiguous(),
|
||||
1.0,
|
||||
0,
|
||||
mat2.contiguous(),
|
||||
mat2_scales,
|
||||
mat2_zero_points,
|
||||
bias,
|
||||
result,
|
||||
1.0,
|
||||
0,
|
||||
result.scalar_type(),
|
||||
/*other*/ std::nullopt,
|
||||
/*other scale*/ 1.0,
|
||||
/*other zp*/ 0,
|
||||
/*binary post op*/ "none",
|
||||
/*binary alpha*/ 1.0,
|
||||
/*post_op_name*/ "none",
|
||||
post_op_args,
|
||||
/*post_op_algorithm*/ "none",
|
||||
/*m2_trans*/ true);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
|
||||
Tensor result =
|
||||
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
|
||||
return _int_mm_out_xpu(self, mat2, result);
|
||||
}
|
||||
} // namespace at::native
|
||||
|
||||
@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
|
||||
if (C10_UNLIKELY(!library)) {
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
NSError* error = nil;
|
||||
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
|
||||
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
|
||||
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
return library;
|
||||
|
||||
@ -33,21 +33,15 @@ struct shrink_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardsigmoid_functor {
|
||||
template <typename T>
|
||||
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardsigmoid, float, float);
|
||||
REGISTER_UNARY_OP(hardsigmoid, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardswish_functor {
|
||||
template <typename T>
|
||||
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardswish, float, float);
|
||||
REGISTER_UNARY_OP(hardswish, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardswish_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardswish_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct leaky_relu_functor {
|
||||
template <typename T>
|
||||
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -113,18 +113,12 @@ kernel void ampUpdateScale(
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(float);
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
|
||||
#endif
|
||||
|
||||
@ -590,9 +590,7 @@ kernel void attention(
|
||||
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(float);
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(float);
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
|
||||
#endif
|
||||
|
||||
@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
|
||||
};
|
||||
|
||||
struct nextafter_functor {
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename U>
|
||||
struct bit_type {};
|
||||
template <>
|
||||
struct bit_type<float> {
|
||||
using type = int;
|
||||
};
|
||||
template <>
|
||||
struct bit_type<half> {
|
||||
using type = short;
|
||||
};
|
||||
#endif
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
#if __METAL_VERSION__ >= 310
|
||||
return static_cast<T>(::metal::nextafter(a, b));
|
||||
#else
|
||||
using U = typename bit_type<T>::type;
|
||||
if (a == b) {
|
||||
return a;
|
||||
}
|
||||
if (::metal::isunordered(a, b)) {
|
||||
return NAN;
|
||||
}
|
||||
if (a == 0) {
|
||||
constexpr auto eps = as_type<T>(static_cast<U>(1));
|
||||
return b > 0 ? eps : -eps;
|
||||
}
|
||||
auto bits = as_type<U>(a);
|
||||
(a > 0) ^ (a > b) ? bits++ : bits--;
|
||||
return as_type<T>(bits);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -344,13 +315,6 @@ struct fmod_functor {
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper defines
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#define _METAL_310_PLUS(x) x
|
||||
#else
|
||||
#define _METAL_310_PLUS(x)
|
||||
#endif
|
||||
|
||||
#define REGISTER_INTEGER_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, long, long); \
|
||||
REGISTER_BINARY_OP(NAME, int, int); \
|
||||
@ -370,12 +334,12 @@ struct fmod_functor {
|
||||
#define REGISTER_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
REGISTER_BINARY_OP(polar, float, float2);
|
||||
|
||||
@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
|
||||
REGISTER_SEARCHSORTED_OP(float, long);
|
||||
REGISTER_SEARCHSORTED_OP(half, int);
|
||||
REGISTER_SEARCHSORTED_OP(half, long);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, int);
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, long);
|
||||
#endif
|
||||
REGISTER_SEARCHSORTED_OP(char, int);
|
||||
REGISTER_SEARCHSORTED_OP(char, long);
|
||||
REGISTER_SEARCHSORTED_OP(uchar, int);
|
||||
|
||||
@ -96,6 +96,4 @@ kernel void col2im_kernel(
|
||||
INSTANTIATE_COL2IM(bool);
|
||||
INSTANTIATE_COL2IM(float);
|
||||
INSTANTIATE_COL2IM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_COL2IM(bfloat);
|
||||
#endif
|
||||
|
||||
@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
|
||||
REGISTER_CROSS_FUNC(char);
|
||||
REGISTER_CROSS_FUNC(uchar);
|
||||
REGISTER_CROSS_FUNC(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_FUNC(bfloat);
|
||||
#endif
|
||||
|
||||
template <typename T, typename U>
|
||||
kernel void cross(
|
||||
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
|
||||
REGISTER_CROSS_OP(char);
|
||||
REGISTER_CROSS_OP(uchar);
|
||||
REGISTER_CROSS_OP(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using metal::max;
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
|
||||
REGISTER_ADAM_OPS_QUART(float, half);
|
||||
REGISTER_ADAM_OPS_QUART(half, float);
|
||||
REGISTER_ADAM_OPS_QUART(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_ADAM_OPS_QUART(float, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, float);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sgd_momentum_math(
|
||||
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
|
||||
REGISTER_FUSED_SGD_OP(half);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_FUSED_SGD_OP(bfloat);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -106,9 +106,7 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_GAMMA_KERNELS(half, half);
|
||||
INSTANTIATE_GAMMA_KERNELS(float, float);
|
||||
INSTANTIATE_GAMMA_KERNELS(bool, float);
|
||||
|
||||
@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
|
||||
INSTANTIATE_IM2COL(float2);
|
||||
INSTANTIATE_IM2COL(half);
|
||||
INSTANTIATE_IM2COL(half2);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_IM2COL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
template <typename StridesT, typename DataT>
|
||||
kernel void kernel_index_offsets(
|
||||
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
|
||||
INSTANTIATE_INDEX_COPY(uchar, int);
|
||||
INSTANTIATE_INDEX_COPY(uchar, long);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INDEX_COPY(bfloat, int);
|
||||
INSTANTIATE_INDEX_COPY(bfloat, long);
|
||||
#endif
|
||||
INSTANTIATE_INDEX_COPY(float2, int);
|
||||
INSTANTIATE_INDEX_COPY(float2, long);
|
||||
INSTANTIATE_INDEX_COPY(half2, int);
|
||||
|
||||
@ -288,7 +288,6 @@ kernel void layer_norm_looped(
|
||||
#define instantiate_layer_norm(DTYPE) \
|
||||
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
|
||||
|
||||
instantiate_layer_norm(float) instantiate_layer_norm(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_layer_norm(bfloat)
|
||||
#endif
|
||||
instantiate_layer_norm(float);
|
||||
instantiate_layer_norm(half);
|
||||
instantiate_layer_norm(bfloat);
|
||||
|
||||
@ -635,9 +635,7 @@ kernel void applyPivots(
|
||||
|
||||
INSTANTIATE_NAIVE_MM(float);
|
||||
INSTANTIATE_NAIVE_MM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_NAIVE_MM(bfloat);
|
||||
#endif
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_NAIVE_MM(short);
|
||||
|
||||
@ -48,3 +48,14 @@ struct PoolingBackwardParams {
|
||||
::c10::metal::array<idx_type_t, N> grad_output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct MaxUnpoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
@ -168,6 +168,16 @@ PoolOffsets find_pool_offsets(
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
case 3:
|
||||
return find_pool_offsets_dim_specific<3>(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
}
|
||||
return PoolOffsets();
|
||||
}
|
||||
@ -292,6 +302,68 @@ kernel void max_pool_backward(
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void max_unpool_impl(
|
||||
device T* output,
|
||||
T input_element,
|
||||
int32_t input_index,
|
||||
constant int32_t* output_sizes,
|
||||
constant int32_t* output_strides,
|
||||
int32_t pooling_dims) {
|
||||
int32_t size_prod = 1;
|
||||
int32_t pool_offset = 0;
|
||||
|
||||
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
auto next_size_prod = output_sizes[dim] * size_prod;
|
||||
pool_offset +=
|
||||
output_strides[dim] * ((input_index % next_size_prod) / size_prod);
|
||||
size_prod *= output_sizes[dim];
|
||||
}
|
||||
|
||||
output[pool_offset] = input_element;
|
||||
}
|
||||
|
||||
// Kernel computes one element of the grad input per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_unpool(
|
||||
device T* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant int64_t* indices [[buffer(2)]],
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto input_sizes = params.input_sizes.data();
|
||||
auto input_strides = params.input_strides.data();
|
||||
auto output_sizes = params.output_sizes.data();
|
||||
auto output_strides = params.output_strides.data();
|
||||
auto indices_strides = params.indices_strides.data();
|
||||
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// NOTE: Since we're doing unpooling, the variable names "input" and "output"
|
||||
// are reversed compared to the pooling operations. So in `find_pool_offsets`,
|
||||
// we need to map "input" -> "output" and "output" -> "input".
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
/*output_sizes=*/input_sizes,
|
||||
/*output_strides=*/input_strides,
|
||||
indices_strides,
|
||||
/*input_strides=*/output_strides,
|
||||
/*pooling_dim_indices=*/nullptr,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/true,
|
||||
tid);
|
||||
|
||||
max_unpool_impl<T>(
|
||||
output + offsets.input_leading,
|
||||
input[offsets.output],
|
||||
indices[offsets.indices],
|
||||
output_sizes + leading_dims,
|
||||
output_strides + leading_dims,
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AvgPoolIterBounds {
|
||||
T start;
|
||||
@ -428,18 +500,25 @@ kernel void avg_pool(
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool<DTYPE>( \
|
||||
device DTYPE * output [[buffer(0)]], \
|
||||
constant DTYPE * input [[buffer(1)]], \
|
||||
constant int64_t* indices [[buffer(2)]], \
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
@ -453,6 +532,7 @@ kernel void avg_pool(
|
||||
|
||||
REGISTER_POOL_OP(float);
|
||||
REGISTER_POOL_OP(half);
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_POOL_OP(int);
|
||||
REGISTER_POOL_OP(long);
|
||||
REGISTER_POOL_OP(short);
|
||||
@ -462,8 +542,4 @@ REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
|
||||
INSTANTIATE_INT4MV(half, 128);
|
||||
INSTANTIATE_INT4MV(float, 256);
|
||||
INSTANTIATE_INT4MV(half, 256);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INT4MV(bfloat, 32);
|
||||
INSTANTIATE_INT4MV(bfloat, 64);
|
||||
INSTANTIATE_INT4MV(bfloat, 128);
|
||||
INSTANTIATE_INT4MV(bfloat, 256);
|
||||
#endif
|
||||
|
||||
// ------------------------------ int8 MM For M >= 12 ------------------------------------
|
||||
/**
|
||||
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
|
||||
using simdgroup_type8x8 = simdgroup_half8x8;
|
||||
using type4 = half4;
|
||||
};
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <> struct BlockType<bfloat> {
|
||||
using simdgroup_type8x8 = simdgroup_bfloat8x8;
|
||||
using type4 = bfloat4;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
|
||||
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
|
||||
|
||||
INSTANTIATE_MM(float, char, get_scale_zero_q8);
|
||||
INSTANTIATE_MM(half, char, get_scale_zero_q8);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
|
||||
#endif
|
||||
// ------------------------------ int8 MM For M < 12 ------------------------------------
|
||||
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
|
||||
|
||||
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
|
||||
|
||||
INSTANTIATE_MV(float);
|
||||
INSTANTIATE_MV(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MV(bfloat);
|
||||
#endif
|
||||
|
||||
@ -192,6 +192,4 @@ template <typename T>
|
||||
|
||||
instantiate_rms(float)
|
||||
instantiate_rms(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_rms(bfloat)
|
||||
#endif // clang-format on
|
||||
|
||||
@ -23,6 +23,4 @@ kernel void renorm(
|
||||
|
||||
REGISTER_RENORM_OP(float);
|
||||
REGISTER_RENORM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_RENORM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -25,379 +25,6 @@ struct LogAddExp {
|
||||
};
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMinOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::min(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMaxOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::max(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::lowest());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct LogCumSumExpOp {
|
||||
static acc_t apply(acc_t x, acc_t y) {
|
||||
return LogAddExp{}(x, y);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return -metal::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Inclusive scan along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_rows [[buffer(2)]],
|
||||
constant uint& row_size [[buffer(3)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[offset + col] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_orows [[buffer(2)]],
|
||||
constant uint& num_irows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[idx] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_rows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[offset + col] = static_cast<T>(accumulator);
|
||||
indices[offset + col] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_orows [[buffer(3)]],
|
||||
constant uint& num_irows [[buffer(4)]],
|
||||
constant uint& row_size [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[idx] = static_cast<T>(accumulator);
|
||||
indices[idx] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Shared utility functions for strided kernels
|
||||
inline long calculate_non_scan_elements(
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long total = 1;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
total *= sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
inline void thread_index_to_coordinates(
|
||||
uint index,
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long remaining_index = index;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
pos[i] = remaining_index % sizes[i];
|
||||
remaining_index /= sizes[i];
|
||||
} else {
|
||||
pos[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline long calculate_base_offset(
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* strides,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long offset = 0;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
offset += pos[i] * strides[i];
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Generic strided scan kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant long* sizes [[buffer(2)]],
|
||||
constant long* input_strides [[buffer(3)]],
|
||||
constant long* output_strides [[buffer(4)]],
|
||||
constant uint& ndim [[buffer(5)]],
|
||||
constant uint& scan_dim [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long output_base_offset =
|
||||
calculate_base_offset(pos, output_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long output_scan_stride = output_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long output_offset =
|
||||
output_base_offset + scan_idx * output_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[output_offset] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic strided scan with indices kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant long* sizes [[buffer(3)]],
|
||||
constant long* input_strides [[buffer(4)]],
|
||||
constant long* values_strides [[buffer(5)]],
|
||||
constant long* indices_strides [[buffer(6)]],
|
||||
constant uint& ndim [[buffer(7)]],
|
||||
constant uint& scan_dim [[buffer(8)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long values_base_offset =
|
||||
calculate_base_offset(pos, values_strides, ndim, scan_dim);
|
||||
const long indices_base_offset =
|
||||
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long values_scan_stride = values_strides[scan_dim];
|
||||
const long indices_scan_stride = indices_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long values_offset =
|
||||
values_base_offset + scan_idx * values_scan_stride;
|
||||
const long indices_offset =
|
||||
indices_base_offset + scan_idx * indices_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = scan_idx;
|
||||
}
|
||||
values[values_offset] = static_cast<T>(accumulator);
|
||||
indices[indices_offset] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_rows [[buffer(2)]], \
|
||||
constant uint & row_size [[buffer(3)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_orows [[buffer(2)]], \
|
||||
constant uint & num_irows [[buffer(3)]], \
|
||||
constant uint & row_size [[buffer(4)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant long* sizes [[buffer(2)]], \
|
||||
constant long* input_strides [[buffer(3)]], \
|
||||
constant long* output_strides [[buffer(4)]], \
|
||||
constant uint& ndim [[buffer(5)]], \
|
||||
constant uint& scan_dim [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_rows [[buffer(3)]], \
|
||||
constant uint& row_size [[buffer(4)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_orows [[buffer(3)]], \
|
||||
constant uint& num_irows [[buffer(4)]], \
|
||||
constant uint& row_size [[buffer(5)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant long* sizes [[buffer(3)]], \
|
||||
constant long* input_strides [[buffer(4)]], \
|
||||
constant long* values_strides [[buffer(5)]], \
|
||||
constant long* indices_strides [[buffer(6)]], \
|
||||
constant uint& ndim [[buffer(7)]], \
|
||||
constant uint& scan_dim [[buffer(8)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
// Simple scan operations
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
|
||||
|
||||
// Scan operations with indices
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
|
||||
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
||||
|
||||
#else // __METAL_VERSION__ >= 310
|
||||
|
||||
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||
|
||||
// The reminder of this file contains cummin and cummax implementations adapted
|
||||
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
|
||||
|
||||
#endif
|
||||
|
||||
@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
|
||||
REGISTER_SPECIAL(int, float);
|
||||
REGISTER_SPECIAL(long, float);
|
||||
REGISTER_SPECIAL(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SPECIAL(bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -100,9 +100,7 @@ kernel void triul(
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half, int);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float2, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half2, int);
|
||||
|
||||
@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
|
||||
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(neg, bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(abs, bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_UNARY_KERNELS2(half, half);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, float);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, bool);
|
||||
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
|
||||
#endif
|
||||
|
||||
@ -70,6 +70,4 @@ kernel void unfold_backward(
|
||||
|
||||
INSTANTIATE_UNFOLD_BACKWARD(float);
|
||||
INSTANTIATE_UNFOLD_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
||||
@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
|
||||
INSTANTIATE_UPSAMPLE_3D(uchar);
|
||||
INSTANTIATE_UPSAMPLE_ALL(float);
|
||||
INSTANTIATE_UPSAMPLE_ALL(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_ALL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -21,6 +21,8 @@
|
||||
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
||||
#include <ATen/ops/max_unpool2d_native.h>
|
||||
#include <ATen/ops/max_unpool3d_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
@ -492,6 +494,60 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
});
|
||||
}
|
||||
|
||||
static void max_unpool_out_mps_template(const Tensor& input,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size_,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto dims = input.dim();
|
||||
auto leading_dims = input.dim() - pooling_dims;
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
std::vector<int64_t> output_size(dims);
|
||||
for (int dim : c10::irange(leading_dims)) {
|
||||
output_size[dim] = input.sizes()[dim];
|
||||
}
|
||||
for (int dim : c10::irange(pooling_dims)) {
|
||||
output_size[leading_dims + dim] = output_size_[dim];
|
||||
}
|
||||
|
||||
output.resize_(output_size, memory_format);
|
||||
output.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = input.numel();
|
||||
MaxUnpoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
|
||||
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, output, input, indices, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
@ -896,6 +952,68 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling2d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling3d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling3d_forward_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
|
||||
@ -719,6 +719,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: all_out
|
||||
MPS: all_out_mps
|
||||
MTIA: all_out_mtia
|
||||
|
||||
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -808,6 +809,7 @@
|
||||
CPU, Meta: arange_out
|
||||
CUDA: arange_cuda_out
|
||||
MPS: arange_mps_out
|
||||
MTIA: arange_mtia_out
|
||||
cpp_no_default_args: ['step']
|
||||
|
||||
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
|
||||
@ -1889,7 +1891,10 @@
|
||||
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm
|
||||
autogen: cudnn_batch_norm.out
|
||||
|
||||
- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm_out
|
||||
|
||||
# NB: You can only use this if you used cudnn_batch_norm training=True
|
||||
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
|
||||
@ -4182,11 +4187,13 @@
|
||||
dispatch:
|
||||
CPU: _int_mm_cpu
|
||||
CUDA: _int_mm_cuda
|
||||
XPU: _int_mm_xpu
|
||||
|
||||
- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: _int_mm_out_cpu
|
||||
CUDA: _int_mm_out_cuda
|
||||
XPU: _int_mm_out_xpu
|
||||
|
||||
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
|
||||
dispatch:
|
||||
@ -7124,18 +7131,21 @@
|
||||
dispatch:
|
||||
CPU: _scaled_mm_cpu
|
||||
CUDA: _scaled_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_out_cpu
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
|
||||
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
@ -10487,6 +10497,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_scalar_kernel_mtia_
|
||||
autogen: _foreach_add.Scalar_out
|
||||
|
||||
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
|
||||
@ -10495,6 +10506,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia
|
||||
|
||||
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10502,6 +10514,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia_
|
||||
autogen: _foreach_add.List_out
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10532,6 +10545,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_add_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_tensor_kernel_mtia_
|
||||
autogen: _foreach_add.Tensor_out
|
||||
|
||||
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10592,6 +10606,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_scalar_kernel_mtia_
|
||||
autogen: _foreach_mul.Scalar_out
|
||||
|
||||
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
||||
@ -10600,6 +10615,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10607,6 +10623,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia_
|
||||
autogen: _foreach_mul.List_out
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10630,6 +10647,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10637,6 +10655,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia_
|
||||
autogen: _foreach_mul.Tensor_out
|
||||
|
||||
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10933,6 +10952,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia
|
||||
|
||||
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10954,6 +10974,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda_
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia_
|
||||
autogen: _foreach_addcmul.Scalar_out
|
||||
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
@ -10978,6 +10999,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow
|
||||
CUDA: foreach_tensor_abs_cuda
|
||||
MTIA: foreach_tensor_abs_mtia
|
||||
|
||||
- func: _foreach_abs_(Tensor(a!)[] self) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10985,6 +11007,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow_
|
||||
CUDA: foreach_tensor_abs_cuda_
|
||||
MTIA: foreach_tensor_abs_mtia_
|
||||
autogen: _foreach_abs.out
|
||||
|
||||
- func: _foreach_acos(Tensor[] self) -> Tensor[]
|
||||
@ -11319,6 +11342,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_norm_slow
|
||||
CUDA: foreach_tensor_norm_cuda
|
||||
MTIA: foreach_tensor_norm_mtia
|
||||
autogen: _foreach_norm.Scalar_out
|
||||
|
||||
- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
||||
@ -11491,6 +11515,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
|
||||
CUDA: foreach_tensor_sqrt_cuda_
|
||||
MTIA: foreach_tensor_sqrt_mtia_
|
||||
autogen: _foreach_sqrt.out
|
||||
|
||||
- func: _foreach_tan(Tensor[] self) -> Tensor[]
|
||||
@ -11552,6 +11577,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
|
||||
CUDA: foreach_tensor_copy_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia_
|
||||
autogen: _foreach_copy.out
|
||||
|
||||
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
|
||||
@ -11559,6 +11585,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _foreach_copy
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia
|
||||
|
||||
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
|
||||
dispatch:
|
||||
@ -12476,24 +12503,28 @@
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_out_cpu
|
||||
CUDA: max_unpooling2d_forward_out_cuda
|
||||
MPS: max_unpooling2d_forward_out_mps
|
||||
|
||||
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_cpu
|
||||
CUDA: max_unpooling2d_forward_cuda
|
||||
MPS: max_unpooling2d_forward_mps
|
||||
|
||||
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_out_cpu
|
||||
CUDA: max_unpooling3d_forward_out_cuda
|
||||
MPS: max_unpooling3d_forward_out_mps
|
||||
|
||||
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_cpu
|
||||
CUDA: max_unpooling3d_forward_cuda
|
||||
MPS: max_unpooling3d_forward_mps
|
||||
|
||||
- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -11,7 +11,27 @@ endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
--api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -19,15 +39,29 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# Generate the files for both fwd and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -44,6 +78,22 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass")
|
||||
endif()
|
||||
|
||||
# Change make_kernel to make_kernel_pt for bwd
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"
|
||||
|
||||
@ -21,6 +21,8 @@ while IFS= read -r file; do
|
||||
if [ -f "$file" ]; then
|
||||
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
|
||||
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
|
||||
sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
echo "Updated: $file"
|
||||
else
|
||||
echo "Skipping: $file (not found)"
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
|
||||
str.compare(0, 11, "elementwise") == 0)
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
|
||||
str.compare(0, 5, "alibi") == 0)
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -1,457 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <bias.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using GemmDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::half_t;
|
||||
using OGradDataType = ck_tile::half_t;
|
||||
using QGradDataType = ck_tile::half_t;
|
||||
using KGradDataType = ck_tile::half_t;
|
||||
using VGradDataType = ck_tile::half_t;
|
||||
using BiasGradDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using GemmDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using OGradDataType = ck_tile::bf16_t;
|
||||
using QGradDataType = ck_tile::bf16_t;
|
||||
using KGradDataType = ck_tile::bf16_t;
|
||||
using VGradDataType = ck_tile::bf16_t;
|
||||
using BiasGradDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_bwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* dq_ptr;
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
float scale;
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdOGradDotOKernel>
|
||||
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
@ -1,824 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include <bias.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <rotary.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
{
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
// the real seqlen_q & seqlen_k are decided by following:
|
||||
// batch mode: seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o_acc;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
};
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
|
||||
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool has_mask;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_vnew;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_knew;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = FmhaKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids =
|
||||
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_k, // only used for paged-kvcache
|
||||
args.batch_stride_v, // only used for paged-kvcache
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(
|
||||
args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel argumentszs
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.has_mask,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kN1_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_combine_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kTileSizeS_,
|
||||
ck_tile::index_t kTileSizeSk_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
ck_tile::index_t kTileSizeDv_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
bool kPadS_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
|
||||
bool kIsPagedKV_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr auto RotaryEnum = RotaryEnum_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_splitkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
fmha_fwd_splitkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_appendkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
rope_enum rope_type;
|
||||
};
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
@ -1,157 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = atoi(v.c_str());
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto set_causal_top_left = [&]() {
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
auto set_causal_bottom_right = [&]() {
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
if(str == "t")
|
||||
set_causal_top_left();
|
||||
else if(str == "b")
|
||||
set_causal_bottom_right();
|
||||
else
|
||||
{
|
||||
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
|
||||
if(tmp.type == mask_enum::mask_top_left)
|
||||
{
|
||||
set_causal_top_left();
|
||||
}
|
||||
else if(tmp.type == mask_enum::mask_bottom_right)
|
||||
{
|
||||
set_causal_bottom_right();
|
||||
}
|
||||
}
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -22,6 +22,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -85,6 +86,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
@ -94,7 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
@ -116,6 +117,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -139,6 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -20,6 +20,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -117,6 +118,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -140,6 +142,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with RotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen,
|
||||
ck_tile::index_t rotary_dim,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
// return dummy tensors if we won't apply RoPE at all
|
||||
if(rotary_dim <= 0)
|
||||
{
|
||||
ck_tile::HostTensor<DataType> dummy({1, 1});
|
||||
return std::make_tuple(dummy, dummy);
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen * 2;
|
||||
const ck_tile::index_t num_cols = rotary_dim / 2;
|
||||
|
||||
using std::begin, std::end;
|
||||
|
||||
ck_tile::HostTensor<float> angle({num_rows, num_cols});
|
||||
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
|
||||
|
||||
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::cos(origin_value));
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::sin(origin_value));
|
||||
});
|
||||
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
const ck_tile::HostTensor<DataType>& sin,
|
||||
ck_tile::index_t seqlen_offset,
|
||||
ck_tile::index_t seqlen)
|
||||
{
|
||||
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
|
||||
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
|
||||
|
||||
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen;
|
||||
const ck_tile::index_t num_cols = cos.get_length(1);
|
||||
|
||||
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
|
||||
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
|
||||
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
return std::make_tuple(cos_pt, sin_pt);
|
||||
}
|
||||
@ -5,6 +5,12 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# Note - hf and timm have their own version of this, torchbench does not
|
||||
# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this...
|
||||
def model_names(filename: str) -> set[str]:
|
||||
@ -17,6 +23,8 @@ def model_names(filename: str) -> set[str]:
|
||||
if len(line_parts) == 1:
|
||||
line_parts = line.split(",")
|
||||
model_name = line_parts[0]
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
names.add(model_name)
|
||||
return names
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import copy
|
||||
import csv
|
||||
import dataclasses
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
@ -2387,6 +2388,7 @@ class BenchmarkRunner:
|
||||
)
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=10):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
@ -2548,6 +2550,7 @@ class BenchmarkRunner:
|
||||
return experiment(*self.maybe_cast(model, example_inputs))
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=5):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
|
||||
@ -106,6 +106,11 @@ finally:
|
||||
# on A100 GPUs - 40 GB.
|
||||
BATCH_SIZE_KNOWN_MODELS = {}
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
|
||||
# Get the list of models and their batch sizes
|
||||
@ -116,6 +121,8 @@ with open(MODELS_FILENAME) as fh:
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
model_name, batch_size = line.split(",")
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
batch_size = int(batch_size)
|
||||
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
|
||||
assert len(BATCH_SIZE_KNOWN_MODELS)
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1009000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8787000000,0.1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -39,13 +39,20 @@ finally:
|
||||
from timm.models import create_model
|
||||
|
||||
TIMM_MODELS = {}
|
||||
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
||||
with open(filename) as fh:
|
||||
lines = fh.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
model_name, batch_size = line.split(" ")
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
TIMM_MODELS[model_name] = int(batch_size)
|
||||
|
||||
|
||||
|
||||
@ -599,6 +599,7 @@ libtorch_nativert_sources = [
|
||||
"torch/nativert/graph/GraphSignature.cpp",
|
||||
"torch/nativert/graph/Serialization.cpp",
|
||||
"torch/nativert/graph/TensorMeta.cpp",
|
||||
"torch/nativert/graph/GraphUtils.cpp",
|
||||
"torch/nativert/executor/DelegateExecutor.cpp",
|
||||
"torch/nativert/executor/Placement.cpp",
|
||||
"torch/nativert/executor/ExecutionPlanner.cpp",
|
||||
|
||||
@ -45,7 +45,7 @@ size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
|
||||
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
|
||||
const size_t interval_end =
|
||||
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
|
||||
"kRoundUpPowerOfTwoIntervals mismatch");
|
||||
|
||||
@ -64,7 +64,7 @@ size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
|
||||
std::numeric_limits<size_t>::max() / kMB;
|
||||
|
||||
size_t val_env = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
val_env >= min_allowed_split_size_mb,
|
||||
"CachingAllocator option max_split_size_mb too small, must be >= ",
|
||||
min_allowed_split_size_mb);
|
||||
@ -83,7 +83,7 @@ size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
|
||||
std::numeric_limits<size_t>::max() / kMB;
|
||||
|
||||
size_t val_env = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
val_env >= min_allowed_split_size_mb,
|
||||
"CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
|
||||
min_allowed_split_size_mb);
|
||||
@ -98,7 +98,7 @@ size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
double val_env = tokenizer.toDouble(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
val_env > 0 && val_env < 1.0,
|
||||
"garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
|
||||
garbage_collection_threshold_ = val_env;
|
||||
@ -119,7 +119,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
size_t value_index = i;
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t value = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
value == 0 || llvm::isPowerOf2_64(value),
|
||||
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
|
||||
|
||||
@ -133,7 +133,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
value);
|
||||
} else {
|
||||
size_t boundary = tokenizer.toSizeT(value_index);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(boundary),
|
||||
"For roundups, the intervals have to be power of 2 ");
|
||||
|
||||
@ -163,7 +163,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
|
||||
} else { // Keep this for backwards compatibility
|
||||
size_t value = tokenizer.toSizeT(i);
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(value),
|
||||
"For roundups, the divisions has to be power of 2 ");
|
||||
std::fill(
|
||||
|
||||
@ -76,7 +76,7 @@ class ConfigTokenizer {
|
||||
} else if (token == "False") {
|
||||
return false;
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected 'True' or 'False' at index ",
|
||||
i,
|
||||
|
||||
@ -1,119 +1,389 @@
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
size_t CUDAAllocatorConfig::parseAllocatorConfig(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
constexpr size_t kRoundUpPowerOfTwoIntervals = 16;
|
||||
|
||||
CUDAAllocatorConfig::CUDAAllocatorConfig()
|
||||
: m_max_split_size(std::numeric_limits<size_t>::max()),
|
||||
m_max_non_split_rounding_size(kLargeBuffer),
|
||||
m_garbage_collection_threshold(0),
|
||||
m_pinned_num_register_threads(1),
|
||||
m_expandable_segments(false),
|
||||
#if CUDA_VERSION >= 12030
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::UNSPECIFIED),
|
||||
#else
|
||||
m_expandable_segments_handle_type(
|
||||
Expandable_Segments_Handle_Type::POSIX_FD),
|
||||
#endif
|
||||
m_release_lock_on_cudamalloc(false),
|
||||
m_pinned_use_cuda_host_register(false),
|
||||
m_pinned_use_background_threads(false) {
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) {
|
||||
size_t log_size = (63 - llvm::countLeadingZeros(size));
|
||||
|
||||
// Our intervals start at 1MB and end at 64GB
|
||||
const size_t interval_start =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(1048576));
|
||||
const size_t interval_end =
|
||||
63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736));
|
||||
TORCH_CHECK(
|
||||
(interval_end - interval_start == kRoundUpPowerOfTwoIntervals),
|
||||
"kRoundUpPowerOfTwoIntervals mismatch");
|
||||
|
||||
int index = static_cast<int>(log_size) - static_cast<int>(interval_start);
|
||||
|
||||
index = std::max(0, index);
|
||||
index = std::min(index, static_cast<int>(kRoundUpPowerOfTwoIntervals) - 1);
|
||||
return instance().m_roundup_power2_divisions[index];
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::lexArgs(
|
||||
const std::string& env,
|
||||
std::vector<std::string>& config) {
|
||||
std::vector<char> buf;
|
||||
|
||||
for (char ch : env) {
|
||||
if (ch == ',' || ch == ':' || ch == '[' || ch == ']') {
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
buf.clear();
|
||||
}
|
||||
config.emplace_back(1, ch);
|
||||
} else if (ch != ' ') {
|
||||
buf.emplace_back(ch);
|
||||
}
|
||||
}
|
||||
if (!buf.empty()) {
|
||||
config.emplace_back(buf.begin(), buf.end());
|
||||
}
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c) {
|
||||
TORCH_CHECK(
|
||||
i < config.size() && config[i] == std::string(1, c),
|
||||
"Error parsing CachingAllocator settings, expected ",
|
||||
c,
|
||||
"");
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxSplitSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_split_size_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_split_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_split_size_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
constexpr int mb = 1024 * 1024;
|
||||
if (++i < config.size()) {
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > kLargeBuffer / mb,
|
||||
"CachingAllocator option max_non_split_rounding_mb too small, must be > ",
|
||||
kLargeBuffer / mb,
|
||||
"");
|
||||
val1 = std::max(val1, kLargeBuffer / mb);
|
||||
val1 = std::min(val1, (std::numeric_limits<size_t>::max() / mb));
|
||||
m_max_non_split_rounding_size = val1 * 1024 * 1024;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
double val1 = stod(config[i]);
|
||||
TORCH_CHECK(
|
||||
val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", "");
|
||||
TORCH_CHECK(
|
||||
val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", "");
|
||||
m_garbage_collection_threshold = val1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting garbage_collection_threshold value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
bool first_value = true;
|
||||
|
||||
if (++i < config.size()) {
|
||||
if (std::string_view(config[i]) == "[") {
|
||||
size_t last_index = 0;
|
||||
// NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
|
||||
while (++i < config.size() && std::string_view(config[i]) != "]") {
|
||||
const std::string& val1 = config[i];
|
||||
size_t val2 = 0;
|
||||
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
val2 = stoi(config[i]);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error parsing roundup_power2_divisions value", "");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
val2 == 0 || llvm::isPowerOf2_64(val2),
|
||||
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ",
|
||||
"");
|
||||
|
||||
if (std::string_view(val1) == ">") {
|
||||
std::fill(
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
last_index)),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val2);
|
||||
} else {
|
||||
size_t val1_long = stoul(val1);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1_long),
|
||||
"For roundups, the intervals have to be power of 2 ",
|
||||
"");
|
||||
|
||||
size_t index = 63 - llvm::countLeadingZeros(val1_long);
|
||||
index = std::max((size_t)0, index);
|
||||
index = std::min(index, m_roundup_power2_divisions.size() - 1);
|
||||
|
||||
if (first_value) {
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
std::next(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
static_cast<std::vector<unsigned long>::difference_type>(
|
||||
index)),
|
||||
val2);
|
||||
first_value = false;
|
||||
}
|
||||
if (index < m_roundup_power2_divisions.size()) {
|
||||
m_roundup_power2_divisions[index] = val2;
|
||||
}
|
||||
last_index = index;
|
||||
}
|
||||
|
||||
if (std::string_view(config[i + 1]) != "]") {
|
||||
consumeToken(config, ++i, ',');
|
||||
}
|
||||
}
|
||||
} else { // Keep this for backwards compatibility
|
||||
size_t val1 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val1),
|
||||
"For roundups, the divisions has to be power of 2 ",
|
||||
"");
|
||||
std::fill(
|
||||
m_roundup_power2_divisions.begin(),
|
||||
m_roundup_power2_divisions.end(),
|
||||
val1);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync) {
|
||||
// For ease of maintenance and understanding, the CUDA and ROCm
|
||||
// implementations of this function are separated. This avoids having many
|
||||
// #ifdef's throughout.
|
||||
#ifdef USE_ROCM
|
||||
// Ease burden on ROCm users by allowing either cuda or hip tokens.
|
||||
// cuda token is broken up to prevent hipify matching it.
|
||||
#define PYTORCH_TOKEN1 \
|
||||
"cud" \
|
||||
"aMallocAsync"
|
||||
#define PYTORCH_TOKEN2 "hipMallocAsync"
|
||||
tokenizer.checkToken(++i, ":");
|
||||
i++; // Move to the value after the colon
|
||||
TORCH_CHECK_VALUE(
|
||||
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
|
||||
(tokenizer[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
if (m_is_allocator_loaded) {
|
||||
bool aync_allocator_at_runtime = (tokenizer[i] != "native");
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
aync_allocator_at_runtime == m_use_async_allocator,
|
||||
"Allocator async backend parsed at runtime != allocator async backend parsed at load time, ",
|
||||
aync_allocator_at_runtime,
|
||||
((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) ||
|
||||
(config[i] == PYTORCH_TOKEN2)),
|
||||
"Unknown allocator backend, "
|
||||
"options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2);
|
||||
used_cudaMallocAsync =
|
||||
(config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name() ||
|
||||
(config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time, ",
|
||||
config[i],
|
||||
" != ",
|
||||
m_use_async_allocator);
|
||||
get()->name());
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
}
|
||||
m_use_async_allocator =
|
||||
(tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2);
|
||||
// CUDA allocator is always loaded at the start of the program
|
||||
m_is_allocator_loaded = true;
|
||||
|
||||
#if defined(CUDA_VERSION)
|
||||
if (m_use_async_allocator) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
return i;
|
||||
#undef PYTORCH_TOKEN1
|
||||
#undef PYTORCH_TOKEN2
|
||||
#else // USE_ROCM
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
((config[i] == "native") || (config[i] == "cudaMallocAsync")),
|
||||
"Unknown allocator backend, "
|
||||
"options are native and cudaMallocAsync");
|
||||
used_cudaMallocAsync = (config[i] == "cudaMallocAsync");
|
||||
if (used_cudaMallocAsync) {
|
||||
#if CUDA_VERSION >= 11040
|
||||
int version = 0;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&version));
|
||||
TORCH_CHECK(
|
||||
version >= 11040,
|
||||
"backend:cudaMallocAsync requires CUDA runtime "
|
||||
"11.4 or newer, but cudaDriverGetVersion returned ",
|
||||
version);
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"backend:cudaMallocAsync requires PyTorch to be built with "
|
||||
"CUDA 11.4 or newer, but CUDA_VERSION is ",
|
||||
CUDA_VERSION);
|
||||
#endif
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
config[i] == get()->name(),
|
||||
"Allocator backend parsed at runtime != "
|
||||
"allocator backend parsed at load time");
|
||||
} else {
|
||||
TORCH_CHECK(false, "Error parsing backend value", "");
|
||||
}
|
||||
return i;
|
||||
#endif // USE_ROCM
|
||||
}
|
||||
|
||||
void CUDAAllocatorConfig::parseArgs(const std::string& env) {
|
||||
void CUDAAllocatorConfig::parseArgs(const std::optional<std::string>& env) {
|
||||
// If empty, set the default values
|
||||
m_max_split_size = std::numeric_limits<size_t>::max();
|
||||
m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0);
|
||||
m_garbage_collection_threshold = 0;
|
||||
bool used_cudaMallocAsync = false;
|
||||
bool used_native_specific_option = false;
|
||||
|
||||
c10::CachingAllocator::ConfigTokenizer tokenizer(env);
|
||||
for (size_t i = 0; i < tokenizer.size(); i++) {
|
||||
const auto& key = tokenizer[i];
|
||||
if (key == "backend") {
|
||||
i = parseAllocatorConfig(tokenizer, i);
|
||||
if (!env.has_value()) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(m_last_allocator_settings_mutex);
|
||||
m_last_allocator_settings = env.value();
|
||||
}
|
||||
|
||||
std::vector<std::string> config;
|
||||
lexArgs(env.value(), config);
|
||||
|
||||
for (size_t i = 0; i < config.size(); i++) {
|
||||
std::string_view config_item_view(config[i]);
|
||||
if (config_item_view == "max_split_size_mb") {
|
||||
i = parseMaxSplitSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "max_non_split_rounding_mb") {
|
||||
i = parseMaxNonSplitRoundingSize(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "garbage_collection_threshold") {
|
||||
i = parseGarbageCollectionThreshold(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "roundup_power2_divisions") {
|
||||
i = parseRoundUpPower2Divisions(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "backend") {
|
||||
i = parseAllocatorConfig(config, i, used_cudaMallocAsync);
|
||||
} else if (config_item_view == "expandable_segments") {
|
||||
used_native_specific_option = true;
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for expandable_segments");
|
||||
config_item_view = config[i];
|
||||
m_expandable_segments = (config_item_view == "True");
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
key == "release_lock_on_hipmalloc" ||
|
||||
key ==
|
||||
config_item_view == "release_lock_on_hipmalloc" ||
|
||||
config_item_view ==
|
||||
"release_lock_on_c"
|
||||
"udamalloc") {
|
||||
used_native_specific_option = true;
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_release_lock_on_cudamalloc = tokenizer.toBool(++i);
|
||||
consumeToken(config, ++i, ':');
|
||||
++i;
|
||||
TORCH_CHECK(
|
||||
i < config.size() &&
|
||||
(std::string_view(config[i]) == "True" ||
|
||||
std::string_view(config[i]) == "False"),
|
||||
"Expected a single True/False argument for release_lock_on_cudamalloc");
|
||||
config_item_view = config[i];
|
||||
m_release_lock_on_cudamalloc = (config_item_view == "True");
|
||||
} else if (
|
||||
// ROCm build's hipify step will change "cuda" to "hip", but for ease of
|
||||
// use, accept both. We must break up the string to prevent hipify here.
|
||||
key == "pinned_use_hip_host_register" ||
|
||||
key ==
|
||||
config_item_view == "pinned_use_hip_host_register" ||
|
||||
config_item_view ==
|
||||
"pinned_use_c"
|
||||
"uda_host_register") {
|
||||
i = parsePinnedUseCudaHostRegister(tokenizer, i);
|
||||
i = parsePinnedUseCudaHostRegister(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (key == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(tokenizer, i);
|
||||
} else if (config_item_view == "pinned_num_register_threads") {
|
||||
i = parsePinnedNumRegisterThreads(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else if (config_item_view == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(config, i);
|
||||
used_native_specific_option = true;
|
||||
} else {
|
||||
const auto& keys =
|
||||
c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys();
|
||||
TORCH_CHECK(
|
||||
keys.find(key) != keys.end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in Accelerator allocator config.");
|
||||
i = tokenizer.skipKey(i);
|
||||
false, "Unrecognized CachingAllocator option: ", config_item_view);
|
||||
}
|
||||
|
||||
if (i + 1 < tokenizer.size()) {
|
||||
tokenizer.checkToken(++i, ",");
|
||||
if (i + 1 < config.size()) {
|
||||
consumeToken(config, ++i, ',');
|
||||
}
|
||||
}
|
||||
|
||||
if (m_use_async_allocator && used_native_specific_option) {
|
||||
if (used_cudaMallocAsync && used_native_specific_option) {
|
||||
TORCH_WARN(
|
||||
"backend:cudaMallocAsync ignores max_split_size_mb,"
|
||||
"roundup_power2_divisions, and garbage_collect_threshold.");
|
||||
@ -121,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) {
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
m_pinned_use_cuda_host_register = tokenizer.toBool(++i);
|
||||
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_cuda_host_register");
|
||||
m_pinned_use_cuda_host_register = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_cuda_host_register value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
tokenizer.checkToken(++i, ":");
|
||||
size_t val2 = tokenizer.toSizeT(++i);
|
||||
TORCH_CHECK_VALUE(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2 ",
|
||||
"");
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK_VALUE(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to " +
|
||||
std::to_string(maxThreads),
|
||||
"");
|
||||
m_pinned_num_register_threads = val2;
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
size_t val2 = stoi(config[i]);
|
||||
TORCH_CHECK(
|
||||
llvm::isPowerOf2_64(val2),
|
||||
"Number of register threads has to be power of 2 ",
|
||||
"");
|
||||
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
|
||||
TORCH_CHECK(
|
||||
val2 <= maxThreads,
|
||||
"Number of register threads should be less than or equal to " +
|
||||
std::to_string(maxThreads),
|
||||
"");
|
||||
m_pinned_num_register_threads = val2;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_num_register_threads value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig)
|
||||
size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i) {
|
||||
consumeToken(config, ++i, ':');
|
||||
if (++i < config.size()) {
|
||||
TORCH_CHECK(
|
||||
(config[i] == "True" || config[i] == "False"),
|
||||
"Expected a single True/False argument for pinned_use_background_threads");
|
||||
m_pinned_use_background_threads = (config[i] == "True");
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false, "Error, expecting pinned_use_background_threads value", "");
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
// General caching allocator utilities
|
||||
void setAllocatorSettings(const std::string& env) {
|
||||
CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str());
|
||||
}
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/util/Deprecated.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/env.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
enum class Expandable_Segments_Handle_Type : int {
|
||||
@ -18,28 +22,21 @@ enum class Expandable_Segments_Handle_Type : int {
|
||||
// Environment config parser
|
||||
class C10_CUDA_API CUDAAllocatorConfig {
|
||||
public:
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.")
|
||||
static size_t max_split_size() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size();
|
||||
return instance().m_max_split_size;
|
||||
}
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.")
|
||||
static double garbage_collection_threshold() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
garbage_collection_threshold();
|
||||
return instance().m_garbage_collection_threshold;
|
||||
}
|
||||
|
||||
static bool expandable_segments() {
|
||||
bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
use_expandable_segments();
|
||||
#ifndef PYTORCH_C10_DRIVER_API_SUPPORTED
|
||||
if (enabled) {
|
||||
if (instance().m_expandable_segments) {
|
||||
TORCH_WARN_ONCE("expandable_segments not supported on this platform")
|
||||
}
|
||||
return false;
|
||||
#else
|
||||
return enabled;
|
||||
return instance().m_expandable_segments;
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -65,11 +62,8 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_pinned_num_register_threads;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.")
|
||||
static bool pinned_use_background_threads() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
pinned_use_background_threads();
|
||||
return instance().m_pinned_use_background_threads;
|
||||
}
|
||||
|
||||
static size_t pinned_max_register_threads() {
|
||||
@ -79,105 +73,92 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return 128;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static size_t roundup_power2_divisions(size_t size) {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions(size);
|
||||
}
|
||||
// This is used to round-up allocation size to nearest power of 2 divisions.
|
||||
// More description below in function roundup_power2_next_division
|
||||
// As an example, if we want 4 divisions between 2's power, this can be done
|
||||
// using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4
|
||||
static size_t roundup_power2_divisions(size_t size);
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.")
|
||||
static std::vector<size_t> roundup_power2_divisions() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
roundup_power2_divisions();
|
||||
return instance().m_roundup_power2_divisions;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.")
|
||||
static size_t max_non_split_rounding_size() {
|
||||
return c10::CachingAllocator::AcceleratorAllocatorConfig::
|
||||
max_non_split_rounding_size();
|
||||
return instance().m_max_non_split_rounding_size;
|
||||
}
|
||||
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.")
|
||||
static std::string last_allocator_settings() {
|
||||
return c10::CachingAllocator::getAllocatorSettings();
|
||||
}
|
||||
|
||||
static bool use_async_allocator() {
|
||||
return instance().m_use_async_allocator;
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
std::lock_guard<std::mutex> lock(
|
||||
instance().m_last_allocator_settings_mutex);
|
||||
return instance().m_last_allocator_settings;
|
||||
}
|
||||
|
||||
static CUDAAllocatorConfig& instance() {
|
||||
static CUDAAllocatorConfig* s_instance = ([]() {
|
||||
auto inst = new CUDAAllocatorConfig();
|
||||
auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF");
|
||||
if (!env.has_value()) {
|
||||
// For backward compatibility, check for the old environment variable
|
||||
// PYTORCH_CUDA_ALLOC_CONF.
|
||||
env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
}
|
||||
auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users, allow alternative HIP token
|
||||
if (!env.has_value()) {
|
||||
env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
|
||||
}
|
||||
#endif
|
||||
if (env.has_value()) {
|
||||
inst->parseArgs(env.value());
|
||||
}
|
||||
inst->parseArgs(env);
|
||||
return inst;
|
||||
})();
|
||||
return *s_instance;
|
||||
}
|
||||
|
||||
void parseArgs(const std::string& env);
|
||||
void parseArgs(const std::optional<std::string>& env);
|
||||
|
||||
private:
|
||||
CUDAAllocatorConfig() = default;
|
||||
CUDAAllocatorConfig();
|
||||
|
||||
size_t parseAllocatorConfig(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
static void lexArgs(const std::string& env, std::vector<std::string>& config);
|
||||
static void consumeToken(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
const char c);
|
||||
size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i);
|
||||
size_t parseMaxNonSplitRoundingSize(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseGarbageCollectionThreshold(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseRoundUpPower2Divisions(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parseAllocatorConfig(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i,
|
||||
bool& used_cudaMallocAsync);
|
||||
size_t parsePinnedUseCudaHostRegister(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parsePinnedNumRegisterThreads(
|
||||
const c10::CachingAllocator::ConfigTokenizer& tokenizer,
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
size_t parsePinnedUseBackgroundThreads(
|
||||
const std::vector<std::string>& config,
|
||||
size_t i);
|
||||
|
||||
std::atomic<size_t> m_pinned_num_register_threads{1};
|
||||
std::atomic<Expandable_Segments_Handle_Type> m_expandable_segments_handle_type
|
||||
#if CUDA_VERSION >= 12030
|
||||
{Expandable_Segments_Handle_Type::UNSPECIFIED};
|
||||
#else
|
||||
{Expandable_Segments_Handle_Type::POSIX_FD};
|
||||
#endif
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc{false};
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register{false};
|
||||
std::atomic<bool> m_use_async_allocator{false};
|
||||
std::atomic<bool> m_is_allocator_loaded{false};
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"backend",
|
||||
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
|
||||
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_cud"
|
||||
"amalloc",
|
||||
"pinned_use_cud"
|
||||
"a_host_register",
|
||||
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_hipmalloc",
|
||||
"pinned_use_hip_host_register",
|
||||
"pinned_num_register_threads"};
|
||||
std::atomic<size_t> m_max_split_size;
|
||||
std::atomic<size_t> m_max_non_split_rounding_size;
|
||||
std::vector<size_t> m_roundup_power2_divisions;
|
||||
std::atomic<double> m_garbage_collection_threshold;
|
||||
std::atomic<size_t> m_pinned_num_register_threads;
|
||||
std::atomic<bool> m_expandable_segments;
|
||||
std::atomic<Expandable_Segments_Handle_Type>
|
||||
m_expandable_segments_handle_type;
|
||||
std::atomic<bool> m_release_lock_on_cudamalloc;
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register;
|
||||
std::atomic<bool> m_pinned_use_background_threads;
|
||||
std::string m_last_allocator_settings;
|
||||
std::mutex m_last_allocator_settings_mutex;
|
||||
};
|
||||
|
||||
// Keep this for backwards compatibility
|
||||
using c10::CachingAllocator::setAllocatorSettings;
|
||||
// General caching allocator utilities
|
||||
C10_CUDA_API void setAllocatorSettings(const std::string& env);
|
||||
|
||||
} // namespace c10::cuda::CUDACachingAllocator
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator {
|
||||
using namespace c10::CachingAllocator;
|
||||
using namespace c10::CachingDeviceAllocator;
|
||||
|
||||
// Included here as this is externally used in CUDAAllocatorConfig
|
||||
const size_t kLargeBuffer =
|
||||
20971520; // "large" allocations may be packed in 20 MiB blocks
|
||||
|
||||
namespace Native {
|
||||
|
||||
//
|
||||
@ -368,14 +373,12 @@ struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
std::optional<cudaStream_t> stream,
|
||||
size_t address_space_size,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers)
|
||||
: device_(device),
|
||||
stream_(stream),
|
||||
// 2MB for small pool, 20MB for large pool
|
||||
segment_size_(segment_size),
|
||||
max_handles_(numSegments(address_space_size)),
|
||||
peers_(std::move(peers)) {
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_));
|
||||
@ -544,11 +547,7 @@ struct ExpandableSegment {
|
||||
ShareHeader header{};
|
||||
buf.read((char*)&header, sizeof(ShareHeader));
|
||||
auto segment = std::make_unique<ExpandableSegment>(
|
||||
device,
|
||||
std::nullopt,
|
||||
header.num_handles * header.segment_size,
|
||||
header.segment_size,
|
||||
std::move(peers));
|
||||
device, std::nullopt, header.segment_size, std::move(peers));
|
||||
// older build setups (e.g. multiwheels) do not have this syscall, added 2020
|
||||
// but the kernel on the system might still support it.
|
||||
#ifndef SYS_pidfd_open
|
||||
@ -746,7 +745,6 @@ struct ExpandableSegment {
|
||||
ExpandableSegment(
|
||||
c10::DeviceIndex device,
|
||||
std::optional<cudaStream_t> stream,
|
||||
size_t address_space_size,
|
||||
size_t segment_size,
|
||||
std::vector<c10::DeviceIndex> peers) {
|
||||
TORCH_INTERNAL_ASSERT(false, "expandable segment not supported");
|
||||
@ -1225,7 +1223,7 @@ class DeviceCachingAllocator {
|
||||
DeviceCachingAllocator()
|
||||
: large_blocks(/*small=*/false), small_blocks(/*small=*/true) {
|
||||
stats.max_split_size =
|
||||
static_cast<int64_t>(AcceleratorAllocatorConfig::max_split_size());
|
||||
static_cast<int64_t>(CUDAAllocatorConfig::max_split_size());
|
||||
context_recorder_.store(nullptr);
|
||||
}
|
||||
|
||||
@ -1350,8 +1348,7 @@ class DeviceCachingAllocator {
|
||||
// Do garbage collection if the flag is set.
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() >
|
||||
0.0)) {
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
garbage_collect_cached_blocks(context);
|
||||
}
|
||||
// Attempt allocate
|
||||
@ -1603,7 +1600,7 @@ class DeviceCachingAllocator {
|
||||
stats.active_bytes[stat_type].increase(block->size);
|
||||
stats.requested_bytes[stat_type].increase(block->requested_size);
|
||||
});
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.increase(1);
|
||||
|
||||
auto allocated_bytes_gauge =
|
||||
@ -1654,7 +1651,7 @@ class DeviceCachingAllocator {
|
||||
block->pool->owner_MempoolId(),
|
||||
context ? context : block->context_when_allocated);
|
||||
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
stats.oversize_allocations.decrease(1);
|
||||
|
||||
if (!block->stream_uses.empty()) {
|
||||
@ -2203,8 +2200,7 @@ class DeviceCachingAllocator {
|
||||
if (size < kMinBlockSize) {
|
||||
return kMinBlockSize;
|
||||
} else {
|
||||
auto divisions =
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions(size);
|
||||
auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size);
|
||||
if (divisions > 1 && size > (kMinBlockSize * divisions)) {
|
||||
return roundup_power2_next_division(size, divisions);
|
||||
} else {
|
||||
@ -2420,19 +2416,8 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
auto segment_size = pool->is_small ? kSmallBuffer : kLargeBuffer;
|
||||
cudaDeviceProp prop{};
|
||||
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
||||
// we allocate enough address space for 1 1/8 the total memory on the GPU.
|
||||
// This allows for some cases where we have to unmap pages earlier in the
|
||||
// segment to put them at the end.
|
||||
size_t address_space_size = prop.totalGlobalMem + prop.totalGlobalMem / 8;
|
||||
|
||||
expandable_segments_.emplace_back(new ExpandableSegment(
|
||||
device,
|
||||
stream,
|
||||
address_space_size,
|
||||
segment_size,
|
||||
devices_with_peer_access_));
|
||||
device, stream, segment_size, devices_with_peer_access_));
|
||||
|
||||
ExpandableSegment* es = expandable_segments_.back();
|
||||
Block* candidate = new Block(device, stream, es->size(), pool, es->ptr());
|
||||
@ -2694,7 +2679,7 @@ class DeviceCachingAllocator {
|
||||
if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) {
|
||||
return remaining >= kMinBlockSize;
|
||||
} else {
|
||||
return (size < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
return (size < CUDAAllocatorConfig::max_split_size()) &&
|
||||
(remaining > kSmallSize);
|
||||
}
|
||||
}
|
||||
@ -2714,7 +2699,7 @@ class DeviceCachingAllocator {
|
||||
|
||||
if (C10_UNLIKELY(
|
||||
set_fraction &&
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) {
|
||||
// Track block reuse interval only when garbage collection is enabled.
|
||||
++pool.get_free_blocks_call_count;
|
||||
}
|
||||
@ -2756,13 +2741,13 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
|
||||
// Do not return an oversized block for a large request
|
||||
if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()))
|
||||
if ((p.size() < CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()))
|
||||
return false;
|
||||
// Allow oversized block size to be rounded up but within a limit
|
||||
if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
if ((p.size() >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >=
|
||||
p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size()))
|
||||
p.size() + CUDAAllocatorConfig::max_non_split_rounding_size()))
|
||||
return false;
|
||||
p.block = *it;
|
||||
pool.blocks.erase(it);
|
||||
@ -2785,7 +2770,7 @@ class DeviceCachingAllocator {
|
||||
// therefore should be of less overheads.
|
||||
|
||||
size_t gc_threshold = static_cast<size_t>(
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold() *
|
||||
CUDAAllocatorConfig::garbage_collection_threshold() *
|
||||
static_cast<double>(allowed_memory_maximum));
|
||||
// No need to trigger GC yet
|
||||
if (total_allocated_memory <= gc_threshold) {
|
||||
@ -2933,7 +2918,7 @@ class DeviceCachingAllocator {
|
||||
stats.segment[stat_type].increase(1);
|
||||
stats.reserved_bytes[stat_type].increase(size);
|
||||
});
|
||||
if (size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
if (size >= CUDAAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.increase(1);
|
||||
auto reserved_bytes_gauge =
|
||||
STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes);
|
||||
@ -2962,7 +2947,7 @@ class DeviceCachingAllocator {
|
||||
bool release_available_cached_blocks(
|
||||
const AllocParams& p,
|
||||
const std::shared_ptr<GatheredContext>& context) {
|
||||
if (AcceleratorAllocatorConfig::max_split_size() ==
|
||||
if (CUDAAllocatorConfig::max_split_size() ==
|
||||
std::numeric_limits<size_t>::max())
|
||||
return false;
|
||||
BlockPool& pool = *p.pool;
|
||||
@ -2970,8 +2955,8 @@ class DeviceCachingAllocator {
|
||||
// because of std::unique_ptr, block cannot be trivially copied
|
||||
// Use constructor for search key.
|
||||
Block key(p.search_key.device, p.search_key.stream, p.search_key.size);
|
||||
key.size = (key.size < AcceleratorAllocatorConfig::max_split_size())
|
||||
? AcceleratorAllocatorConfig::max_split_size()
|
||||
key.size = (key.size < CUDAAllocatorConfig::max_split_size())
|
||||
? CUDAAllocatorConfig::max_split_size()
|
||||
: key.size;
|
||||
auto it = pool.blocks.lower_bound(&key);
|
||||
if (it == pool.blocks.end() || (*it)->stream != p.stream() ||
|
||||
@ -2984,7 +2969,7 @@ class DeviceCachingAllocator {
|
||||
--it; // Back up one item. Now on the largest block for the correct
|
||||
// stream
|
||||
while ((totalReleased < key.size) &&
|
||||
((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) &&
|
||||
((*it)->size >= CUDAAllocatorConfig::max_split_size()) &&
|
||||
((*it)->stream == p.stream())) {
|
||||
auto cur = it;
|
||||
bool is_first = cur == pool.blocks.begin();
|
||||
@ -3109,7 +3094,7 @@ class DeviceCachingAllocator {
|
||||
stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)]
|
||||
.current);
|
||||
|
||||
if (block->size >= AcceleratorAllocatorConfig::max_split_size())
|
||||
if (block->size >= CUDAAllocatorConfig::max_split_size())
|
||||
stats.oversize_segments.decrease(1);
|
||||
pool->blocks.erase(block);
|
||||
delete block;
|
||||
@ -3736,8 +3721,8 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
|
||||
auto& md = result.config_metadata;
|
||||
md.garbage_collection_threshold =
|
||||
AcceleratorAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = AcceleratorAllocatorConfig::max_split_size();
|
||||
CUDAAllocatorConfig::garbage_collection_threshold();
|
||||
md.max_split_size = CUDAAllocatorConfig::max_split_size();
|
||||
md.pinned_num_register_threads =
|
||||
CUDAAllocatorConfig::pinned_num_register_threads();
|
||||
md.expandable_segments = CUDAAllocatorConfig::expandable_segments();
|
||||
@ -3745,10 +3730,9 @@ class NativeCachingAllocator : public CUDAAllocator {
|
||||
CUDAAllocatorConfig::release_lock_on_cudamalloc();
|
||||
md.pinned_use_host_register =
|
||||
CUDAAllocatorConfig::pinned_use_cuda_host_register();
|
||||
md.last_allocator_settings =
|
||||
AcceleratorAllocatorConfig::last_allocator_settings();
|
||||
md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings();
|
||||
md.roundup_power2_divisions =
|
||||
AcceleratorAllocatorConfig::roundup_power2_divisions();
|
||||
CUDAAllocatorConfig::roundup_power2_divisions();
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -4126,10 +4110,49 @@ CUDAAllocator* allocator();
|
||||
} // namespace CudaMallocAsync
|
||||
|
||||
struct BackendStaticInitializer {
|
||||
// Parses env for backend at load time, duplicating some logic from
|
||||
// CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at
|
||||
// runtime). Defers verbose exceptions and error checks, including Cuda
|
||||
// version checks, to CUDAAllocatorConfig's runtime doublecheck. If this
|
||||
// works, maybe we should move all of CUDAAllocatorConfig here?
|
||||
CUDAAllocator* parseEnvForBackend() {
|
||||
// If the environment variable is set, we use the CudaMallocAsync allocator.
|
||||
if (CUDAAllocatorConfig::use_async_allocator()) {
|
||||
return CudaMallocAsync::allocator();
|
||||
auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF");
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
if (!val.has_value()) {
|
||||
val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF");
|
||||
}
|
||||
#endif
|
||||
if (val.has_value()) {
|
||||
const std::string& config = val.value();
|
||||
|
||||
std::regex exp("[\\s,]+");
|
||||
std::sregex_token_iterator it(config.begin(), config.end(), exp, -1);
|
||||
std::sregex_token_iterator end;
|
||||
std::vector<std::string> options(it, end);
|
||||
|
||||
for (auto option : options) {
|
||||
std::regex exp2("[:]+");
|
||||
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
|
||||
std::sregex_token_iterator end2;
|
||||
std::vector<std::string> kv(it2, end2);
|
||||
if (kv.size() >= 2) {
|
||||
if (kv[0] == "backend") {
|
||||
#ifdef USE_ROCM
|
||||
// convenience for ROCm users to allow either CUDA or HIP env var
|
||||
if (kv[1] ==
|
||||
"cud"
|
||||
"aMallocAsync" ||
|
||||
kv[1] == "hipMallocAsync")
|
||||
#else
|
||||
if (kv[1] == "cudaMallocAsync")
|
||||
#endif
|
||||
return CudaMallocAsync::allocator();
|
||||
if (kv[1] == "native")
|
||||
return &Native::allocator;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Native::allocator;
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/cuda/CUDAAllocatorConfig.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator {
|
||||
|
||||
// Preserved only for BC reasons
|
||||
// NOLINTNEXTLINE(misc-unused-using-decls)
|
||||
using c10::CachingAllocator::kLargeBuffer;
|
||||
using c10::CachingDeviceAllocator::DeviceStats;
|
||||
|
||||
extern const size_t kLargeBuffer;
|
||||
|
||||
typedef std::shared_ptr<GatheredContext> (*CreateContextFn)();
|
||||
|
||||
// Struct containing info of an allocation block (i.e. a fractional part of a
|
||||
|
||||
@ -85,7 +85,6 @@ struct AtomicType<uchar> {
|
||||
}
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct AtomicType<bfloat> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
@ -93,7 +92,6 @@ struct AtomicType<bfloat> {
|
||||
atomic_add_helper<bfloat>(data, offset, value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// Metal supports atomic_store_explicit for bools, but
|
||||
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
#define C10_METAL_CONSTEXPR constexpr
|
||||
#endif
|
||||
|
||||
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
|
||||
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||
_(Byte, 0) \
|
||||
_(Char, 1) \
|
||||
@ -22,19 +21,6 @@
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11) \
|
||||
_(BFloat16, 15)
|
||||
#else
|
||||
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||
_(Byte, 0) \
|
||||
_(Char, 1) \
|
||||
_(Short, 2) \
|
||||
_(Int, 3) \
|
||||
_(Long, 4) \
|
||||
_(Half, 5) \
|
||||
_(Float, 6) \
|
||||
_(ComplexHalf, 8) \
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11)
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
|
||||
@ -186,10 +186,8 @@ inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
|
||||
return cast_to<T>(val_at_offs<float>(ptr, offs));
|
||||
case ScalarType::Half:
|
||||
return cast_to<T>(val_at_offs<half>(ptr, offs));
|
||||
#if __METAL_VERSION__ >= 310
|
||||
case ScalarType::BFloat16:
|
||||
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
|
||||
#endif
|
||||
// Complex
|
||||
case ScalarType::ComplexHalf:
|
||||
return cast_to<T>(val_at_offs<half2>(ptr, offs));
|
||||
|
||||
@ -15,12 +15,10 @@ struct simd_type {
|
||||
template <typename T>
|
||||
using simd_type_t = typename simd_type<T>::t;
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct simd_type<bfloat> {
|
||||
using t = float;
|
||||
};
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
@ -140,7 +138,7 @@ template <
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
@ -149,7 +147,7 @@ template <
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmin(T val) {
|
||||
const auto rc = simd_min(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
@ -158,7 +156,7 @@ template <
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc);
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <
|
||||
@ -167,7 +165,7 @@ template <
|
||||
inline ::c10::metal::pair<T, ushort> simd_argmax(T val) {
|
||||
const auto rc = simd_max(val);
|
||||
const auto vote = ::metal::simd_ballot(val == rc || ::metal::isnan(val));
|
||||
return {rc, ::metal::ctz(static_cast<ushort>(static_cast<ulong>(vote)))};
|
||||
return {rc, static_cast<ushort>(::metal::ctz(static_cast<ulong>(vote)))};
|
||||
}
|
||||
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
@ -303,30 +301,58 @@ float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int threadgroup_argmax(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
IDX_T threadgroup_argmax(
|
||||
threadgroup ARG_T* arg_data,
|
||||
threadgroup IDX_T* idx_data,
|
||||
ARG_T val,
|
||||
IDX_T idx_val,
|
||||
unsigned idx,
|
||||
unsigned size) {
|
||||
auto rc = simd_argmax(val, idx_val);
|
||||
if (size <= simdgroup_size) {
|
||||
return rc.second;
|
||||
}
|
||||
if (idx % simdgroup_size == 0) {
|
||||
arg_data[idx / simdgroup_size] = rc.first;
|
||||
idx_data[idx / simdgroup_size] = rc.second;
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
int rc = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] > data[rc]) {
|
||||
rc = idx;
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_argmax(arg_data[idx], idx_data[idx]);
|
||||
if (idx == 0) {
|
||||
idx_data[0] = rc1.second;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return idx_data[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int threadgroup_argmin(threadgroup T* data, unsigned size) {
|
||||
// TODO: This should be moved to the callee
|
||||
template <typename ARG_T, typename IDX_T>
|
||||
IDX_T threadgroup_argmin(
|
||||
threadgroup ARG_T* arg_data,
|
||||
threadgroup IDX_T* idx_data,
|
||||
ARG_T val,
|
||||
IDX_T idx_val,
|
||||
unsigned idx,
|
||||
unsigned size) {
|
||||
auto rc = simd_argmin(val, idx_val);
|
||||
if (size <= simdgroup_size) {
|
||||
return rc.second;
|
||||
}
|
||||
if (idx % simdgroup_size == 0) {
|
||||
arg_data[idx / simdgroup_size] = rc.first;
|
||||
idx_data[idx / simdgroup_size] = rc.second;
|
||||
}
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
int rc = 0;
|
||||
for (unsigned idx = 1; idx < size; ++idx) {
|
||||
if (data[idx] < data[rc]) {
|
||||
rc = idx;
|
||||
if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) {
|
||||
auto rc1 = simd_argmin(arg_data[idx], idx_data[idx]);
|
||||
if (idx == 0) {
|
||||
idx_data[0] = rc1.second;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
||||
return idx_data[0];
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
||||
@ -24,14 +24,12 @@ struct vectypes<half> {
|
||||
using type2 = half2;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct vectypes<bfloat> {
|
||||
using type4 = bfloat4;
|
||||
using type3 = bfloat3;
|
||||
using type2 = bfloat2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct vectypes<short> {
|
||||
@ -79,12 +77,10 @@ struct OpMathType<uchar> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct OpMathType<bfloat> {
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Type promotion structure for higher precision accumulation
|
||||
template <typename T>
|
||||
@ -98,13 +94,11 @@ struct AccumulationType<half> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
// Specialization for bfloat - promote to float for accumulation
|
||||
template <>
|
||||
struct AccumulationType<bfloat> {
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -130,7 +124,6 @@ min(T a, U b) {
|
||||
return ::metal::min(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
inline bfloat min(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
@ -142,7 +135,6 @@ inline bfloat max(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using vec2type_t = typename detail::vectypes<T>::type2;
|
||||
|
||||
@ -1,274 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
|
||||
: x(detail::fp8e4m3fn_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
|
||||
return detail::fp8e4m3fn_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
|
||||
return (x & 0b01111111) == 0b01111111;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn
|
||||
operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(
|
||||
const Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
|
||||
Float8_e4m3fn& a,
|
||||
const Float8_e4m3fn& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fn>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
|
||||
return static_cast<Float8_e4m3fn>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fn to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e4m3fn> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 4;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 3;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -5;
|
||||
static constexpr int min_exponent10 = -1;
|
||||
static constexpr int max_exponent = 8;
|
||||
static constexpr int max_exponent10 = 2;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e4m3fn min() {
|
||||
return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn lowest() {
|
||||
return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn max() {
|
||||
return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn epsilon() {
|
||||
return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn round_error() {
|
||||
return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn quiet_NaN() {
|
||||
return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fn denorm_min() {
|
||||
return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
|
||||
@ -1,238 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
/// bias = 7
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E4M3FN number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
const uint32_t w = (uint32_t)input << 24;
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = w & UINT32_C(0x80000000);
|
||||
/*
|
||||
* Extract mantissa and biased exponent of the input number into the bits 0-30
|
||||
* of the 32-bit word:
|
||||
*
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 27-30 24-26 0-23
|
||||
*/
|
||||
const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
|
||||
/*
|
||||
* Renorm shift is the number of bits to shift mantissa left to make the
|
||||
* half-precision number normalized. If the initial number is normalized, some
|
||||
* of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
|
||||
* renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
|
||||
* that if we shift denormalized nonsign by renorm_shift, the unit bit of
|
||||
* mantissa will shift into exponent, turning the biased exponent into 1, and
|
||||
* making mantissa normalized (i.e. without leading 1).
|
||||
*/
|
||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
uint32_t renorm_shift = __clz(nonsign);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
// Note: zero is not a supported input into `__builtin_clz`
|
||||
uint32_t renorm_shift =
|
||||
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
|
||||
#elif defined(_MSC_VER) && !defined(__clang__)
|
||||
unsigned long nonsign_bsr;
|
||||
_BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
|
||||
uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
|
||||
#else
|
||||
// Note: zero is not a supported input into `__builtin_clz`
|
||||
uint32_t renorm_shift =
|
||||
nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
|
||||
#endif
|
||||
renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
|
||||
/*
|
||||
* Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
|
||||
* the addition overflows it into bit 31, and the subsequent shift turns the
|
||||
* high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
|
||||
* is Nan, 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t inf_nan_mask =
|
||||
((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
|
||||
/*
|
||||
* Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
|
||||
* into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
|
||||
* broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
|
||||
* 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
|
||||
* 0x00000000 otherwise
|
||||
*/
|
||||
const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
|
||||
/*
|
||||
* 1. Shift nonsign left by renorm_shift to normalize it (if the input
|
||||
* was denormal)
|
||||
* 2. Shift nonsign right by 4 so the exponent (4 bits originally)
|
||||
* becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
|
||||
* bits of the 23-bit mantissa of IEEE single-precision number.
|
||||
* 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
|
||||
* different in exponent bias (0x7F for single-precision number less 0x07
|
||||
* for fp8e4m3fn number).
|
||||
* 4. Subtract renorm_shift from the exponent (starting at bit 23) to
|
||||
* account for renormalization. As renorm_shift is less than 0x78, this
|
||||
* can be combined with step 3.
|
||||
* 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
|
||||
* input was NaN or infinity.
|
||||
* 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
|
||||
* into zero if the input was zero.
|
||||
* 7. Combine with the sign of the input number.
|
||||
*/
|
||||
uint32_t result = sign |
|
||||
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
||||
inf_nan_mask) &
|
||||
~zero_mask);
|
||||
return fp32_from_bits(result);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 480.0f, which is the first value
|
||||
* not representable in fp8e4m3fn range:
|
||||
* 0 1111 111 - fp8e4m3fn
|
||||
* 0 10000111 11100000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fn normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 7) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = 0x7f;
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(121) << 23)) {
|
||||
// Input number is smaller than 2^(-6), which is the smallest
|
||||
// fp8e4m3fn normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e4m3fn {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e4m3fn() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fn(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e4m3fn.h>
|
||||
|
||||
@ -1,279 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
|
||||
: x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz
|
||||
operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
|
||||
const Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
|
||||
Float8_e4m3fnuz& a,
|
||||
const Float8_e4m3fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e4m3fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
|
||||
return static_cast<Float8_e4m3fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e4m3fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e4m3fnuz> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 4;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 3;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -6;
|
||||
static constexpr int min_exponent10 = -1;
|
||||
static constexpr int max_exponent = 8;
|
||||
static constexpr int max_exponent10 = 2;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e4m3fnuz min() {
|
||||
return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz lowest() {
|
||||
return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz max() {
|
||||
return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz epsilon() {
|
||||
return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz round_error() {
|
||||
return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz infinity() {
|
||||
// NaN (no infinities)
|
||||
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
|
||||
return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e4m3fnuz denorm_min() {
|
||||
return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
|
||||
@ -1,139 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e4m3fnuz type (8-bit floating-point) including
|
||||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration remains the same as Float8_e4m3fn:
|
||||
/// s eeee mmm
|
||||
/// 1 sign bit
|
||||
/// 4 exponent bits
|
||||
/// 3 mantissa bits
|
||||
/// The key differences versus Float8_e4m3fn are:
|
||||
/// bias = 8
|
||||
/// no infinities or negative zero
|
||||
/// NaN only when sign bit is 1, rest all 0s
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
|
||||
/// the existing Float8_e4m3fn implementation.
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 256.0f, which is the first value not representable
|
||||
* (i.e. the first value which would overflow in to the sign bit, resulting in
|
||||
* a NaN) in fp8e4m3fnuz range:
|
||||
* 1 0000 000 - fp8e4m3fnuz
|
||||
* 0 10000111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fnuz_max = UINT32_C(0x87) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 8) + (23 - 3) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
|
||||
uint32_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fnuz_max) {
|
||||
// NaN -- sign bit set to 1, rest 0s.
|
||||
return 0x80;
|
||||
}
|
||||
|
||||
if (f_bits < (UINT32_C(0x78) << 23) /* 2^-7 in float32 */) {
|
||||
// Input exponent is less than -7, the smallest e4m3fnuz exponent, so the
|
||||
// number will become subnormal.
|
||||
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
if (result == 0) {
|
||||
// fnuz types don't have negative zero.
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 20) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(8 - 127) << 23) + 0x7FFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 20);
|
||||
}
|
||||
|
||||
result |= sign >> 24;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e4m3fnuz {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e4m3fnuz() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e4m3fnuz(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e4m3fnuz(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e4m3fnuz& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e4m3fnuz-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e4m3fnuz.h>
|
||||
|
||||
@ -1,286 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
#define EXP_WIDTH_FP8 5
|
||||
#define MAN_WIDTH_FP8 2
|
||||
#define EXP_BIAS_FP8 15
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
|
||||
: x(detail::fp8e5m2_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
|
||||
return detail::fp8e5m2_to_fp32_value(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
|
||||
return (x & 0b01111111) > 0b01111100;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
|
||||
return (x & 0b01111111) == 0b01111100;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2
|
||||
operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(
|
||||
const Float8_e5m2& a,
|
||||
const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator+=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator-=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator*=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2& operator/=(
|
||||
Float8_e5m2& a,
|
||||
const Float8_e5m2& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
|
||||
return static_cast<Float8_e5m2>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2 to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e5m2> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 3;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 2;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -13;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::Float8_e5m2 min() {
|
||||
return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 max() {
|
||||
return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 lowest() {
|
||||
return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 epsilon() {
|
||||
return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 round_error() {
|
||||
return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 infinity() {
|
||||
return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 quiet_NaN() {
|
||||
return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2 denorm_min() {
|
||||
return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
||||
|
||||
@ -1,146 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
|
||||
/// to standard C types and basic arithmetic operations. Note that arithmetic
|
||||
/// operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
/// bias = 15
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
|
||||
/// and inspired by Half implementation from pytorch/c10/util/Half.h
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
|
||||
* representation, to a 32-bit floating-point number in IEEE single-precision
|
||||
* format, in bit representation.
|
||||
*
|
||||
* @note The implementation doesn't use any floating-point operations.
|
||||
*/
|
||||
inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
|
||||
/*
|
||||
* Extend the fp8 E5M2 number to 32 bits and shift to the
|
||||
* upper part of the 32-bit word:
|
||||
* +---+----+---+-----------------------------+
|
||||
* | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
|
||||
* +---+----+---+-----------------------------+
|
||||
* Bits 31 26-30 24-25 0-23
|
||||
*
|
||||
* S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
|
||||
* - zero bits.
|
||||
*/
|
||||
uint16_t half_representation = input;
|
||||
half_representation <<= 8;
|
||||
return fp16_ieee_to_fp32_value(half_representation);
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of fp32 infinity
|
||||
* 0 11111111 00000000000000000000000
|
||||
*/
|
||||
constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
|
||||
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value
|
||||
* not representable in fp8e5m2 range:
|
||||
* 0 11111 00 - fp8e5m2
|
||||
* 0 10001111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fp8_max = UINT32_C(143) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e5m2 normal range
|
||||
* into denorm representation
|
||||
* magic number: ((127 - 15) + (23 - 2) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
uint8_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fp8_max) {
|
||||
// NaN - all exponent and mantissa bits set to 1
|
||||
result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
|
||||
} else {
|
||||
if (f_bits < (UINT32_C(113) << 23)) {
|
||||
// Input number is smaller than 2^(-14), which is the smallest
|
||||
// fp8e5m2 normal number
|
||||
f_bits =
|
||||
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint32_t mant_odd = (f_bits >> 21) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 21);
|
||||
}
|
||||
}
|
||||
|
||||
result |= static_cast<uint8_t>(sign >> 24);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e5m2 {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e5m2() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e5m2(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e5m2.h>
|
||||
|
||||
@ -1,285 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Float8_fnuz_cvt.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value)
|
||||
: x(detail::fp8e5m2fnuz_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const {
|
||||
return detail::fp8_fnuz_to_fp32_value<5, 2>(x);
|
||||
}
|
||||
|
||||
/// Special values helpers
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const {
|
||||
return x == 0b10000000;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Arithmetic
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) + static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) - static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz
|
||||
operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) {
|
||||
return static_cast<float>(a) * static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(
|
||||
const Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) {
|
||||
return -static_cast<float>(a);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=(
|
||||
Float8_e5m2fnuz& a,
|
||||
const Float8_e5m2fnuz& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
|
||||
/// Arithmetic with floats
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) {
|
||||
return static_cast<float>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<float>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<float>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a += static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a -= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a *= static_cast<float>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) {
|
||||
return a /= static_cast<float>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with doubles
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) {
|
||||
return static_cast<double>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return static_cast<double>(a) / b;
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) {
|
||||
return a + static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) {
|
||||
return a - static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) {
|
||||
return a * static_cast<double>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
return a / static_cast<double>(b);
|
||||
}
|
||||
|
||||
/// Arithmetic with ints
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
//// Arithmetic with int64_t
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a + static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a - static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a * static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) {
|
||||
return a / static_cast<Float8_e5m2fnuz>(b);
|
||||
}
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) + b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) - b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) * b;
|
||||
}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) {
|
||||
return static_cast<Float8_e5m2fnuz>(a) / b;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e5m2fnuz to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e5m2fnuz> {
|
||||
public:
|
||||
static constexpr bool is_signed = true;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = true;
|
||||
static constexpr auto has_denorm_loss = true;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 3;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 2;
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -14;
|
||||
static constexpr int min_exponent10 = -4;
|
||||
static constexpr int max_exponent = 16;
|
||||
static constexpr int max_exponent10 = 4;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before =
|
||||
numeric_limits<float>::tinyness_before;
|
||||
|
||||
static constexpr c10::Float8_e5m2fnuz min() {
|
||||
return c10::Float8_e5m2fnuz(0x04, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz max() {
|
||||
return c10::Float8_e5m2fnuz(0x7F, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz lowest() {
|
||||
return c10::Float8_e5m2fnuz(0xFF, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz epsilon() {
|
||||
return c10::Float8_e5m2fnuz(0x34, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz round_error() {
|
||||
return c10::Float8_e5m2fnuz(0x38, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz infinity() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
// TODO(future): we are mapping neg_zero to both inf and NaN, this is
|
||||
// surprising and we should figure out what to do about it.
|
||||
static constexpr c10::Float8_e5m2fnuz quiet_NaN() {
|
||||
return c10::Float8_e5m2fnuz(0x80, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e5m2fnuz denorm_min() {
|
||||
return c10::Float8_e5m2fnuz(0x01, c10::Float8_e5m2fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
||||
|
||||
@ -1,138 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e5m2fnuz type (8-bit floating-point) including
|
||||
/// conversions to standard C types and basic arithmetic operations. Note that
|
||||
/// arithmetic operations are implemented by converting to floating point and
|
||||
/// performing the operation in float32.
|
||||
/// Binary configuration remains the same as e5m2:
|
||||
/// s eeeee mm
|
||||
/// 1 sign bit
|
||||
/// 5 exponent bits
|
||||
/// 2 mantissa bits
|
||||
/// The key differences that e5m2fnuz brings are:
|
||||
/// bias = 16
|
||||
/// no infinities or negative zero
|
||||
/// NaN only when sign bit is 1, rest all 0s
|
||||
///
|
||||
/// Implementation based on the paper https://arxiv.org/pdf/2206.02915.pdf and
|
||||
/// the existing Float8_e4m3fn implementation.
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/TypeSafeSignMath.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 E5M2 format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) {
|
||||
/*
|
||||
* Binary representation of 65536.0f, which is the first value not
|
||||
* representable (i.e. the first value which would overflow in to the sign
|
||||
* bit, resulting in a NaN) in fp8e4m3fnuz range:
|
||||
* 1 00000 00 - fp8e5m2fnuz
|
||||
* 0 10001111 00000000000000000000000 - fp32
|
||||
*/
|
||||
constexpr uint32_t fnuz_max = UINT32_C(0x8F) << 23;
|
||||
|
||||
/*
|
||||
* A mask for converting fp32 numbers lower than fp8e5m2fnuz normal range
|
||||
* into denormalized representation.
|
||||
* magic number: ((127 - 16) + (23 - 2) + 1)
|
||||
*/
|
||||
constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23;
|
||||
|
||||
uint32_t f_bits = fp32_to_bits(f);
|
||||
uint32_t result = 0u;
|
||||
|
||||
/*
|
||||
* Extract the sign of the input number into the high bit of the 32-bit word:
|
||||
*
|
||||
* +---+----------------------------------+
|
||||
* | S |0000000 00000000 00000000 00000000|
|
||||
* +---+----------------------------------+
|
||||
* Bits 31 0-31
|
||||
*/
|
||||
const uint32_t sign = f_bits & UINT32_C(0x80000000);
|
||||
|
||||
/*
|
||||
* Set sign bit to 0
|
||||
*/
|
||||
f_bits ^= sign;
|
||||
|
||||
if (f_bits >= fnuz_max) {
|
||||
// NaN -- sign bit set to 1, rest 0s
|
||||
return 0x80;
|
||||
}
|
||||
|
||||
if (f_bits < (UINT32_C(0x70) << 23) /* 2^-15 in float32 */) {
|
||||
// Input exponent is less than -15, the smallest e5m2fnuz exponent, so the
|
||||
// number will become subnormal.
|
||||
f_bits = fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
||||
result = static_cast<uint8_t>(f_bits - denorm_mask);
|
||||
if (result == 0) {
|
||||
// fnuz types don't have negative zero.
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
// resulting mantissa is odd
|
||||
uint8_t mant_odd = (f_bits >> 21) & 1;
|
||||
|
||||
// update exponent, rounding bias part 1
|
||||
f_bits += ((uint32_t)(16 - 127) << 23) + 0xFFFFF;
|
||||
|
||||
// rounding bias part 2
|
||||
f_bits += mant_odd;
|
||||
|
||||
// take the bits!
|
||||
result = static_cast<uint8_t>(f_bits >> 21);
|
||||
}
|
||||
|
||||
result |= sign >> 24;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e5m2fnuz {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e5m2fnuz() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e5m2fnuz(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e5m2fnuz(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
inline C10_HOST_DEVICE bool isinf() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e5m2fnuz& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e5m2fnuz-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e5m2fnuz.h>
|
||||
|
||||
@ -1,112 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
// TODO(#146647): Can we remove the below warning?
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Constructors
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::Float8_e8m0fnu(float value)
|
||||
: x(detail::fp8e8m0fnu_from_fp32_value(value)) {}
|
||||
|
||||
/// Implicit conversions
|
||||
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu::operator float() const {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
// if exponent is zero, need to special case to return 2^-127 instead of zero
|
||||
if (x == 0) {
|
||||
return c10::detail::fp32_from_bits(0x00400000);
|
||||
}
|
||||
|
||||
// if exponent is NaN, need to special case to return properly encoded NaN
|
||||
if (isnan()) {
|
||||
return c10::detail::fp32_from_bits(0x7f800001);
|
||||
}
|
||||
|
||||
// leave sign at 0, set the exponent bits, leave stored mantissa at 0
|
||||
uint32_t res = x << 23;
|
||||
|
||||
return c10::detail::fp32_from_bits(res);
|
||||
}
|
||||
|
||||
/// Special values helper
|
||||
|
||||
inline C10_HOST_DEVICE bool Float8_e8m0fnu::isnan() const {
|
||||
return x == 0b11111111;
|
||||
}
|
||||
|
||||
/// NOTE: we do not define comparisons directly and instead rely on the implicit
|
||||
/// conversion from c10::Float8_e8m0fnu to float.
|
||||
|
||||
} // namespace c10
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
class numeric_limits<c10::Float8_e8m0fnu> {
|
||||
public:
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = false;
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool is_exact = false;
|
||||
static constexpr bool has_infinity = false;
|
||||
static constexpr bool has_quiet_NaN = true;
|
||||
static constexpr bool has_signaling_NaN = false;
|
||||
static constexpr auto has_denorm = false;
|
||||
static constexpr auto has_denorm_loss = false;
|
||||
static constexpr auto round_style = numeric_limits<float>::round_style;
|
||||
static constexpr bool is_iec559 = false;
|
||||
static constexpr bool is_bounded = true;
|
||||
static constexpr bool is_modulo = false;
|
||||
static constexpr int digits = 1;
|
||||
static constexpr int digits10 = 0;
|
||||
static constexpr int max_digits10 = 1; // just a 2!
|
||||
static constexpr int radix = 2;
|
||||
static constexpr int min_exponent = -126;
|
||||
static constexpr int min_exponent10 = -38;
|
||||
static constexpr int max_exponent = 128;
|
||||
static constexpr int max_exponent10 = 38;
|
||||
static constexpr auto traps = numeric_limits<float>::traps;
|
||||
static constexpr auto tinyness_before = false;
|
||||
|
||||
static constexpr c10::Float8_e8m0fnu min() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu lowest() {
|
||||
// 2^-127
|
||||
return c10::Float8_e8m0fnu(0b00000000, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu max() {
|
||||
// 254 biased, which is 127 unbiased, so 2^127
|
||||
return c10::Float8_e8m0fnu(0b11111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu epsilon() {
|
||||
// according to https://en.cppreference.com/w/cpp/types/numeric_limits, this
|
||||
// is "the difference between 1.0 and the next representable value of the
|
||||
// given floating-point type". The next representable value is 2.0, so the
|
||||
// difference is 1.0 which is 2^0. 0 unbiased is 127 biased.
|
||||
return c10::Float8_e8m0fnu(0b01111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu round_error() {
|
||||
// 0.5 in float, which is 2^-1, and -1 + 127 = 126
|
||||
return c10::Float8_e8m0fnu(0b01111110, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
static constexpr c10::Float8_e8m0fnu quiet_NaN() {
|
||||
return c10::Float8_e8m0fnu(0b11111111, c10::Float8_e8m0fnu::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
||||
|
||||
@ -1,120 +1 @@
|
||||
#pragma once
|
||||
|
||||
/// Defines the Float8_e8m0fnu type (8-bit floating-point) including
|
||||
/// conversions to standard C types
|
||||
/// Binary configuration :
|
||||
/// eeeeeeee
|
||||
/// no sign bits
|
||||
/// 8 exponent bits
|
||||
/// no mantissa bits
|
||||
///
|
||||
/// This is the E8M0 dtype from the OCP MX format spec
|
||||
/// (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
|
||||
/// Section 5.4.1)
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/floating_point_utils.h>
|
||||
#include <type_traits>
|
||||
|
||||
// TODO(#146647): do we need to special case OPENCL?
|
||||
#if defined(__cplusplus)
|
||||
#include <cstdint>
|
||||
#elif !defined(__OPENCL_VERSION__)
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <iosfwd>
|
||||
#include <ostream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
* Convert a 32-bit floating-point number in IEEE single-precision format to a
|
||||
* 8-bit floating-point number in fp8 e8m0fnu format, in bit representation.
|
||||
*/
|
||||
inline C10_HOST_DEVICE uint8_t fp8e8m0fnu_from_fp32_value(float f) {
|
||||
// TODO(#146647): maybe rewrite without control flow
|
||||
|
||||
uint32_t f_bits = c10::detail::fp32_to_bits(f);
|
||||
|
||||
// extract the exponent
|
||||
uint32_t exponent = (f_bits >> 23) & 0b11111111;
|
||||
|
||||
// special case float32 NaN and +-inf to map to e8m0 nan
|
||||
if (exponent == 0b11111111) {
|
||||
return exponent;
|
||||
}
|
||||
|
||||
// next, we use guard, round, sticky bits and the LSB to implement round to
|
||||
// nearest, with ties to even
|
||||
|
||||
// guard bit - bit 23, or 22 zero-indexed
|
||||
uint8_t g = (f_bits & 0x400000) > 0;
|
||||
// round bit - bit 22, or 21 zero-indexed
|
||||
uint8_t r = (f_bits & 0x200000) > 0;
|
||||
// sticky bit - bits 21 to 1, or 20 to 0 zero-indexed
|
||||
uint8_t s = (f_bits & 0x1FFFFF) > 0;
|
||||
// in casting to e8m0, LSB is the implied mantissa bit. It equals to 0 if the
|
||||
// original float32 is denormal, and to 1 if the original float32 is normal.
|
||||
uint8_t lsb = exponent > 0;
|
||||
|
||||
// implement the RNE logic
|
||||
bool round_up = false;
|
||||
|
||||
// if g == 0, round down (no-op)
|
||||
if (g == 1) {
|
||||
if ((r == 1) || (s == 1)) {
|
||||
// round up
|
||||
round_up = true;
|
||||
} else {
|
||||
if (lsb == 1) {
|
||||
// round up
|
||||
round_up = true;
|
||||
}
|
||||
// if lsb == 0, round down (no-op)
|
||||
}
|
||||
}
|
||||
|
||||
if (round_up) {
|
||||
// adjust exponent
|
||||
// note that if exponent was 255 we would have already returned earlier, so
|
||||
// we know we can add one safely without running out of bounds
|
||||
exponent++;
|
||||
}
|
||||
|
||||
return exponent;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
struct alignas(1) Float8_e8m0fnu {
|
||||
uint8_t x;
|
||||
|
||||
struct from_bits_t {};
|
||||
C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
|
||||
Float8_e8m0fnu() = default;
|
||||
|
||||
constexpr C10_HOST_DEVICE Float8_e8m0fnu(uint8_t bits, from_bits_t)
|
||||
: x(bits) {}
|
||||
inline C10_HOST_DEVICE Float8_e8m0fnu(float value);
|
||||
inline C10_HOST_DEVICE operator float() const;
|
||||
inline C10_HOST_DEVICE bool isnan() const;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(
|
||||
std::ostream& out,
|
||||
const Float8_e8m0fnu& value) {
|
||||
out << (float)value;
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#include <c10/util/Float8_e8m0fnu-inl.h> // IWYU pragma: keep
|
||||
#include <torch/headeronly/util/Float8_e8m0fnu.h>
|
||||
|
||||
@ -1,140 +1 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wstring-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion")
|
||||
#endif
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
/// Returns false since we cannot have x < 0 if x is unsigned.
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(
|
||||
const T& /*x*/,
|
||||
std::true_type /*is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if a signed variable x < 0
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(const T& x, std::false_type /*is_unsigned*/) {
|
||||
return x < T(0);
|
||||
}
|
||||
|
||||
/// Returns true if x < 0
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :-(
|
||||
template <typename T>
|
||||
inline constexpr bool is_negative(const T& x) {
|
||||
return is_negative(x, std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
/// Returns the sign of an unsigned variable x as 0, 1
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x, std::true_type /*is_unsigned*/) {
|
||||
return T(0) < x;
|
||||
}
|
||||
|
||||
/// Returns the sign of a signed variable x as -1, 0, 1
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x, std::false_type /*is_unsigned*/) {
|
||||
return (T(0) < x) - (x < T(0));
|
||||
}
|
||||
|
||||
/// Returns the sign of x as -1, 0, 1
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :-(
|
||||
template <typename T>
|
||||
inline constexpr int signum(const T& x) {
|
||||
return signum(x, std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
/// Returns true if a and b are not both negative
|
||||
template <typename T, typename U>
|
||||
inline constexpr bool signs_differ(const T& a, const U& b) {
|
||||
return is_negative(a) != is_negative(b);
|
||||
}
|
||||
|
||||
// Suppress sign compare warning when compiling with GCC
|
||||
// as later does not account for short-circuit rule before
|
||||
// raising the warning, see https://godbolt.org/z/Tr3Msnz99
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wsign-compare"
|
||||
#endif
|
||||
|
||||
/// Returns true if x is greater than the greatest value of the type Limit
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool greater_than_max(const T& x) {
|
||||
constexpr bool can_overflow =
|
||||
std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
|
||||
return can_overflow && x > (std::numeric_limits<Limit>::max)();
|
||||
}
|
||||
|
||||
#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
/// Returns true if x < lowest(Limit). Standard comparison
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& x,
|
||||
std::false_type /*limit_is_unsigned*/,
|
||||
std::false_type /*x_is_unsigned*/) {
|
||||
return x < std::numeric_limits<Limit>::lowest();
|
||||
}
|
||||
|
||||
/// Returns false since all the limit is signed and therefore includes
|
||||
/// negative values but x cannot be negative because it is unsigned
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& /*x*/,
|
||||
std::false_type /*limit_is_unsigned*/,
|
||||
std::true_type /*x_is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if x < 0, where 0 is constructed from T.
|
||||
/// Limit is not signed, so its lower value is zero
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& x,
|
||||
std::true_type /*limit_is_unsigned*/,
|
||||
std::false_type /*x_is_unsigned*/) {
|
||||
return x < T(0);
|
||||
}
|
||||
|
||||
/// Returns false sign both types are unsigned
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(
|
||||
const T& /*x*/,
|
||||
std::true_type /*limit_is_unsigned*/,
|
||||
std::true_type /*x_is_unsigned*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if x is less than the lowest value of type T
|
||||
/// NOTE: Will fail on an unsigned custom type
|
||||
/// For the most part it's possible to fix this if
|
||||
/// the custom type has a constexpr constructor.
|
||||
/// However, notably, c10::Half does not :
|
||||
template <typename Limit, typename T>
|
||||
inline constexpr bool less_than_lowest(const T& x) {
|
||||
return less_than_lowest<Limit>(
|
||||
x, std::is_unsigned<Limit>(), std::is_unsigned<T>());
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
#include <torch/headeronly/util/TypeSafeSignMath.h>
|
||||
|
||||
@ -4,531 +4,7 @@
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
#include <thrust/complex.h>
|
||||
#endif
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
|
||||
#endif
|
||||
#if C10_CLANG_HAS_WARNING("-Wfloat-conversion")
|
||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion")
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// c10::complex is an implementation of complex numbers that aims
|
||||
// to work on all devices supported by PyTorch
|
||||
//
|
||||
// Most of the APIs duplicates std::complex
|
||||
// Reference: https://en.cppreference.com/w/cpp/numeric/complex
|
||||
//
|
||||
// [NOTE: Complex Operator Unification]
|
||||
// Operators currently use a mix of std::complex, thrust::complex, and
|
||||
// c10::complex internally. The end state is that all operators will use
|
||||
// c10::complex internally. Until then, there may be some hacks to support all
|
||||
// variants.
|
||||
//
|
||||
//
|
||||
// [Note on Constructors]
|
||||
//
|
||||
// The APIs of constructors are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/complex
|
||||
//
|
||||
// Since C++14, all constructors are constexpr in std::complex
|
||||
//
|
||||
// There are three types of constructors:
|
||||
// - initializing from real and imag:
|
||||
// `constexpr complex( const T& re = T(), const T& im = T() );`
|
||||
// - implicitly-declared copy constructor
|
||||
// - converting constructors
|
||||
//
|
||||
// Converting constructors:
|
||||
// - std::complex defines converting constructor between float/double/long
|
||||
// double,
|
||||
// while we define converting constructor between float/double.
|
||||
// - For these converting constructors, upcasting is implicit, downcasting is
|
||||
// explicit.
|
||||
// - We also define explicit casting from std::complex/thrust::complex
|
||||
// - Note that the conversion from thrust is not constexpr, because
|
||||
// thrust does not define them as constexpr ????
|
||||
//
|
||||
//
|
||||
// [Operator =]
|
||||
//
|
||||
// The APIs of operator = are mostly copied from C++ standard:
|
||||
// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D
|
||||
//
|
||||
// Since C++20, all operator= are constexpr. Although we are not building with
|
||||
// C++20, we also obey this behavior.
|
||||
//
|
||||
// There are three types of assign operator:
|
||||
// - Assign a real value from the same scalar type
|
||||
// - In std, this is templated as complex& operator=(const T& x)
|
||||
// with specialization `complex& operator=(T x)` for float/double/long
|
||||
// double Since we only support float and double, on will use `complex&
|
||||
// operator=(T x)`
|
||||
// - Copy assignment operator and converting assignment operator
|
||||
// - There is no specialization of converting assignment operators, which type
|
||||
// is
|
||||
// convertible is solely dependent on whether the scalar type is convertible
|
||||
//
|
||||
// In addition to the standard assignment, we also provide assignment operators
|
||||
// with std and thrust
|
||||
//
|
||||
//
|
||||
// [Casting operators]
|
||||
//
|
||||
// std::complex does not have casting operators. We define casting operators
|
||||
// casting to std::complex and thrust::complex
|
||||
//
|
||||
//
|
||||
// [Operator ""]
|
||||
//
|
||||
// std::complex has custom literals `i`, `if` and `il` defined in namespace
|
||||
// `std::literals::complex_literals`. We define our own custom literals in the
|
||||
// namespace `c10::complex_literals`. Our custom literals does not follow the
|
||||
// same behavior as in std::complex, instead, we define _if, _id to construct
|
||||
// float/double complex literals.
|
||||
//
|
||||
//
|
||||
// [real() and imag()]
|
||||
//
|
||||
// In C++20, there are two overload of these functions, one it to return the
|
||||
// real/imag, another is to set real/imag, they are both constexpr. We follow
|
||||
// this design.
|
||||
//
|
||||
//
|
||||
// [Operator +=,-=,*=,/=]
|
||||
//
|
||||
// Since C++20, these operators become constexpr. In our implementation, they
|
||||
// are also constexpr.
|
||||
//
|
||||
// There are two types of such operators: operating with a real number, or
|
||||
// operating with another complex number. For the operating with a real number,
|
||||
// the generic template form has argument type `const T &`, while the overload
|
||||
// for float/double/long double has `T`. We will follow the same type as
|
||||
// float/double/long double in std.
|
||||
//
|
||||
// [Unary operator +-]
|
||||
//
|
||||
// Since C++20, they are constexpr. We also make them expr
|
||||
//
|
||||
// [Binary operators +-*/]
|
||||
//
|
||||
// Each operator has three versions (taking + as example):
|
||||
// - complex + complex
|
||||
// - complex + real
|
||||
// - real + complex
|
||||
//
|
||||
// [Operator ==, !=]
|
||||
//
|
||||
// Each operator has three versions (taking == as example):
|
||||
// - complex == complex
|
||||
// - complex == real
|
||||
// - real == complex
|
||||
//
|
||||
// Some of them are removed on C++20, but we decide to keep them
|
||||
//
|
||||
// [Operator <<, >>]
|
||||
//
|
||||
// These are implemented by casting to std::complex
|
||||
//
|
||||
//
|
||||
//
|
||||
// TODO(@zasdfgbnm): c10::complex<c10::Half> is not currently supported,
|
||||
// because:
|
||||
// - lots of members and functions of c10::Half are not constexpr
|
||||
// - thrust::complex only support float and double
|
||||
|
||||
template <typename T>
|
||||
struct alignas(sizeof(T) * 2) complex {
|
||||
using value_type = T;
|
||||
|
||||
T real_ = T(0);
|
||||
T imag_ = T(0);
|
||||
|
||||
constexpr complex() = default;
|
||||
C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T())
|
||||
: real_(re), imag_(im) {}
|
||||
template <typename U>
|
||||
explicit constexpr complex(const std::complex<U>& other)
|
||||
: complex(other.real(), other.imag()) {}
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
explicit C10_HOST_DEVICE complex(const thrust::complex<U>& other)
|
||||
: real_(other.real()), imag_(other.imag()) {}
|
||||
// NOTE can not be implemented as follow due to ROCm bug:
|
||||
// explicit C10_HOST_DEVICE complex(const thrust::complex<U> &other):
|
||||
// complex(other.real(), other.imag()) {}
|
||||
#endif
|
||||
|
||||
// Use SFINAE to specialize casting constructor for c10::complex<float> and
|
||||
// c10::complex<double>
|
||||
template <typename U = T>
|
||||
C10_HOST_DEVICE explicit constexpr complex(
|
||||
const std::enable_if_t<std::is_same_v<U, float>, complex<double>>& other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
template <typename U = T>
|
||||
C10_HOST_DEVICE constexpr complex(
|
||||
const std::enable_if_t<std::is_same_v<U, double>, complex<float>>& other)
|
||||
: real_(other.real_), imag_(other.imag_) {}
|
||||
|
||||
constexpr complex<T>& operator=(T re) {
|
||||
real_ = re;
|
||||
imag_ = 0;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator+=(T re) {
|
||||
real_ += re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator-=(T re) {
|
||||
real_ -= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator*=(T re) {
|
||||
real_ *= re;
|
||||
imag_ *= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
constexpr complex<T>& operator/=(T re) {
|
||||
real_ /= re;
|
||||
imag_ /= re;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator=(const complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator+=(const complex<U>& rhs) {
|
||||
real_ += rhs.real();
|
||||
imag_ += rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator-=(const complex<U>& rhs) {
|
||||
real_ -= rhs.real();
|
||||
imag_ -= rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator*=(const complex<U>& rhs) {
|
||||
// (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
|
||||
T a = real_;
|
||||
T b = imag_;
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
|
||||
#ifdef __APPLE__
|
||||
#define FORCE_INLINE_APPLE __attribute__((always_inline))
|
||||
#else
|
||||
#define FORCE_INLINE_APPLE
|
||||
#endif
|
||||
template <typename U>
|
||||
constexpr FORCE_INLINE_APPLE complex<T>& operator/=(const complex<U>& rhs)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
// (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i
|
||||
// the calculation below follows numpy's complex division
|
||||
T a = real_;
|
||||
T b = imag_;
|
||||
U c = rhs.real();
|
||||
U d = rhs.imag();
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
// std::abs is already constexpr by gcc
|
||||
auto abs_c = std::abs(c);
|
||||
auto abs_d = std::abs(d);
|
||||
#else
|
||||
auto abs_c = c < 0 ? -c : c;
|
||||
auto abs_d = d < 0 ? -d : d;
|
||||
#endif
|
||||
|
||||
if (abs_c >= abs_d) {
|
||||
if (abs_c == U(0) && abs_d == U(0)) {
|
||||
/* divide by zeros should yield a complex inf or nan */
|
||||
real_ = a / abs_c;
|
||||
imag_ = b / abs_d;
|
||||
} else {
|
||||
auto rat = d / c;
|
||||
auto scl = U(1.0) / (c + d * rat);
|
||||
real_ = (a + b * rat) * scl;
|
||||
imag_ = (b - a * rat) * scl;
|
||||
}
|
||||
} else {
|
||||
auto rat = c / d;
|
||||
auto scl = U(1.0) / (d + c * rat);
|
||||
real_ = (a * rat + b) * scl;
|
||||
imag_ = (b * rat - a) * scl;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
#undef FORCE_INLINE_APPLE
|
||||
|
||||
template <typename U>
|
||||
constexpr complex<T>& operator=(const std::complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
C10_HOST_DEVICE complex<T>& operator=(const thrust::complex<U>& rhs) {
|
||||
real_ = rhs.real();
|
||||
imag_ = rhs.imag();
|
||||
return *this;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename U>
|
||||
explicit constexpr operator std::complex<U>() const {
|
||||
return std::complex<U>(std::complex<T>(real(), imag()));
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
template <typename U>
|
||||
C10_HOST_DEVICE explicit operator thrust::complex<U>() const {
|
||||
return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag()));
|
||||
}
|
||||
#endif
|
||||
|
||||
// consistent with NumPy behavior
|
||||
explicit constexpr operator bool() const {
|
||||
return real() || imag();
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE constexpr T real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr void real(T value) {
|
||||
real_ = value;
|
||||
}
|
||||
C10_HOST_DEVICE constexpr T imag() const {
|
||||
return imag_;
|
||||
}
|
||||
constexpr void imag(T value) {
|
||||
imag_ = value;
|
||||
}
|
||||
};
|
||||
|
||||
namespace complex_literals {
|
||||
|
||||
constexpr complex<float> operator""_if(long double imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator""_id(long double imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<float> operator""_if(unsigned long long imag) {
|
||||
return complex<float>(0.0f, static_cast<float>(imag));
|
||||
}
|
||||
|
||||
constexpr complex<double> operator""_id(unsigned long long imag) {
|
||||
return complex<double>(0.0, static_cast<double>(imag));
|
||||
}
|
||||
|
||||
} // namespace complex_literals
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& val) {
|
||||
return complex<T>(-val.real(), -val.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result += rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator+(const T& lhs, const complex<T>& rhs) {
|
||||
return complex<T>(lhs + rhs.real(), rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result -= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator-(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = -rhs;
|
||||
return result += lhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result *= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator*(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = rhs;
|
||||
return result *= lhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const complex<T>& lhs, const T& rhs) {
|
||||
complex<T> result = lhs;
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr complex<T> operator/(const T& lhs, const complex<T>& rhs) {
|
||||
complex<T> result(lhs, T());
|
||||
return result /= rhs;
|
||||
}
|
||||
|
||||
// Define operators between integral scalars and c10::complex. std::complex does
|
||||
// not support this when T is a floating-point number. This is useful because it
|
||||
// saves a lot of "static_cast" when operate a complex and an integer. This
|
||||
// makes the code both less verbose and potentially more efficient.
|
||||
#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
|
||||
typename std::enable_if_t< \
|
||||
std::is_floating_point_v<fT> && std::is_integral_v<iT>, \
|
||||
int> = 0
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
|
||||
return a + static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) + b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
|
||||
return a - static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) - b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
|
||||
return a * static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) * b;
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
|
||||
return a / static_cast<fT>(b);
|
||||
}
|
||||
|
||||
template <typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
|
||||
constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
|
||||
return static_cast<fT>(a) / b;
|
||||
}
|
||||
|
||||
#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const complex<T>& lhs, const T& rhs) {
|
||||
return (lhs.real() == rhs) && (lhs.imag() == T());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator==(const T& lhs, const complex<T>& rhs) {
|
||||
return (lhs == rhs.real()) && (T() == rhs.imag());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const complex<T>& lhs, const complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const complex<T>& lhs, const T& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr bool operator!=(const T& lhs, const complex<T>& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_ostream<CharT, Traits>& operator<<(
|
||||
std::basic_ostream<CharT, Traits>& os,
|
||||
const complex<T>& x) {
|
||||
return (os << static_cast<std::complex<T>>(x));
|
||||
}
|
||||
|
||||
template <typename T, typename CharT, typename Traits>
|
||||
std::basic_istream<CharT, Traits>& operator>>(
|
||||
std::basic_istream<CharT, Traits>& is,
|
||||
complex<T>& x) {
|
||||
std::complex<T> tmp;
|
||||
is >> tmp;
|
||||
x = tmp;
|
||||
return is;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
#include <torch/headeronly/util/complex.h>
|
||||
|
||||
// std functions
|
||||
//
|
||||
@ -594,72 +70,6 @@ constexpr c10::complex<T> conj(const c10::complex<T>& z) {
|
||||
|
||||
} // namespace std
|
||||
|
||||
namespace c10 {
|
||||
|
||||
template <typename T>
|
||||
C10_HOST_DEVICE complex<T> polar(const T& r, const T& theta = T()) {
|
||||
#if defined(__CUDACC__) || defined(__HIPCC__)
|
||||
return static_cast<complex<T>>(thrust::polar(r, theta));
|
||||
#else
|
||||
// std::polar() requires r >= 0, so spell out the explicit implementation to
|
||||
// avoid a branch.
|
||||
return complex<T>(r * std::cos(theta), r * std::sin(theta));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
struct alignas(4) complex<Half> {
|
||||
Half real_;
|
||||
Half imag_;
|
||||
|
||||
// Constructors
|
||||
complex() = default;
|
||||
// Half constructor is not constexpr so the following constructor can't
|
||||
// be constexpr
|
||||
C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
|
||||
: real_(real), imag_(imag) {}
|
||||
C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
|
||||
: real_(value.real()), imag_(value.imag()) {}
|
||||
|
||||
// Conversion operator
|
||||
inline C10_HOST_DEVICE operator c10::complex<float>() const {
|
||||
return {real_, imag_};
|
||||
}
|
||||
|
||||
constexpr C10_HOST_DEVICE Half real() const {
|
||||
return real_;
|
||||
}
|
||||
constexpr C10_HOST_DEVICE Half imag() const {
|
||||
return imag_;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
|
||||
real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
|
||||
imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
|
||||
auto a = static_cast<float>(real_);
|
||||
auto b = static_cast<float>(imag_);
|
||||
auto c = static_cast<float>(other.real());
|
||||
auto d = static_cast<float>(other.imag());
|
||||
real_ = a * c - b * d;
|
||||
imag_ = a * d + b * c;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
|
||||
#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
|
||||
// math functions are included in a separate file
|
||||
#include <c10/util/complex_math.h> // IWYU pragma: keep
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
#include <c10/core/AllocatorConfig.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/xpu/XPUCachingAllocator.h>
|
||||
@ -21,6 +20,8 @@ constexpr size_t kMinBlockSize = 512;
|
||||
constexpr size_t kSmallSize = 1048576;
|
||||
// "small" allocations are packed in 2 MiB blocks
|
||||
constexpr size_t kSmallBuffer = 2097152;
|
||||
// "large" allocations may be packed in 20 MiB blocks
|
||||
constexpr size_t kLargeBuffer = 20971520;
|
||||
// allocations between 1 and 10 MiB may use kLargeBuffer
|
||||
constexpr size_t kMinLargeAlloc = 10485760;
|
||||
// round up large allocations to 2 MiB
|
||||
|
||||
@ -825,7 +825,6 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
add_dependencies(torch_cpu metallibs)
|
||||
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib)
|
||||
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_bfloat,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_bfloat.metallib)
|
||||
else()
|
||||
target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS)
|
||||
endif()
|
||||
|
||||
@ -260,7 +260,7 @@ endif()
|
||||
# Determine if blas was compiled with the f2c conventions
|
||||
if(BLAS_LIBRARIES AND BLAS_CHECK_F2C)
|
||||
include(cmake/BLAS_ABI.cmake)
|
||||
endif(BLAS_LIBRARIES)
|
||||
endif()
|
||||
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(AT_MKL_SEQUENTIAL 0)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user